Skip to content

Commit 36bf599

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
Corrections on Gromov
1 parent 8ea74ad commit 36bf599

File tree

6 files changed

+9
-12
lines changed

6 files changed

+9
-12
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.

examples/plot_gromov.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
"""
2323
Sample two Gaussian distributions (2D and 3D)
2424
=============================================
25-
The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space.
26-
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
25+
The Gromov-Wasserstein distance allows to compute distances with samples that
26+
do not belong to the same metric space. For demonstration purpose, we sample
27+
two Gaussian distributions in 2- and 3-dimensional spaces.
2728
"""
2829

2930
n_samples = 30 # nb samples

examples/plot_gromov_barycenter.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,10 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
4848
eps : float
4949
relative tolerance w.r.t stress to declare converge
5050
51-
5251
Returns
5352
-------
5453
npos : ndarray, shape (R, dim)
5554
Embedded coordinates of the interpolated point cloud (defined with one isometry)
56-
57-
5855
"""
5956

6057
rng = np.random.RandomState(seed=3)
@@ -91,12 +88,12 @@ def im2mat(I):
9188
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
9289

9390

94-
square = spi.imread('../data/carre.png').astype(np.float64) / 256
95-
circle = spi.imread('../data/rond.png').astype(np.float64) / 256
96-
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
97-
arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256
91+
square = spi.imread('../data/square.png').astype(np.float64)[:,:,2] / 256
92+
cross = spi.imread('../data/cross.png').astype(np.float64)[:,:,2] / 256
93+
triangle = spi.imread('../data/triangle.png').astype(np.float64)[:,:,2] / 256
94+
star = spi.imread('../data/star.png').astype(np.float64)[:,:,2] / 256
9895

99-
shapes = [square, circle, triangle, arrow]
96+
shapes = [square, cross, triangle, star]
10097

10198
S = 4
10299
xs = [[] for i in range(S)]

test/test_gromov.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ def test_gromov():
1717

1818
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
1919

20-
xt = xs[::-1]
21-
xt = np.array(xt)
20+
xt = xs[::-1].copy()
2221

2322
p = ot.unif(n_samples)
2423
q = ot.unif(n_samples)

0 commit comments

Comments
 (0)