Skip to content

Commit 64a5d3c

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
docstrings and naming
2 parents ab6ed1d + 986f46d commit 64a5d3c

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed

examples/plot_gromov.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
2727
"""
2828

29+
<<<<<<< HEAD
2930
n_samples = 30 # nb samples
31+
=======
32+
n = 30 # nb samples
33+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
3034

3135
mu_s = np.array([0, 0])
3236
cov_s = np.array([[1, 0], [0, 1]])
@@ -35,9 +39,15 @@
3539
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
3640

3741

42+
<<<<<<< HEAD
3843
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
3944
P = sp.linalg.sqrtm(cov_t)
4045
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
46+
=======
47+
xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
48+
P = sp.linalg.sqrtm(cov_t)
49+
xt = np.random.randn(n, 3).dot(P) + mu_t
50+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
4151

4252

4353
"""
@@ -75,8 +85,13 @@
7585
=============================================
7686
"""
7787

88+
<<<<<<< HEAD
7889
p = ot.unif(n_samples)
7990
q = ot.unif(n_samples)
91+
=======
92+
p = ot.unif(n)
93+
q = ot.unif(n)
94+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
8095

8196
gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
8297
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)

examples/plot_gromov_barycenter.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,21 @@ def im2mat(I):
9191
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
9292

9393

94+
<<<<<<< HEAD
9495
square = spi.imread('../data/carre.png').astype(np.float64) / 256
9596
circle = spi.imread('../data/rond.png').astype(np.float64) / 256
9697
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
9798
arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256
9899

99100
shapes = [square, circle, triangle, arrow]
101+
=======
102+
carre = spi.imread('../data/carre.png').astype(np.float64) / 256
103+
rond = spi.imread('../data/rond.png').astype(np.float64) / 256
104+
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
105+
fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256
106+
107+
shapes = [carre, rond, triangle, fleche]
108+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
100109

101110
S = 4
102111
xs = [[] for i in range(S)]
@@ -118,36 +127,60 @@ def im2mat(I):
118127
The four distributions are constructed from 4 simple images
119128
"""
120129
ns = [len(xs[s]) for s in range(S)]
130+
<<<<<<< HEAD
121131
n_samples = 30
132+
=======
133+
N = 30
134+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
122135

123136
"""Compute all distances matrices for the four shapes"""
124137
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
125138
Cs = [cs / cs.max() for cs in Cs]
126139

127140
ps = [ot.unif(ns[s]) for s in range(S)]
141+
<<<<<<< HEAD
128142
p = ot.unif(n_samples)
143+
=======
144+
p = ot.unif(N)
145+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
129146

130147

131148
lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
132149

133150
Ct01 = [0 for i in range(2)]
134151
for i in range(2):
152+
<<<<<<< HEAD
135153
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], [
154+
=======
155+
Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [
156+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
136157
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
137158

138159
Ct02 = [0 for i in range(2)]
139160
for i in range(2):
161+
<<<<<<< HEAD
140162
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], [
163+
=======
164+
Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [
165+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
141166
ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
142167

143168
Ct13 = [0 for i in range(2)]
144169
for i in range(2):
170+
<<<<<<< HEAD
145171
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], [
172+
=======
173+
Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [
174+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
146175
ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
147176

148177
Ct23 = [0 for i in range(2)]
149178
for i in range(2):
179+
<<<<<<< HEAD
150180
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], [
181+
=======
182+
Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [
183+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
151184
ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
152185

153186
"""

ot/gromov.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,11 @@ def update_kl_loss(p, lambdas, T, Cs):
208208
return(np.exp(np.divide(tmpsum, ppt)))
209209

210210

211+
<<<<<<< HEAD
211212
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
213+
=======
214+
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
215+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
212216
"""
213217
Returns the gromov-wasserstein coupling between the two measured similarity matrices
214218
@@ -248,7 +252,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
248252
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
249253
epsilon : float
250254
Regularization term >0
255+
<<<<<<< HEAD
251256
max_iter : int, optional
257+
=======
258+
numItermax : int, optional
259+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
252260
Max number of iterations
253261
stopThr : float, optional
254262
Stop threshold on error (>0)
@@ -274,7 +282,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
274282
cpt = 0
275283
err = 1
276284

285+
<<<<<<< HEAD
277286
while (err > stopThr and cpt < max_iter):
287+
=======
288+
while (err > stopThr and cpt < numItermax):
289+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
278290

279291
Tprev = T
280292

@@ -307,7 +319,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
307319
return T
308320

309321

322+
<<<<<<< HEAD
310323
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
324+
=======
325+
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
326+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
311327
"""
312328
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
313329
@@ -362,10 +378,17 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
362378

363379
if log:
364380
gw, logv = gromov_wasserstein(
381+
<<<<<<< HEAD
365382
C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
366383
else:
367384
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
368385
epsilon, max_iter, stopThr, verbose, log)
386+
=======
387+
C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log)
388+
else:
389+
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
390+
epsilon, numItermax, stopThr, verbose, log)
391+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
369392

370393
if loss_fun == 'square_loss':
371394
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
@@ -379,7 +402,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
379402
return gw_dist
380403

381404

405+
<<<<<<< HEAD
382406
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
407+
=======
408+
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
409+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
383410
"""
384411
Returns the gromov-wasserstein barycenters of S measured similarity matrices
385412
@@ -442,12 +469,20 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
442469

443470
error = []
444471

472+
<<<<<<< HEAD
445473
while(err > stopThr and cpt < max_iter):
474+
=======
475+
while(err > stopThr and cpt < numItermax):
476+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
446477

447478
Cprev = C
448479

449480
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
481+
<<<<<<< HEAD
450482
max_iter, 1e-5, verbose, log) for s in range(S)]
483+
=======
484+
numItermax, 1e-5, verbose, log) for s in range(S)]
485+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
451486

452487
if loss_fun == 'square_loss':
453488
C = update_square_loss(p, lambdas, T, Cs)

test/test_gromov.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,32 @@
1010

1111

1212
def test_gromov():
13+
<<<<<<< HEAD
1314
n_samples = 50 # nb samples
15+
=======
16+
n = 50 # nb samples
17+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
1418

1519
mu_s = np.array([0, 0])
1620
cov_s = np.array([[1, 0], [0, 1]])
1721

22+
<<<<<<< HEAD
1823
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
1924

2025
xt = [xs[n_samples - (i + 1)] for i in range(n_samples)]
2126
xt = np.array(xt)
2227

2328
p = ot.unif(n_samples)
2429
q = ot.unif(n_samples)
30+
=======
31+
xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
32+
33+
xt = [xs[n - (i + 1)] for i in range(n)]
34+
xt = np.array(xt)
35+
36+
p = ot.unif(n)
37+
q = ot.unif(n)
38+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
2539

2640
C1 = ot.dist(xs, xs)
2741
C2 = ot.dist(xt, xt)

0 commit comments

Comments
 (0)