Skip to content

Commit 0eac835

Browse files
authored
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013a. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov
1 parent 0e431c2 commit 0eac835

File tree

8 files changed

+247
-33
lines changed

8 files changed

+247
-33
lines changed

ot/backend.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,18 @@ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
653653
"""
654654
raise NotImplementedError()
655655

656+
def dtype_device(self, a):
657+
r"""
658+
Returns the dtype and the device of the given tensor.
659+
"""
660+
raise NotImplementedError()
661+
662+
def assert_same_dtype_device(self, a, b):
663+
r"""
664+
Checks whether or not the two given inputs have the same dtype as well as the same device
665+
"""
666+
raise NotImplementedError()
667+
656668

657669
class NumpyBackend(Backend):
658670
"""
@@ -880,6 +892,16 @@ def copy(self, a):
880892
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
881893
return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
882894

895+
def dtype_device(self, a):
896+
if hasattr(a, "dtype"):
897+
return a.dtype, "cpu"
898+
else:
899+
return type(a), "cpu"
900+
901+
def assert_same_dtype_device(self, a, b):
902+
# numpy has implicit type conversion so we automatically validate the test
903+
pass
904+
883905

884906
class JaxBackend(Backend):
885907
"""
@@ -899,17 +921,20 @@ def __init__(self):
899921
self.rng_ = jax.random.PRNGKey(42)
900922

901923
for d in jax.devices():
902-
self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d),
903-
jax.device_put(jnp.array(1, dtype=np.float64), d)]
924+
self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d),
925+
jax.device_put(jnp.array(1, dtype=jnp.float64), d)]
904926

905927
def to_numpy(self, a):
906928
return np.array(a)
907929

930+
def _change_device(self, a, type_as):
931+
return jax.device_put(a, type_as.device_buffer.device())
932+
908933
def from_numpy(self, a, type_as=None):
909934
if type_as is None:
910935
return jnp.array(a)
911936
else:
912-
return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device())
937+
return self._change_device(jnp.array(a).astype(type_as.dtype), type_as)
913938

914939
def set_gradients(self, val, inputs, grads):
915940
from jax.flatten_util import ravel_pytree
@@ -928,13 +953,13 @@ def zeros(self, shape, type_as=None):
928953
if type_as is None:
929954
return jnp.zeros(shape)
930955
else:
931-
return jnp.zeros(shape, dtype=type_as.dtype)
956+
return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as)
932957

933958
def ones(self, shape, type_as=None):
934959
if type_as is None:
935960
return jnp.ones(shape)
936961
else:
937-
return jnp.ones(shape, dtype=type_as.dtype)
962+
return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as)
938963

939964
def arange(self, stop, start=0, step=1, type_as=None):
940965
return jnp.arange(start, stop, step)
@@ -943,13 +968,13 @@ def full(self, shape, fill_value, type_as=None):
943968
if type_as is None:
944969
return jnp.full(shape, fill_value)
945970
else:
946-
return jnp.full(shape, fill_value, dtype=type_as.dtype)
971+
return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as)
947972

948973
def eye(self, N, M=None, type_as=None):
949974
if type_as is None:
950975
return jnp.eye(N, M)
951976
else:
952-
return jnp.eye(N, M, dtype=type_as.dtype)
977+
return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as)
953978

954979
def sum(self, a, axis=None, keepdims=False):
955980
return jnp.sum(a, axis, keepdims=keepdims)
@@ -1127,6 +1152,16 @@ def copy(self, a):
11271152
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
11281153
return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
11291154

1155+
def dtype_device(self, a):
1156+
return a.dtype, a.device_buffer.device()
1157+
1158+
def assert_same_dtype_device(self, a, b):
1159+
a_dtype, a_device = self.dtype_device(a)
1160+
b_dtype, b_device = self.dtype_device(b)
1161+
1162+
assert a_dtype == b_dtype, "Dtype discrepancy"
1163+
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
1164+
11301165

11311166
class TorchBackend(Backend):
11321167
"""
@@ -1455,3 +1490,13 @@ def copy(self, a):
14551490

14561491
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
14571492
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
1493+
1494+
def dtype_device(self, a):
1495+
return a.dtype, a.device
1496+
1497+
def assert_same_dtype_device(self, a, b):
1498+
a_dtype, a_device = self.dtype_device(a)
1499+
b_dtype, b_device = self.dtype_device(b)
1500+
1501+
assert a_dtype == b_dtype, "Dtype discrepancy"
1502+
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"

