Skip to content

Commit 7950b11

Browse files
authored
Fix DA cost correction when cost limit is set to Inf (#593)
* Introduce the test that actually fails with cost = nan * Update cost correction algorithm * Better explanation in the code for how missing_labels and label_match interact * All close to check semi-supervised mode * Update RELEASE file * Suppress runtime warning from numpy about using Inf in the multiplication
1 parent 98a58d2 commit 7950b11

File tree

3 files changed

+51
-11
lines changed

3 files changed

+51
-11
lines changed

RELEASES.md

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Releases
22

3+
## 0.9.3
4+
5+
#### Closed issues
6+
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
7+
8+
39
## 0.9.2
410
*December 2023*
511

ot/da.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# License: MIT License
1414

1515
import numpy as np
16+
import warnings
1617

1718
from .backend import get_backend
1819
from .bregman import sinkhorn, jcpot_barycenter
@@ -499,12 +500,27 @@ class label
499500
if self.limit_max != np.infty:
500501
self.limit_max = self.limit_max * nx.max(self.cost_)
501502

502-
# zeros where source label is missing (masked with -1)
503-
missing_labels = ys + nx.ones(ys.shape, type_as=ys)
504-
missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1)
505-
# zeros where labels match
506-
label_match = ys[:, None] - yt[None, :]
507-
self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max)
503+
# missing_labels is a (ns, nt) matrix of {0, 1} such that
504+
# the cells (i, j) has 0 iff either ys[i] or yt[j] is masked
505+
missing_ys = (ys == -1) + nx.zeros(ys.shape, type_as=ys)
506+
missing_yt = (yt == -1) + nx.zeros(yt.shape, type_as=yt)
507+
missing_labels = missing_ys[:, None] @ missing_yt[None, :]
508+
# labels_match is a (ns, nt) matrix of {True, False} such that
509+
# the cells (i, j) has False if ys[i] != yt[i]
510+
label_match = (ys[:, None] - yt[None, :]) != 0
511+
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
512+
# that he cells (i, j) has -Inf where there's no correction necessary
513+
# by 'correction' we mean setting cost to a large value when
514+
# labels do not match
515+
# we suppress potential RuntimeWarning caused by Inf multiplication
516+
# (as we explicitly cover potential NANs later)
517+
with warnings.catch_warnings():
518+
warnings.simplefilter('ignore', category=RuntimeWarning)
519+
cost_correction = label_match * missing_labels * self.limit_max
520+
# this operation is necessary because 0 * Inf = NAN
521+
# thus is irrelevant when limit_max is finite
522+
cost_correction = nx.nan_to_num(cost_correction, -np.infty)
523+
self.cost_ = nx.maximum(self.cost_, cost_correction)
508524

509525
# distribution estimation
510526
self.mu_s = self.distribution_estimation(Xs)

test/test_da.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def test_sinkhorn_lpl1_transport_class(nx):
8989
# test its computed
9090
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
9191
assert hasattr(otda, "cost_")
92+
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
9293
assert hasattr(otda, "coupling_")
94+
assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite"
9395

9496
# test dimensions of coupling
9597
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -148,7 +150,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
148150
n_semisup = nx.sum(otda_semi.cost_)
149151

150152
# check that the cost matrix norms are indeed different
151-
assert n_unsup != n_semisup, "semisupervised mode not working"
153+
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"
152154

153155
# check that the coupling forbids mass transport between labeled source
154156
# and labeled target samples
@@ -238,7 +240,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
238240
n_semisup = nx.sum(otda_semi.cost_)
239241

240242
# check that the cost matrix norms are indeed different
241-
assert n_unsup != n_semisup, "semisupervised mode not working"
243+
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"
242244

243245
# check that the coupling forbids mass transport between labeled source
244246
# and labeled target samples
@@ -331,7 +333,7 @@ def test_sinkhorn_transport_class(nx):
331333
n_semisup = nx.sum(otda_semi.cost_)
332334

333335
# check that the cost matrix norms are indeed different
334-
assert n_unsup != n_semisup, "semisupervised mode not working"
336+
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"
335337

336338
# check that the coupling forbids mass transport between labeled source
337339
# and labeled target samples
@@ -371,6 +373,10 @@ def test_unbalanced_sinkhorn_transport_class(nx):
371373
# test dimensions of coupling
372374
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
373375
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
376+
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
377+
378+
# test coupling
379+
assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite"
374380

375381
# test transform
376382
transp_Xs = otda.transform(Xs=Xs)
@@ -409,19 +415,22 @@ def test_unbalanced_sinkhorn_transport_class(nx):
409415
# test unsupervised vs semi-supervised mode
410416
otda_unsup = ot.da.SinkhornTransport()
411417
otda_unsup.fit(Xs=Xs, Xt=Xt)
418+
assert not np.any(np.isnan(nx.to_numpy(otda_unsup.cost_))), "cost is finite"
412419
n_unsup = nx.sum(otda_unsup.cost_)
413420

414421
otda_semi = ot.da.SinkhornTransport()
415422
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
423+
assert not np.any(np.isnan(nx.to_numpy(otda_semi.cost_))), "cost is finite"
416424
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
417425
n_semisup = nx.sum(otda_semi.cost_)
418426

419427
# check that the cost matrix norms are indeed different
420-
assert n_unsup != n_semisup, "semisupervised mode not working"
428+
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"
421429

422430
# check everything runs well with log=True
423431
otda = ot.da.SinkhornTransport(log=True)
424432
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
433+
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
425434
assert len(otda.log_.keys()) != 0
426435

427436

@@ -448,7 +457,9 @@ def test_emd_transport_class(nx):
448457

449458
# test dimensions of coupling
450459
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
460+
assert not np.any(np.isnan(nx.to_numpy(otda.cost_))), "cost is finite"
451461
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
462+
assert np.all(np.isfinite(nx.to_numpy(otda.coupling_))), "coupling is finite"
452463

453464
# test margin constraints
454465
mu_s = unif(ns)
@@ -495,15 +506,22 @@ def test_emd_transport_class(nx):
495506
# test unsupervised vs semi-supervised mode
496507
otda_unsup = ot.da.EMDTransport()
497508
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
509+
assert_equal(otda_unsup.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
510+
assert not np.any(np.isnan(nx.to_numpy(otda_unsup.cost_))), "cost is finite"
511+
assert_equal(otda_unsup.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
512+
assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "coupling is finite"
498513
n_unsup = nx.sum(otda_unsup.cost_)
499514

500515
otda_semi = ot.da.EMDTransport()
501516
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
502517
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
518+
assert not np.any(np.isnan(nx.to_numpy(otda_semi.cost_))), "cost is finite"
519+
assert_equal(otda_semi.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
520+
assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "coupling is finite"
503521
n_semisup = nx.sum(otda_semi.cost_)
504522

505523
# check that the cost matrix norms are indeed different
506-
assert n_unsup != n_semisup, "semisupervised mode not working"
524+
assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working"
507525

508526
# check that the coupling forbids mass transport between labeled source
509527
# and labeled target samples

0 commit comments

Comments
 (0)