Skip to content

Commit ab6ed1d

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
docstrings and naming
1 parent 4ec5b33 commit ab6ed1d

File tree

4 files changed

+29
-29
lines changed

4 files changed

+29
-29
lines changed

examples/plot_gromov.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
2727
"""
2828

29-
n = 30 # nb samples
29+
n_samples = 30 # nb samples
3030

3131
mu_s = np.array([0, 0])
3232
cov_s = np.array([[1, 0], [0, 1]])
@@ -35,9 +35,9 @@
3535
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
3636

3737

38-
xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
38+
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
3939
P = sp.linalg.sqrtm(cov_t)
40-
xt = np.random.randn(n, 3).dot(P) + mu_t
40+
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
4141

4242

4343
"""
@@ -75,8 +75,8 @@
7575
=============================================
7676
"""
7777

78-
p = ot.unif(n)
79-
q = ot.unif(n)
78+
p = ot.unif(n_samples)
79+
q = ot.unif(n_samples)
8080

8181
gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
8282
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)

examples/plot_gromov_barycenter.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ def im2mat(I):
9191
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
9292

9393

94-
carre = spi.imread('../data/carre.png').astype(np.float64) / 256
95-
rond = spi.imread('../data/rond.png').astype(np.float64) / 256
94+
square = spi.imread('../data/carre.png').astype(np.float64) / 256
95+
circle = spi.imread('../data/rond.png').astype(np.float64) / 256
9696
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
97-
fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256
97+
arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256
9898

99-
shapes = [carre, rond, triangle, fleche]
99+
shapes = [square, circle, triangle, arrow]
100100

101101
S = 4
102102
xs = [[] for i in range(S)]
@@ -118,36 +118,36 @@ def im2mat(I):
118118
The four distributions are constructed from 4 simple images
119119
"""
120120
ns = [len(xs[s]) for s in range(S)]
121-
N = 30
121+
n_samples = 30
122122

123123
"""Compute all distances matrices for the four shapes"""
124124
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
125125
Cs = [cs / cs.max() for cs in Cs]
126126

127127
ps = [ot.unif(ns[s]) for s in range(S)]
128-
p = ot.unif(N)
128+
p = ot.unif(n_samples)
129129

130130

131131
lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
132132

133133
Ct01 = [0 for i in range(2)]
134134
for i in range(2):
135-
Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [
135+
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], [
136136
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
137137

138138
Ct02 = [0 for i in range(2)]
139139
for i in range(2):
140-
Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [
140+
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], [
141141
ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
142142

143143
Ct13 = [0 for i in range(2)]
144144
for i in range(2):
145-
Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [
145+
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], [
146146
ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
147147

148148
Ct23 = [0 for i in range(2)]
149149
for i in range(2):
150-
Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [
150+
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], [
151151
ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
152152

153153
"""

ot/gromov.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def update_kl_loss(p, lambdas, T, Cs):
208208
return(np.exp(np.divide(tmpsum, ppt)))
209209

210210

211-
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
211+
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
212212
"""
213213
Returns the gromov-wasserstein coupling between the two measured similarity matrices
214214
@@ -248,7 +248,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
248248
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
249249
epsilon : float
250250
Regularization term >0
251-
numItermax : int, optional
251+
max_iter : int, optional
252252
Max number of iterations
253253
stopThr : float, optional
254254
Stop threshold on error (>0)
@@ -274,7 +274,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
274274
cpt = 0
275275
err = 1
276276

277-
while (err > stopThr and cpt < numItermax):
277+
while (err > stopThr and cpt < max_iter):
278278

279279
Tprev = T
280280

@@ -307,7 +307,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
307307
return T
308308

309309

310-
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
310+
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
311311
"""
312312
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
313313
@@ -362,10 +362,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
362362

363363
if log:
364364
gw, logv = gromov_wasserstein(
365-
C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log)
365+
C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
366366
else:
367367
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
368-
epsilon, numItermax, stopThr, verbose, log)
368+
epsilon, max_iter, stopThr, verbose, log)
369369

370370
if loss_fun == 'square_loss':
371371
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
@@ -379,7 +379,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
379379
return gw_dist
380380

381381

382-
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
382+
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
383383
"""
384384
Returns the gromov-wasserstein barycenters of S measured similarity matrices
385385
@@ -442,12 +442,12 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000
442442

443443
error = []
444444

445-
while(err > stopThr and cpt < numItermax):
445+
while(err > stopThr and cpt < max_iter):
446446

447447
Cprev = C
448448

449449
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
450-
numItermax, 1e-5, verbose, log) for s in range(S)]
450+
max_iter, 1e-5, verbose, log) for s in range(S)]
451451

452452
if loss_fun == 'square_loss':
453453
C = update_square_loss(p, lambdas, T, Cs)

test/test_gromov.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@
1010

1111

1212
def test_gromov():
13-
n = 50 # nb samples
13+
n_samples = 50 # nb samples
1414

1515
mu_s = np.array([0, 0])
1616
cov_s = np.array([[1, 0], [0, 1]])
1717

18-
xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
18+
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
1919

20-
xt = [xs[n - (i + 1)] for i in range(n)]
20+
xt = [xs[n_samples - (i + 1)] for i in range(n_samples)]
2121
xt = np.array(xt)
2222

23-
p = ot.unif(n)
24-
q = ot.unif(n)
23+
p = ot.unif(n_samples)
24+
q = ot.unif(n_samples)
2525

2626
C1 = ot.dist(xs, xs)
2727
C2 = ot.dist(xt, xt)

0 commit comments

Comments
 (0)