ot/sliced.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
139139
X_t.shape[1]))
140140

141141
if a is None:
142-
a = nx.full(n, 1 / n)
142+
a = nx.full(n, 1 / n, type_as=X_s)
143143
if b is None:
144-
b = nx.full(m, 1 / m)
144+
b = nx.full(m, 1 / m, type_as=X_s)
145145

146146
d = X_s.shape[1]
147147

@@ -238,9 +238,9 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
238238
X_t.shape[1]))
239239

240240
if a is None:
241-
a = nx.full(n, 1 / n)
241+
a = nx.full(n, 1 / n, type_as=X_s)
242242
if b is None:
243-
b = nx.full(m, 1 / m)
243+
b = nx.full(m, 1 / m, type_as=X_s)
244244

245245
d = X_s.shape[1]
246246

test/conftest.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,44 @@
1111

1212
if jax:
1313
from jax.config import config
14+
config.update("jax_enable_x64", True)
1415

1516
backend_list = get_backend_list()
1617

1718

1819
@pytest.fixture(params=backend_list)
1920
def nx(request):
2021
backend = request.param
21-
if backend.__name__ == "jax":
22-
config.update("jax_enable_x64", True)
2322

2423
yield backend
2524

26-
if backend.__name__ == "jax":
27-
config.update("jax_enable_x64", False)
28-
2925

3026
def skip_arg(arg, value, reason=None, getter=lambda x: x):
27+
if isinstance(arg, tuple) or isinstance(arg, list):
28+
n = len(arg)
29+
else:
30+
arg = (arg, )
31+
n = 1
32+
if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
33+
pass
34+
else:
35+
value = (value, )
36+
if isinstance(getter, tuple) or isinstance(value, list):
37+
pass
38+
else:
39+
getter = [getter] * n
40+
3141
if reason is None:
3242
reason = f"Param {arg} should be skipped for value {value}"
3343

3444
def wrapper(function):
3545

3646
@functools.wraps(function)
3747
def wrapped(*args, **kwargs):
38-
if arg in kwargs.keys() and getter(kwargs[arg]) == value:
48+
if all(
49+
arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i]
50+
for i in range(n)
51+
):
3952
pytest.skip(reason)
4053
return function(*args, **kwargs)
4154

test/test_1d_solver.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def test_wasserstein_1d(nx):
8585
np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
8686

8787

88-
@pytest.mark.parametrize('nx', backend_list)
8988
def test_wasserstein_1d_type_devices(nx):
9089

9190
rng = np.random.RandomState(0)
@@ -98,17 +97,15 @@ def test_wasserstein_1d_type_devices(nx):
9897
rho_v /= rho_v.sum()
9998

10099
for tp in nx.__type_list__:
101-
102-
print(tp.dtype)
100+
print(nx.dtype_device(tp))
103101

104102
xb = nx.from_numpy(x, type_as=tp)
105103
rho_ub = nx.from_numpy(rho_u, type_as=tp)
106104
rho_vb = nx.from_numpy(rho_v, type_as=tp)
107105

108106
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
109107

110-
if not str(nx) == 'numpy':
111-
assert res.dtype == xb.dtype
108+
nx.assert_same_dtype_device(xb, res)
112109

113110

114111
def test_emd_1d_emd2_1d():
@@ -162,17 +159,14 @@ def test_emd1d_type_devices(nx):
162159
rho_v /= rho_v.sum()
163160

164161
for tp in nx.__type_list__:
165-
166-
print(tp.dtype)
162+
print(nx.dtype_device(tp))
167163

168164
xb = nx.from_numpy(x, type_as=tp)
169165
rho_ub = nx.from_numpy(rho_u, type_as=tp)
170166
rho_vb = nx.from_numpy(rho_v, type_as=tp)
171167

172168
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
173-
174169
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
175170

176-
assert emd.dtype == xb.dtype
177-
if not str(nx) == 'numpy':
178-
assert emd2.dtype == xb.dtype
171+
nx.assert_same_dtype_device(xb, emd)
172+
nx.assert_same_dtype_device(xb, emd2)

test/test_bregman.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,51 @@ def test_sinkhorn_variants(nx):
278278
np.testing.assert_allclose(G0, G_green, atol=1e-5)
279279

280280

