Skip to content

Commit 9ddb690

Browse files
kachayevrflamary
andauthored
[DA] Sinkhorn L1L2 transport to work on JAX (#587)
* Draft sinkhorn_l1l2_transport to work on JAX * Move label_to_masks in utils * Move nan_to_num to backend * Proper test case for semi-supervised DA * Test case for label to mask computation * Simplified axis operations for labels * Allow JAX backend for BaseEstimator * Label normalization performs copy only when necessary * Fix comment regarding label transformation * Update RELEASES * Additional backend tests for nan_to_num * min(unique(y)) === min(y) * Avoid catching all warnings as JAX throws deprecation * No need to import warnings module --------- Co-authored-by: Rémi Flamary <[email protected]>
1 parent acd84ed commit 9ddb690

File tree

7 files changed

+131
-70
lines changed

7 files changed

+131
-70
lines changed

RELEASES.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Releases
22

3+
## Next Release
4+
5+
#### New features
6+
+ Domain adaptation method `SinkhornL1l2Transport` now supports JAX backend (PR #587)
7+
38
## 0.9.2dev
49

510
#### New features

ot/backend.py

+27
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,14 @@ def matmul(self, a, b):
10431043
"""
10441044
raise NotImplementedError()
10451045

1046+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
1047+
r"""
1048+
Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.
1049+
1050+
See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num
1051+
"""
1052+
raise NotImplementedError()
1053+
10461054

10471055
class NumpyBackend(Backend):
10481056
"""
@@ -1392,6 +1400,9 @@ def detach(self, *args):
13921400
def matmul(self, a, b):
13931401
return np.matmul(a, b)
13941402

1403+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
1404+
return np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
1405+
13951406

13961407
_register_backend_implementation(NumpyBackend)
13971408

@@ -1762,6 +1773,9 @@ def detach(self, *args):
17621773
def matmul(self, a, b):
17631774
return jnp.matmul(a, b)
17641775

1776+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
1777+
return jnp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
1778+
17651779

17661780
if jax:
17671781
# Only register jax backend if it is installed
@@ -2250,6 +2264,10 @@ def detach(self, *args):
22502264
def matmul(self, a, b):
22512265
return torch.matmul(a, b)
22522266

2267+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
2268+
out = None if copy else x
2269+
return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf, out=out)
2270+
22532271

22542272
if torch:
22552273
# Only register torch backend if it is installed
@@ -2647,6 +2665,9 @@ def detach(self, *args):
26472665
def matmul(self, a, b):
26482666
return cp.matmul(a, b)
26492667

2668+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
2669+
return cp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
2670+
26502671

26512672
if cp:
26522673
# Only register cp backend if it is installed
@@ -3070,6 +3091,12 @@ def detach(self, *args):
30703091
def matmul(self, a, b):
30713092
return tnp.matmul(a, b)
30723093

3094+
# todo(okachaiev): replace this with a more reasonable implementation
3095+
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
3096+
x = self.to_numpy(x)
3097+
x = np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
3098+
return self.from_numpy(x)
3099+
30733100

30743101
if tf:
30753102
# Only register tensorflow backend if it is installed

ot/da.py

+24-49
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .bregman import sinkhorn, jcpot_barycenter
1919
from .lp import emd
2020
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
21-
from .utils import list_to_array, check_params, BaseEstimator, deprecated
21+
from .utils import BaseEstimator, check_params, deprecated, labels_to_masks, list_to_array
2222
from .unbalanced import sinkhorn_unbalanced
2323
from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping
2424
from .optim import cg
@@ -499,18 +499,12 @@ class label
499499
if self.limit_max != np.infty:
500500
self.limit_max = self.limit_max * nx.max(self.cost_)
501501

502-
# assumes labeled source samples occupy the first rows
503-
# and labeled target samples occupy the first columns
504-
classes = [c for c in nx.unique(ys) if c != -1]
505-
for c in classes:
506-
idx_s = nx.where((ys != c) & (ys != -1))
507-
idx_t = nx.where(yt == c)
508-
509-
# all the coefficients corresponding to a source sample
510-
# and a target sample :
511-
# with different labels get a infinite
512-
for j in idx_t[0]:
513-
self.cost_[idx_s[0], j] = self.limit_max
502+
# zeros where source label is missing (masked with -1)
503+
missing_labels = ys + nx.ones(ys.shape, type_as=ys)
504+
missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1)
505+
# zeros where labels match
506+
label_match = ys[:, None] - yt[None, :]
507+
self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max)
514508

515509
# distribution estimation
516510
self.mu_s = self.distribution_estimation(Xs)
@@ -581,12 +575,11 @@ class label
581575
if check_params(Xs=Xs):
582576

583577
if nx.array_equal(self.xs_, Xs):
584-
585578
# perform standard barycentric mapping
586579
transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
587580

588581
# set nans to 0
589-
transp[~ nx.isfinite(transp)] = 0
582+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
590583

