@@ -264,6 +264,112 @@ def test_emd_transport_class():
264
264
assert n_unsup != n_semisup , "semisupervised mode not working"
265
265
266
266
267
+ def test_mapping_transport_class ():
268
+ """test_mapping_transport
269
+ """
270
+
271
+ ns = 150
272
+ nt = 200
273
+
274
+ Xs , ys = get_data_classif ('3gauss' , ns )
275
+ Xt , yt = get_data_classif ('3gauss2' , nt )
276
+ Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
277
+
278
+ ##########################################################################
279
+ # kernel == linear mapping tests
280
+ ##########################################################################
281
+
282
+ # check computation and dimensions if bias == False
283
+ clf = ot .da .MappingTransport (kernel = "linear" , bias = False )
284
+ clf .fit (Xs = Xs , Xt = Xt )
285
+
286
+ assert_equal (clf .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
287
+ assert_equal (clf .mapping_ .shape , ((Xs .shape [1 ], Xt .shape [1 ])))
288
+
289
+ # test margin constraints
290
+ mu_s = unif (ns )
291
+ mu_t = unif (nt )
292
+ assert_allclose (np .sum (clf .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
293
+ assert_allclose (np .sum (clf .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
294
+
295
+ # test transform
296
+ transp_Xs = clf .transform (Xs = Xs )
297
+ assert_equal (transp_Xs .shape , Xs .shape )
298
+
299
+ transp_Xs_new = clf .transform (Xs_new )
300
+
301
+ # check that the oos method is working
302
+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
303
+
304
+ # check computation and dimensions if bias == True
305
+ clf = ot .da .MappingTransport (kernel = "linear" , bias = True )
306
+ clf .fit (Xs = Xs , Xt = Xt )
307
+ assert_equal (clf .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
308
+ assert_equal (clf .mapping_ .shape , ((Xs .shape [1 ] + 1 , Xt .shape [1 ])))
309
+
310
+ # test margin constraints
311
+ mu_s = unif (ns )
312
+ mu_t = unif (nt )
313
+ assert_allclose (np .sum (clf .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
314
+ assert_allclose (np .sum (clf .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
315
+
316
+ # test transform
317
+ transp_Xs = clf .transform (Xs = Xs )
318
+ assert_equal (transp_Xs .shape , Xs .shape )
319
+
320
+ transp_Xs_new = clf .transform (Xs_new )
321
+
322
+ # check that the oos method is working
323
+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
324
+
325
+ ##########################################################################
326
+ # kernel == gaussian mapping tests
327
+ ##########################################################################
328
+
329
+ # check computation and dimensions if bias == False
330
+ clf = ot .da .MappingTransport (kernel = "gaussian" , bias = False )
331
+ clf .fit (Xs = Xs , Xt = Xt )
332
+
333
+ assert_equal (clf .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
334
+ assert_equal (clf .mapping_ .shape , ((Xs .shape [0 ], Xt .shape [1 ])))
335
+
336
+ # test margin constraints
337
+ mu_s = unif (ns )
338
+ mu_t = unif (nt )
339
+ assert_allclose (np .sum (clf .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
340
+ assert_allclose (np .sum (clf .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
341
+
342
+ # test transform
343
+ transp_Xs = clf .transform (Xs = Xs )
344
+ assert_equal (transp_Xs .shape , Xs .shape )
345
+
346
+ transp_Xs_new = clf .transform (Xs_new )
347
+
348
+ # check that the oos method is working
349
+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
350
+
351
+ # check computation and dimensions if bias == True
352
+ clf = ot .da .MappingTransport (kernel = "gaussian" , bias = True )
353
+ clf .fit (Xs = Xs , Xt = Xt )
354
+ assert_equal (clf .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
355
+ assert_equal (clf .mapping_ .shape , ((Xs .shape [0 ] + 1 , Xt .shape [1 ])))
356
+
357
+ # test margin constraints
358
+ mu_s = unif (ns )
359
+ mu_t = unif (nt )
360
+ assert_allclose (np .sum (clf .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
361
+ assert_allclose (np .sum (clf .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
362
+
363
+ # test transform
364
+ transp_Xs = clf .transform (Xs = Xs )
365
+ assert_equal (transp_Xs .shape , Xs .shape )
366
+
367
+ transp_Xs_new = clf .transform (Xs_new )
368
+
369
+ # check that the oos method is working
370
+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
371
+
372
+
267
373
def test_otda ():
268
374
269
375
n_samples = 150 # nb samples
@@ -326,9 +432,10 @@ def test_otda():
326
432
da_emd .predict (xs ) # interpolation of source samples
327
433
328
434
329
- # if __name__ == "__main__":
435
+ if __name__ == "__main__" :
330
436
331
- # test_sinkhorn_transport_class()
332
- # test_emd_transport_class()
333
- # test_sinkhorn_l1l2_transport_class()
334
- # test_sinkhorn_lpl1_transport_class()
437
+ # test_sinkhorn_transport_class()
438
+ # test_emd_transport_class()
439
+ # test_sinkhorn_l1l2_transport_class()
440
+ # test_sinkhorn_lpl1_transport_class()
441
+ test_mapping_transport_class ()
0 commit comments