@@ -245,6 +245,79 @@ def test_sinkhorn_transport_class():
245
245
assert len (otda .log_ .keys ()) != 0
246
246
247
247
248
+ def test_unbalanced_sinkhorn_transport_class ():
249
+ """test_sinkhorn_transport
250
+ """
251
+
252
+ ns = 150
253
+ nt = 200
254
+
255
+ Xs , ys = make_data_classif ('3gauss' , ns )
256
+ Xt , yt = make_data_classif ('3gauss2' , nt )
257
+
258
+ otda = ot .da .UnbalancedSinkhornTransport ()
259
+
260
+ # test its computed
261
+ otda .fit (Xs = Xs , Xt = Xt )
262
+ assert hasattr (otda , "cost_" )
263
+ assert hasattr (otda , "coupling_" )
264
+ assert hasattr (otda , "log_" )
265
+
266
+ # test dimensions of coupling
267
+ assert_equal (otda .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
268
+ assert_equal (otda .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
269
+
270
+ # test margin constraints
271
+ mu_s = unif (ns )
272
+ mu_t = unif (nt )
273
+ assert_allclose (
274
+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
275
+ assert_allclose (
276
+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
277
+
278
+ # test transform
279
+ transp_Xs = otda .transform (Xs = Xs )
280
+ assert_equal (transp_Xs .shape , Xs .shape )
281
+
282
+ Xs_new , _ = make_data_classif ('3gauss' , ns + 1 )
283
+ transp_Xs_new = otda .transform (Xs_new )
284
+
285
+ # check that the oos method is working
286
+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
287
+
288
+ # test inverse transform
289
+ transp_Xt = otda .inverse_transform (Xt = Xt )
290
+ assert_equal (transp_Xt .shape , Xt .shape )
291
+
292
+ Xt_new , _ = make_data_classif ('3gauss2' , nt + 1 )
293
+ transp_Xt_new = otda .inverse_transform (Xt = Xt_new )
294
+
295
+ # check that the oos method is working
296
+ assert_equal (transp_Xt_new .shape , Xt_new .shape )
297
+
298
+ # test fit_transform
299
+ transp_Xs = otda .fit_transform (Xs = Xs , Xt = Xt )
300
+ assert_equal (transp_Xs .shape , Xs .shape )
301
+
302
+ # test unsupervised vs semi-supervised mode
303
+ otda_unsup = ot .da .SinkhornTransport ()
304
+ otda_unsup .fit (Xs = Xs , Xt = Xt )
305
+ n_unsup = np .sum (otda_unsup .cost_ )
306
+
307
+ otda_semi = ot .da .SinkhornTransport ()
308
+ otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
309
+ assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
310
+ n_semisup = np .sum (otda_semi .cost_ )
311
+
312
+ # check that the cost matrix norms are indeed different
313
+ assert n_unsup != n_semisup , "semisupervised mode not working"
314
+
315
+ # check everything runs well with log=True
316
+ otda = ot .da .SinkhornTransport (log = True )
317
+ otda .fit (Xs = Xs , ys = ys , Xt = Xt )
318
+ assert len (otda .log_ .keys ()) != 0
319
+
320
+
248
321
def test_emd_transport_class ():
249
322
"""test_sinkhorn_transport
250
323
"""
0 commit comments