591584
# compute transported samples
592585
transp_Xs = nx.dot(transp, self.xt_)
@@ -604,9 +597,8 @@ class label
604597
idx = nx.argmin(D0, axis=1)
605598

606599
# transport the source samples
607-
transp = self.coupling_ / nx.sum(
608-
self.coupling_, axis=1)[:, None]
609-
transp[~ nx.isfinite(transp)] = 0
600+
transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
601+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
610602
transp_Xs_ = nx.dot(transp, self.xt_)
611603

612604
# define the transported points
@@ -645,23 +637,16 @@ def transform_labels(self, ys=None):
645637

646638
# check the necessary inputs parameters are here
647639
if check_params(ys=ys):
648-
649-
ysTemp = label_normalization(nx.copy(ys))
650-
classes = nx.unique(ysTemp)
651-
n = len(classes)
652-
D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_)
653-
654640
# perform label propagation
655641
transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :]
656642

657643
# set nans to 0
658-
transp[~ nx.isfinite(transp)] = 0
659-
660-
for c in classes:
661-
D1[int(c), ysTemp == c] = 1
644+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
662645

663646
# compute propagated labels
664-
transp_ys = nx.dot(D1, transp)
647+
labels = label_normalization(ys)
648+
masks = labels_to_masks(labels, nx=nx, type_as=transp)
649+
transp_ys = nx.dot(masks.T, transp)
665650

666651
return transp_ys.T
667652

@@ -697,12 +682,11 @@ class label
697682
if check_params(Xt=Xt):
698683

699684
if nx.array_equal(self.xt_, Xt):
700-
701685
# perform standard barycentric mapping
702686
transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]
703687

704688
# set nans to 0
705-
transp_[~ nx.isfinite(transp_)] = 0
689+
transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0)
706690

707691
# compute transported samples
708692
transp_Xt = nx.dot(transp_, self.xs_)
@@ -719,9 +703,8 @@ class label
719703
idx = nx.argmin(D0, axis=1)
720704

721705
# transport the target samples
722-
transp_ = self.coupling_.T / nx.sum(
723-
self.coupling_, 0)[:, None]
724-
transp_[~ nx.isfinite(transp_)] = 0
706+
transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]
707+
transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0)
725708
transp_Xt_ = nx.dot(transp_, self.xs_)
726709

727710
# define the transported points
@@ -750,23 +733,15 @@ def inverse_transform_labels(self, yt=None):
750733

751734
# check the necessary inputs parameters are here
752735
if check_params(yt=yt):
753-
754-
ytTemp = label_normalization(nx.copy(yt))
755-
classes = nx.unique(ytTemp)
756-
n = len(classes)
757-
D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_)
758-
759736
# perform label propagation
760737
transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]
761-
762738
# set nans to 0
763-
transp[~ nx.isfinite(transp)] = 0
739+
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
764740

765-
for c in classes:
766-
D1[int(c), ytTemp == c] = 1
767-
768-
# compute propagated samples
769-
transp_ys = nx.dot(D1, transp.T)
741+
# compute propagated labels
742+
labels = label_normalization(yt)
743+
masks = labels_to_masks(labels, nx=nx, type_as=transp)
744+
transp_ys = nx.dot(masks.T, transp.T)
770745

771746
return transp_ys.T
772747

@@ -2151,7 +2126,7 @@ def transform_labels(self, ys=None):
21512126
type_as=ys[0]
21522127
)
21532128
for i in range(len(ys)):
2154-
ysTemp = label_normalization(nx.copy(ys[i]))
2129+
ysTemp = label_normalization(ys[i])
21552130
classes = nx.unique(ysTemp)
21562131
n = len(classes)
21572132
ns = len(ysTemp)
@@ -2194,7 +2169,7 @@ def inverse_transform_labels(self, yt=None):
21942169
# check the necessary inputs parameters are here
21952170
if check_params(yt=yt):
21962171
transp_ys = []
2197-
ytTemp = label_normalization(nx.copy(yt))
2172+
ytTemp = label_normalization(yt)
21982173
classes = nx.unique(ytTemp)
21992174
n = len(classes)
22002175
D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0])

ot/utils.py

+35-10
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def is_all_finite(*args):
390390
return all(not nx.any(~nx.isfinite(arg)) for arg in args)
391391

392392

