Skip to content

Commit 46fc12a

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
solving conflicts :/
1 parent 64a5d3c commit 46fc12a

File tree

4 files changed

+5
-100
lines changed

4 files changed

+5
-100
lines changed

examples/plot_gromov.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@
2626
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
2727
"""
2828

29-
<<<<<<< HEAD
3029
n_samples = 30 # nb samples
31-
=======
32-
n = 30 # nb samples
33-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
3430

3531
mu_s = np.array([0, 0])
3632
cov_s = np.array([[1, 0], [0, 1]])
@@ -39,15 +35,9 @@
3935
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
4036

4137

42-
<<<<<<< HEAD
4338
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
4439
P = sp.linalg.sqrtm(cov_t)
4540
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
46-
=======
47-
xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
48-
P = sp.linalg.sqrtm(cov_t)
49-
xt = np.random.randn(n, 3).dot(P) + mu_t
50-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
5141

5242

5343
"""
@@ -85,13 +75,8 @@
8575
=============================================
8676
"""
8777

88-
<<<<<<< HEAD
8978
p = ot.unif(n_samples)
9079
q = ot.unif(n_samples)
91-
=======
92-
p = ot.unif(n)
93-
q = ot.unif(n)
94-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
9580

9681
gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
9782
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)

examples/plot_gromov_barycenter.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,12 @@ def im2mat(I):
9191
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
9292

9393

94-
<<<<<<< HEAD
9594
square = spi.imread('../data/carre.png').astype(np.float64) / 256
9695
circle = spi.imread('../data/rond.png').astype(np.float64) / 256
9796
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
9897
arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256
9998

10099
shapes = [square, circle, triangle, arrow]
101-
=======
102-
carre = spi.imread('../data/carre.png').astype(np.float64) / 256
103-
rond = spi.imread('../data/rond.png').astype(np.float64) / 256
104-
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
105-
fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256
106-
107-
shapes = [carre, rond, triangle, fleche]
108-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
109100

110101
S = 4
111102
xs = [[] for i in range(S)]
@@ -127,60 +118,36 @@ def im2mat(I):
127118
The four distributions are constructed from 4 simple images
128119
"""
129120
ns = [len(xs[s]) for s in range(S)]
130-
<<<<<<< HEAD
131121
n_samples = 30
132-
=======
133-
N = 30
134-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
135122

136123
"""Compute all distances matrices for the four shapes"""
137124
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
138125
Cs = [cs / cs.max() for cs in Cs]
139126

140127
ps = [ot.unif(ns[s]) for s in range(S)]
141-
<<<<<<< HEAD
142128
p = ot.unif(n_samples)
143-
=======
144-
p = ot.unif(N)
145-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
146129

147130

148131
lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
149132

150133
Ct01 = [0 for i in range(2)]
151134
for i in range(2):
152-
<<<<<<< HEAD
153135
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], [
154-
=======
155-
Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [
156-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
157136
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
158137

159138
Ct02 = [0 for i in range(2)]
160139
for i in range(2):
161-
<<<<<<< HEAD
162140
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], [
163-
=======
164-
Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [
165-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
166141
ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
167142

168143
Ct13 = [0 for i in range(2)]
169144
for i in range(2):
170-
<<<<<<< HEAD
171145
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], [
172-
=======
173-
Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [
174-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
175146
ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
176147

177148
Ct23 = [0 for i in range(2)]
178149
for i in range(2):
179-
<<<<<<< HEAD
180150
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], [
181-
=======
182-
Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [
183-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
184151
ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
185152

186153
"""

ot/gromov.py

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,7 @@ def update_kl_loss(p, lambdas, T, Cs):
208208
return(np.exp(np.divide(tmpsum, ppt)))
209209

210210

211-
<<<<<<< HEAD
212211
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
213-
=======
214-
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
215-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
216212
"""
217213
Returns the gromov-wasserstein coupling between the two measured similarity matrices
218214
@@ -252,11 +248,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
252248
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
253249
epsilon : float
254250
Regularization term >0
255-
<<<<<<< HEAD
251+
<<<<<<< HEAD
256252
max_iter : int, optional
257-
=======
253+
=======
258254
numItermax : int, optional
259-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
255+
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
260256
Max number of iterations
261257
stopThr : float, optional
262258
Stop threshold on error (>0)
@@ -282,11 +278,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
282278
cpt = 0
283279
err = 1
284280

285-
<<<<<<< HEAD
286281
while (err > stopThr and cpt < max_iter):
287-
=======
288-
while (err > stopThr and cpt < numItermax):
289-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
290282

291283
Tprev = T
292284

@@ -319,11 +311,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
319311
return T
320312

321313

322-
<<<<<<< HEAD
323314
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
324-
=======
325-
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
326-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
327315
"""
328316
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
329317
@@ -358,7 +346,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
358346
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
359347
epsilon : float
360348
Regularization term >0
361-
numItermax : int, optional
349+
max_iter : int, optional
362350
Max number of iterations
363351
stopThr : float, optional
364352
Stop threshold on error (>0)
@@ -378,17 +366,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
378366

379367
if log:
380368
gw, logv = gromov_wasserstein(
381-
<<<<<<< HEAD
382369
C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
383370
else:
384371
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
385372
epsilon, max_iter, stopThr, verbose, log)
386-
=======
387-
C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log)
388-
else:
389-
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
390-
epsilon, numItermax, stopThr, verbose, log)
391-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
392373

393374
if loss_fun == 'square_loss':
394375
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
@@ -402,11 +383,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
402383
return gw_dist
403384

404385

405-
<<<<<<< HEAD
406386
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
407-
=======
408-
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
409-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
410387
"""
411388
Returns the gromov-wasserstein barycenters of S measured similarity matrices
412389
@@ -439,7 +416,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000
439416
with the S Ts couplings calculated at each iteration
440417
epsilon : float
441418
Regularization term >0
442-
numItermax : int, optional
419+
max_iter : int, optional
443420
Max number of iterations
444421
stopThr : float, optional
445422
Stop threshol on error (>0)
@@ -469,21 +446,11 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000
469446

470447
error = []
471448

472-
<<<<<<< HEAD
473449
while(err > stopThr and cpt < max_iter):
474-
=======
475-
while(err > stopThr and cpt < numItermax):
476-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
477-
478450
Cprev = C
479451

480452
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
481-
<<<<<<< HEAD
482453
max_iter, 1e-5, verbose, log) for s in range(S)]
483-
=======
484-
numItermax, 1e-5, verbose, log) for s in range(S)]
485-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
486-
487454
if loss_fun == 'square_loss':
488455
C = update_square_loss(p, lambdas, T, Cs)
489456

test/test_gromov.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,18 @@
1010

1111

1212
def test_gromov():
13-
<<<<<<< HEAD
1413
n_samples = 50 # nb samples
15-
=======
16-
n = 50 # nb samples
17-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
1814

1915
mu_s = np.array([0, 0])
2016
cov_s = np.array([[1, 0], [0, 1]])
2117

22-
<<<<<<< HEAD
2318
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
2419

2520
xt = [xs[n_samples - (i + 1)] for i in range(n_samples)]
2621
xt = np.array(xt)
2722

2823
p = ot.unif(n_samples)
2924
q = ot.unif(n_samples)
30-
=======
31-
xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
32-
33-
xt = [xs[n - (i + 1)] for i in range(n)]
34-
xt = np.array(xt)
35-
36-
p = ot.unif(n)
37-
q = ot.unif(n)
38-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
3925

4026
C1 = ot.dist(xs, xs)
4127
C2 = ot.dist(xt, xt)

0 commit comments

Comments
 (0)