@@ -36,8 +36,10 @@ def test_sinkhorn_lpl1_transport_class():
36
36
# test margin constraints
37
37
mu_s = unif (ns )
38
38
mu_t = unif (nt )
39
- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
40
- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
39
+ assert_allclose (
40
+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
41
+ assert_allclose (
42
+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
41
43
42
44
# test transform
43
45
transp_Xs = otda .transform (Xs = Xs )
@@ -108,8 +110,10 @@ def test_sinkhorn_l1l2_transport_class():
108
110
# test margin constraints
109
111
mu_s = unif (ns )
110
112
mu_t = unif (nt )
111
- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
112
- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
113
+ assert_allclose (
114
+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
115
+ assert_allclose (
116
+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
113
117
114
118
# test transform
115
119
transp_Xs = otda .transform (Xs = Xs )
@@ -187,8 +191,10 @@ def test_sinkhorn_transport_class():
187
191
# test margin constraints
188
192
mu_s = unif (ns )
189
193
mu_t = unif (nt )
190
- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
191
- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
194
+ assert_allclose (
195
+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
196
+ assert_allclose (
197
+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
192
198
193
199
# test transform
194
200
transp_Xs = otda .transform (Xs = Xs )
@@ -263,8 +269,10 @@ def test_emd_transport_class():
263
269
# test margin constraints
264
270
mu_s = unif (ns )
265
271
mu_t = unif (nt )
266
- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
267
- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
272
+ assert_allclose (
273
+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
274
+ assert_allclose (
275
+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
268
276
269
277
# test transform
270
278
transp_Xs = otda .transform (Xs = Xs )
@@ -342,8 +350,10 @@ def test_mapping_transport_class():
342
350
# test margin constraints
343
351
mu_s = unif (ns )
344
352
mu_t = unif (nt )
345
- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
346
- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
353
+ assert_allclose (
354
+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
355
+ assert_allclose (
356
+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
347
357
348
358
# test transform
349
359
transp_Xs = otda .transform (Xs = Xs )
@@ -363,8 +373,10 @@ def test_mapping_transport_class():
363
373
# test margin constraints
364
374
mu_s = unif (ns )
365
375
mu_t = unif (nt )
366
- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
367
- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
376
+ assert_allclose (
377
+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
378
+ assert_allclose (
379
+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
368
380
369
381
# test transform
370
382
transp_Xs = otda .transform (Xs = Xs )
@@ -389,8 +401,10 @@ def test_mapping_transport_class():
389
401
# test margin constraints
390
402
mu_s = unif (ns )
391
403
mu_t = unif (nt )
392
- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
393
- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
404
+ assert_allclose (
405
+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
406
+ assert_allclose (
407
+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
394
408
395
409
# test transform
396
410
transp_Xs = otda .transform (Xs = Xs )
@@ -410,8 +424,10 @@ def test_mapping_transport_class():
410
424
# test margin constraints
411
425
mu_s = unif (ns )
412
426
mu_t = unif (nt )
413
- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
414
- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
427
+ assert_allclose (
428
+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
429
+ assert_allclose (
430
+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
415
431
416
432
# test transform
417
433
transp_Xs = otda .transform (Xs = Xs )
@@ -454,7 +470,8 @@ def test_otda():
454
470
da_entrop .interp ()
455
471
da_entrop .predict (xs )
456
472
457
- np .testing .assert_allclose (a , np .sum (da_entrop .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
473
+ np .testing .assert_allclose (
474
+ a , np .sum (da_entrop .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
458
475
np .testing .assert_allclose (b , np .sum (da_entrop .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
459
476
460
477
# non-convex Group lasso regularization
@@ -488,13 +505,3 @@ def test_otda():
488
505
da_emd = ot .da .OTDA_mapping_kernel () # init class
489
506
da_emd .fit (xs , xt , numItermax = 10 ) # fit distributions
490
507
da_emd .predict (xs ) # interpolation of source samples
491
-
492
-
493
- # if __name__ == "__main__":
494
-
495
- # test_sinkhorn_transport_class()
496
- # test_emd_transport_class()
497
- # test_sinkhorn_l1l2_transport_class()
498
- # test_sinkhorn_lpl1_transport_class()
499
- # test_mapping_transport_class()
500
-
0 commit comments