393-
def label_normalization(y, start=0):
393+
def label_normalization(y, start=0, nx=None):
394394
r""" Transform labels to start at a given value
395395
396396
Parameters
@@ -399,18 +399,45 @@ def label_normalization(y, start=0):
399399
The vector of labels to be normalized.
400400
start : int
401401
Desired value for the smallest label in :math:`\mathbf{y}` (default=0)
402+
nx : Backend, optional
403+
Backend to perform computations on. If omitted, the backend defaults to that of `y`.
402404
403405
Returns
404406
-------
405407
y : array-like, shape (`n1`, )
406408
The input vector of labels normalized according to given start value.
407409
"""
408-
nx = get_backend(y)
410+
if nx is None:
411+
nx = get_backend(y)
412+
diff = nx.min(y) - start
413+
return y if diff == 0 else (y - diff)
414+
415+
416+
def labels_to_masks(y, type_as=None, nx=None):
417+
r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks.
418+
419+
Parameters
420+
----------
421+
y : array-like, shape (n_samples, )
422+
The vector of labels.
423+
type_as : array_like
424+
Array of the same type of the expected output.
425+
nx : Backend, optional
426+
Backend to perform computations on. If omitted, the backend defaults to that of `y`.
409427
410-
diff = nx.min(nx.unique(y)) - start
411-
if diff != 0:
412-
y -= diff
413-
return y
428+
Returns
429+
-------
430+
masks : array-like, shape (n_samples, n_labels)
431+
The (n_samples, n_labels) matrix of label masks.
432+
"""
433+
if nx is None:
434+
nx = get_backend(y)
435+
if type_as is None:
436+
type_as = y
437+
labels_u, labels_idx = nx.unique(y, return_inverse=True)
438+
n_labels = labels_u.shape[0]
439+
masks = nx.eye(n_labels, type_as=type_as)[labels_idx]
440+
return masks
414441

415442

416443
def parmap(f, X, nprocs="default"):
@@ -755,10 +782,8 @@ def _get_backend(self, *arrays):
755782
nx = get_backend(
756783
*[input_ for input_ in arrays if input_ is not None]
757784
)
758-
if nx.__name__ in ("jax", "tf"):
759-
raise TypeError(
760-
"""JAX or TF arrays have been received but domain
761-
adaptation does not support those backend.""")
785+
if nx.__name__ in ("tf",):
786+
raise TypeError("Domain adaptation does not support TF backend.")
762787
self.nx = nx
763788
return nx
764789

test/test_backend.py

+7
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def test_empty_backend():
264264
nx.detach(M)
265265
with pytest.raises(NotImplementedError):
266266
nx.matmul(M, M.T)
267+
with pytest.raises(NotImplementedError):
268+
nx.nan_to_num(M)
267269

268270

269271
def test_func_backends(nx):
@@ -667,6 +669,11 @@ def test_func_backends(nx):
667669
lst_b.append(nx.to_numpy(A))
668670
lst_name.append("matmul broadcast")
669671

672+
vec = nx.from_numpy(np.array([1, np.nan, -1]))
673+
vec = nx.nan_to_num(vec, nan=0)
674+
lst_b.append(nx.to_numpy(vec))
675+
lst_name.append("nan_to_num")
676+
670677
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
671678
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
672679
assert not nx.array_equal(

test/test_da.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
from numpy.testing import assert_allclose, assert_equal
99
import pytest
10-
import warnings
1110

1211
import ot
1312
from ot.datasets import make_data_classif
@@ -158,7 +157,6 @@ def test_sinkhorn_lpl1_transport_class(nx):
158157
assert mass_semi == 0, "semisupervised mode not working"
159158

160159

161-
@pytest.skip_backend("jax")
162160
@pytest.skip_backend("tf")
163161
def test_sinkhorn_l1l2_transport_class(nx):
164162
"""test_sinkhorn_transport
@@ -169,15 +167,16 @@ def test_sinkhorn_l1l2_transport_class(nx):
169167

170168
Xs, ys = make_data_classif('3gauss', ns, random_state=42)
171169
Xt, yt = make_data_classif('3gauss2', nt, random_state=43)
170+
# prepare semi-supervised labels
171+
yt_semi = np.copy(yt)
172+
yt_semi[np.arange(0, nt, 2)] = -1
172173

173-
Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
174+
Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi)
174175

175176
otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500)
177+
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
176178

177179
# test its computed
178-
with warnings.catch_warnings():
179-
warnings.simplefilter("error")
180-
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
181180
assert hasattr(otda, "cost_")
182181
assert hasattr(otda, "coupling_")
183182
assert hasattr(otda, "log_")
@@ -234,7 +233,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
234233
n_unsup = nx.sum(otda_unsup.cost_)
235234

236235
otda_semi = ot.da.SinkhornL1l2Transport()
237-
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
236+
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi)
238237
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
239238
n_semisup = nx.sum(otda_semi.cost_)
240239

@@ -243,11 +242,9 @@ def test_sinkhorn_l1l2_transport_class(nx):
243242

244243
# check that the coupling forbids mass transport between labeled source
245244
# and labeled target samples
246-
mass_semi = nx.sum(
247-
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
245+
mass_semi = nx.sum(otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
248246
mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
249-
assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)),
250-
rtol=1e-9, atol=1e-9)
247+
assert_allclose(nx.to_numpy(mass_semi), np.zeros_like(mass_semi), rtol=1e-9, atol=1e-9)
251248

252249
# check everything runs well with log=True
253250
otda = ot.da.SinkhornL1l2Transport(log=True)

0 commit comments

Comments
 (0)