Skip to content

Commit 13444ca

Browse files
committed
partial with tests
1 parent 8c724ad commit 13444ca

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed

test/test_partial.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Tests for module partial """
2+
3+
# Author:
4+
# Laetitia Chapel <[email protected]>
5+
#
6+
# License: MIT License
7+
8+
import numpy as np
9+
import scipy as sp
10+
import ot
11+
12+
13+
def test_partial_wasserstein():
14+
15+
n_samples = 20 # nb samples (gaussian)
16+
n_noise = 20 # nb of samples (noise)
17+
18+
mu = np.array([0, 0])
19+
cov = np.array([[1, 0], [0, 2]])
20+
21+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
22+
xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
23+
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
24+
xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
25+
26+
M = ot.dist(xs, xt)
27+
28+
p = ot.unif(n_samples + n_noise)
29+
q = ot.unif(n_samples + n_noise)
30+
31+
m = 0.5
32+
33+
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
34+
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
35+
log=True)
36+
37+
# check constratints
38+
np.testing.assert_equal(
39+
w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
40+
np.testing.assert_equal(
41+
w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
42+
np.testing.assert_equal(
43+
w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
44+
np.testing.assert_equal(
45+
w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
46+
47+
# check transported mass
48+
np.testing.assert_allclose(
49+
np.sum(w0), m, atol=1e-04)
50+
np.testing.assert_allclose(
51+
np.sum(w), m, atol=1e-04)
52+
53+
w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
54+
w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False)
55+
56+
G = log0['T']
57+
58+
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
59+
60+
# check constratints
61+
np.testing.assert_equal(
62+
G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
63+
np.testing.assert_equal(
64+
G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
65+
np.testing.assert_allclose(
66+
np.sum(G), m, atol=1e-04)
67+
68+
69+
def test_partial_gromov_wasserstein():
70+
n_samples = 20 # nb samples
71+
n_noise = 10 # nb of samples (noise)
72+
73+
p = ot.unif(n_samples + n_noise)
74+
q = ot.unif(n_samples + n_noise)
75+
76+
mu_s = np.array([0, 0])
77+
cov_s = np.array([[1, 0], [0, 1]])
78+
79+
mu_t = np.array([0, 0, 0])
80+
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
81+
82+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
83+
xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
84+
P = sp.linalg.sqrtm(cov_t)
85+
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
86+
xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
87+
xt2 = xs[::-1].copy()
88+
89+
C1 = ot.dist(xs, xs)
90+
C2 = ot.dist(xt, xt)
91+
C3 = ot.dist(xt2, xt2)
92+
93+
m = 2 / 3
94+
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C3, p, q, m=m,
95+
log=True)
96+
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C3, p, q, 10,
97+
m=m, log=True)
98+
np.testing.assert_allclose(res0, 0, atol=1e-1, rtol=1e-1)
99+
np.testing.assert_allclose(res, 0, atol=1e-1, rtol=1e-1)
100+
101+
C1 = sp.spatial.distance.cdist(xs, xs)
102+
C2 = sp.spatial.distance.cdist(xt, xt)
103+
104+
m = 1
105+
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
106+
log=True)
107+
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss')
108+
np.testing.assert_allclose(G, res0, atol=1e-04)
109+
110+
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
111+
m=m, log=True)
112+
G = ot.gromov.entropic_gromov_wasserstein(
113+
C1, C2, p, q, 'square_loss', epsilon=10)
114+
np.testing.assert_allclose(G, res, atol=1e-02)
115+
116+
w0, log0 = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m,
117+
log=True)
118+
w0_val = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m,
119+
log=False)
120+
G = log0['T']
121+
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
122+
123+
m = 2 / 3
124+
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
125+
log=True)
126+
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
127+
m=m, log=True)
128+
# check constratints
129+
np.testing.assert_equal(
130+
res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
131+
np.testing.assert_equal(
132+
res0.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
133+
np.testing.assert_allclose(
134+
np.sum(res0), m, atol=1e-04)
135+
136+
np.testing.assert_equal(
137+
res.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
138+
np.testing.assert_equal(
139+
res.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
140+
np.testing.assert_allclose(
141+
np.sum(res), m, atol=1e-04)

0 commit comments

Comments
 (0)