Skip to content

Commit 326d163

Browse files
committed
test functions for MappingTransport Class
1 parent 791a4a6 commit 326d163

File tree

2 files changed

+125
-10
lines changed

2 files changed

+125
-10
lines changed

ot/da.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,8 +1665,14 @@ class MappingTransport(BaseEstimator):
16651665
16661666
Attributes
16671667
----------
1668-
coupling_ : the optimal coupling
1669-
mapping_ : the mapping associated
1668+
coupling_ : array-like, shape (n_source_samples, n_features)
1669+
The optimal coupling
1670+
mapping_ : array-like, shape (n_features (+ 1), n_features)
1671+
(if bias) for kernel == linear
1672+
The associated mapping
1673+
1674+
array-like, shape (n_source_samples (+ 1), n_features)
1675+
(if bias) for kernel == gaussian
16701676
16711677
References
16721678
----------
@@ -1679,20 +1685,22 @@ class MappingTransport(BaseEstimator):
16791685

16801686
def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean",
16811687
kernel="linear", sigma=1, max_iter=100, tol=1e-5,
1682-
max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False):
1688+
max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False,
1689+
verbose2=False):
16831690

16841691
self.metric = metric
16851692
self.mu = mu
16861693
self.eta = eta
16871694
self.bias = bias
16881695
self.kernel = kernel
1689-
self.sigma
1696+
self.sigma = sigma
16901697
self.max_iter = max_iter
16911698
self.tol = tol
16921699
self.max_inner_iter = max_inner_iter
16931700
self.inner_tol = inner_tol
16941701
self.log = log
16951702
self.verbose = verbose
1703+
self.verbose2 = verbose2
16961704

16971705
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
16981706
"""Builds an optimal coupling and estimates the associated mapping
@@ -1712,7 +1720,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
17121720
Returns
17131721
-------
17141722
self : object
1715-
Returns self.
1723+
Returns self
17161724
"""
17171725

17181726
self.Xs = Xs

test/test_da.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,112 @@ def test_emd_transport_class():
264264
assert n_unsup != n_semisup, "semisupervised mode not working"
265265

266266

267+
def test_mapping_transport_class():
268+
"""test_mapping_transport
269+
"""
270+
271+
ns = 150
272+
nt = 200
273+
274+
Xs, ys = get_data_classif('3gauss', ns)
275+
Xt, yt = get_data_classif('3gauss2', nt)
276+
Xs_new, _ = get_data_classif('3gauss', ns + 1)
277+
278+
##########################################################################
279+
# kernel == linear mapping tests
280+
##########################################################################
281+
282+
# check computation and dimensions if bias == False
283+
clf = ot.da.MappingTransport(kernel="linear", bias=False)
284+
clf.fit(Xs=Xs, Xt=Xt)
285+
286+
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
287+
assert_equal(clf.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
288+
289+
# test margin constraints
290+
mu_s = unif(ns)
291+
mu_t = unif(nt)
292+
assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
293+
assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
294+
295+
# test transform
296+
transp_Xs = clf.transform(Xs=Xs)
297+
assert_equal(transp_Xs.shape, Xs.shape)
298+
299+
transp_Xs_new = clf.transform(Xs_new)
300+
301+
# check that the oos method is working
302+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
303+
304+
# check computation and dimensions if bias == True
305+
clf = ot.da.MappingTransport(kernel="linear", bias=True)
306+
clf.fit(Xs=Xs, Xt=Xt)
307+
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
308+
assert_equal(clf.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1])))
309+
310+
# test margin constraints
311+
mu_s = unif(ns)
312+
mu_t = unif(nt)
313+
assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
314+
assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
315+
316+
# test transform
317+
transp_Xs = clf.transform(Xs=Xs)
318+
assert_equal(transp_Xs.shape, Xs.shape)
319+
320+
transp_Xs_new = clf.transform(Xs_new)
321+
322+
# check that the oos method is working
323+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
324+
325+
##########################################################################
326+
# kernel == gaussian mapping tests
327+
##########################################################################
328+
329+
# check computation and dimensions if bias == False
330+
clf = ot.da.MappingTransport(kernel="gaussian", bias=False)
331+
clf.fit(Xs=Xs, Xt=Xt)
332+
333+
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
334+
assert_equal(clf.mapping_.shape, ((Xs.shape[0], Xt.shape[1])))
335+
336+
# test margin constraints
337+
mu_s = unif(ns)
338+
mu_t = unif(nt)
339+
assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
340+
assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
341+
342+
# test transform
343+
transp_Xs = clf.transform(Xs=Xs)
344+
assert_equal(transp_Xs.shape, Xs.shape)
345+
346+
transp_Xs_new = clf.transform(Xs_new)
347+
348+
# check that the oos method is working
349+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
350+
351+
# check computation and dimensions if bias == True
352+
clf = ot.da.MappingTransport(kernel="gaussian", bias=True)
353+
clf.fit(Xs=Xs, Xt=Xt)
354+
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
355+
assert_equal(clf.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1])))
356+
357+
# test margin constraints
358+
mu_s = unif(ns)
359+
mu_t = unif(nt)
360+
assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
361+
assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
362+
363+
# test transform
364+
transp_Xs = clf.transform(Xs=Xs)
365+
assert_equal(transp_Xs.shape, Xs.shape)
366+
367+
transp_Xs_new = clf.transform(Xs_new)
368+
369+
# check that the oos method is working
370+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
371+
372+
267373
def test_otda():
268374

269375
n_samples = 150 # nb samples
@@ -326,9 +432,10 @@ def test_otda():
326432
da_emd.predict(xs) # interpolation of source samples
327433

328434

329-
# if __name__ == "__main__":
435+
if __name__ == "__main__":
330436

331-
# test_sinkhorn_transport_class()
332-
# test_emd_transport_class()
333-
# test_sinkhorn_l1l2_transport_class()
334-
# test_sinkhorn_lpl1_transport_class()
437+
# test_sinkhorn_transport_class()
438+
# test_emd_transport_class()
439+
# test_sinkhorn_l1l2_transport_class()
440+
# test_sinkhorn_lpl1_transport_class()
441+
test_mapping_transport_class()

0 commit comments

Comments
 (0)