281+
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
282+
"sinkhorn_epsilon_scaling",
283+
"greenkhorn",
284+
"sinkhorn_log"])
285+
@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
286+
@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
287+
def test_sinkhorn_variants_dtype_device(nx, method):
288+
n = 100
289+
290+
x = np.random.randn(n, 2)
291+
u = ot.utils.unif(n)
292+
293+
M = ot.dist(x, x)
294+
295+
for tp in nx.__type_list__:
296+
print(nx.dtype_device(tp))
297+
298+
ub = nx.from_numpy(u, type_as=tp)
299+
Mb = nx.from_numpy(M, type_as=tp)
300+
301+
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
302+
303+
nx.assert_same_dtype_device(Mb, Gb)
304+
305+
306+
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
307+
def test_sinkhorn2_variants_dtype_device(nx, method):
308+
n = 100
309+
310+
x = np.random.randn(n, 2)
311+
u = ot.utils.unif(n)
312+
313+
M = ot.dist(x, x)
314+
315+
for tp in nx.__type_list__:
316+
print(nx.dtype_device(tp))
317+
318+
ub = nx.from_numpy(u, type_as=tp)
319+
Mb = nx.from_numpy(M, type_as=tp)
320+
321+
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
322+
323+
nx.assert_same_dtype_device(Mb, lossb)
324+
325+
281326
@pytest.skip_backend("jax")
282327
def test_sinkhorn_variants_multi_b(nx):
283328
# test sinkhorn

test/test_gromov.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,41 @@ def test_gromov(nx):
7575
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
7676

7777

78+
def test_gromov_dtype_device(nx):
79+
# setup
80+
n_samples = 50 # nb samples
81+
82+
mu_s = np.array([0, 0])
83+
cov_s = np.array([[1, 0], [0, 1]])
84+
85+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
86+
87+
xt = xs[::-1].copy()
88+
89+
p = ot.unif(n_samples)
90+
q = ot.unif(n_samples)
91+
92+
C1 = ot.dist(xs, xs)
93+
C2 = ot.dist(xt, xt)
94+
95+
C1 /= C1.max()
96+
C2 /= C2.max()
97+
98+
for tp in nx.__type_list__:
99+
print(nx.dtype_device(tp))
100+
101+
C1b = nx.from_numpy(C1, type_as=tp)
102+
C2b = nx.from_numpy(C2, type_as=tp)
103+
pb = nx.from_numpy(p, type_as=tp)
104+
qb = nx.from_numpy(q, type_as=tp)
105+
106+
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
107+
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
108+
109+
nx.assert_same_dtype_device(C1b, Gb)
110+
nx.assert_same_dtype_device(C1b, gw_valb)
111+
112+
78113
def test_gromov2_gradients():
79114
n_samples = 50 # nb samples
80115

@@ -168,6 +203,46 @@ def test_entropic_gromov(nx):
168203
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
169204

170205

206+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
207+
def test_entropic_gromov_dtype_device(nx):
208+
# setup
209+
n_samples = 50 # nb samples
210+
211+
mu_s = np.array([0, 0])
212+
cov_s = np.array([[1, 0], [0, 1]])
213+
214+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
215+
216+
xt = xs[::-1].copy()
217+
218+
p = ot.unif(n_samples)
219+
q = ot.unif(n_samples)
220+
221+
C1 = ot.dist(xs, xs)
222+
C2 = ot.dist(xt, xt)
223+
224+
C1 /= C1.max()
225+
C2 /= C2.max()
226+
227+
for tp in nx.__type_list__:
228+
print(nx.dtype_device(tp))
229+
230+
C1b = nx.from_numpy(C1, type_as=tp)
231+
C2b = nx.from_numpy(C2, type_as=tp)
232+
pb = nx.from_numpy(p, type_as=tp)
233+
qb = nx.from_numpy(q, type_as=tp)
234+
235+
Gb = ot.gromov.entropic_gromov_wasserstein(
236+
C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
237+
)
238+
gw_valb = ot.gromov.entropic_gromov_wasserstein2(
239+
C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
240+
)
241+
242+
nx.assert_same_dtype_device(C1b, Gb)
243+
nx.assert_same_dtype_device(C1b, gw_valb)
244+
245+
171246
def test_pointwise_gromov(nx):
172247
n_samples = 50 # nb samples
173248

0 commit comments

Comments
 (0)