@@ -42,14 +42,34 @@ def test_class_jax_tf():
42
42
otda .fit (Xs = Xs , ys = ys , Xt = Xt )
43
43
44
44
45
+ @pytest .skip_backend ("jax" )
46
+ @pytest .skip_backend ("tf" )
47
+ @pytest .mark .parametrize ("class_to_test" , [ot .da .EMDTransport , ot .da .SinkhornTransport , ot .da .SinkhornLpl1Transport , ot .da .SinkhornL1l2Transport , ot .da .SinkhornL1l2Transport ])
48
+ def test_log_da (nx , class_to_test ):
49
+
50
+ ns = 50
51
+ nt = 50
52
+
53
+ Xs , ys = make_data_classif ('3gauss' , ns )
54
+ Xt , yt = make_data_classif ('3gauss2' , nt )
55
+
56
+ Xs , ys , Xt , yt = nx .from_numpy (Xs , ys , Xt , yt )
57
+
58
+ otda = class_to_test (log = True )
59
+
60
+ # test its computed
61
+ otda .fit (Xs = Xs , ys = ys , Xt = Xt )
62
+ assert hasattr (otda , "log_" )
63
+
64
+
45
65
@pytest .skip_backend ("jax" )
46
66
@pytest .skip_backend ("tf" )
47
67
def test_sinkhorn_lpl1_transport_class (nx ):
48
68
"""test_sinkhorn_transport
49
69
"""
50
70
51
- ns = 150
52
- nt = 200
71
+ ns = 50
72
+ nt = 50
53
73
54
74
Xs , ys = make_data_classif ('3gauss' , ns )
55
75
Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
136
156
"""
137
157
138
158
ns = 50
139
- nt = 100
159
+ nt = 50
140
160
141
161
Xs , ys = make_data_classif ('3gauss' , ns )
142
162
Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx):
230
250
"""test_sinkhorn_transport
231
251
"""
232
252
233
- ns = 150
234
- nt = 200
253
+ ns = 50
254
+ nt = 50
235
255
236
256
Xs , ys = make_data_classif ('3gauss' , ns )
237
257
Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx):
323
343
"""test_sinkhorn_transport
324
344
"""
325
345
326
- ns = 150
327
- nt = 200
346
+ ns = 50
347
+ nt = 50
328
348
329
349
Xs , ys = make_data_classif ('3gauss' , ns )
330
350
Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -402,8 +422,8 @@ def test_emd_transport_class(nx):
402
422
"""test_sinkhorn_transport
403
423
"""
404
424
405
- ns = 150
406
- nt = 200
425
+ ns = 50
426
+ nt = 50
407
427
408
428
Xs , ys = make_data_classif ('3gauss' , ns )
409
429
Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -558,8 +578,8 @@ def test_mapping_transport_class_specific_seed(nx):
558
578
@pytest .skip_backend ("jax" )
559
579
@pytest .skip_backend ("tf" )
560
580
def test_linear_mapping (nx ):
561
- ns = 150
562
- nt = 200
581
+ ns = 50
582
+ nt = 50
563
583
564
584
Xs , ys = make_data_classif ('3gauss' , ns )
565
585
Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -579,8 +599,8 @@ def test_linear_mapping(nx):
579
599
@pytest .skip_backend ("jax" )
580
600
@pytest .skip_backend ("tf" )
581
601
def test_linear_mapping_class (nx ):
582
- ns = 150
583
- nt = 200
602
+ ns = 50
603
+ nt = 50
584
604
585
605
Xs , ys = make_data_classif ('3gauss' , ns )
586
606
Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -609,9 +629,9 @@ def test_jcpot_transport_class(nx):
609
629
"""test_jcpot_transport
610
630
"""
611
631
612
- ns1 = 150
613
- ns2 = 150
614
- nt = 200
632
+ ns1 = 50
633
+ ns2 = 50
634
+ nt = 50
615
635
616
636
Xs1 , ys1 = make_data_classif ('3gauss' , ns1 )
617
637
Xs2 , ys2 = make_data_classif ('3gauss' , ns2 )
@@ -681,9 +701,9 @@ def test_jcpot_barycenter(nx):
681
701
"""test_jcpot_barycenter
682
702
"""
683
703
684
- ns1 = 150
685
- ns2 = 150
686
- nt = 200
704
+ ns1 = 50
705
+ ns2 = 50
706
+ nt = 50
687
707
688
708
sigma = 0.1
689
709
np .random .seed (1985 )
@@ -713,8 +733,8 @@ def test_jcpot_barycenter(nx):
713
733
def test_emd_laplace_class (nx ):
714
734
"""test_emd_laplace_transport
715
735
"""
716
- ns = 150
717
- nt = 200
736
+ ns = 50
737
+ nt = 50
718
738
719
739
Xs , ys = make_data_classif ('3gauss' , ns )
720
740
Xt , yt = make_data_classif ('3gauss2' , nt )
0 commit comments