Skip to content

Commit a4426fd

Browse files
authored
Merge branch 'master' into laplace_da
2 parents 470fce2 + f10f323 commit a4426fd

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

examples/plot_partial_wass_and_gromov.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
==========================
3+
==================================================
44
Partial Wasserstein and Gromov-Wasserstein example
5-
==========================
5+
==================================================
66
77
This example is designed to show how to use the Partial (Gromov-)Wassertsein
88
distance computation in POT.
@@ -50,8 +50,7 @@
5050

5151
#############################################################################
5252
#
53-
# Compute partial Wasserstein plans and distance,
54-
# by transporting 50% of the mass
53+
# Compute partial Wasserstein plans and distance
5554
# ----------------------------------------------
5655

5756
p = ot.unif(n_samples + n_noise)
@@ -113,34 +112,33 @@
113112

114113
#############################################################################
115114
#
116-
# Compute partial Gromov-Wasserstein plans and distance,
117-
# by transporting 100% and 2/3 of the mass
115+
# Compute partial Gromov-Wasserstein plans and distance
118116
# -----------------------------------------------------
119117

120118
C1 = sp.spatial.distance.cdist(xs, xs)
121119
C2 = sp.spatial.distance.cdist(xt, xt)
122120

121+
# transport 100% of the mass
123122
print('-----m = 1')
124123
m = 1
125-
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
126-
log=True)
124+
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
127125
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
128126
m=m, log=True)
129127

130-
print('Partial Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
131-
print('Entropic partial Wasserstein distance (m = 1): ' +
132-
str(log['partial_gw_dist']))
128+
print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
129+
print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))
133130

134131
pl.figure(1, (10, 5))
135132
pl.title("mass to be transported m = 1")
136133
pl.subplot(1, 2, 1)
137134
pl.imshow(res0, cmap='jet')
138-
pl.title('Partial Wasserstein')
135+
pl.title('Wasserstein')
139136
pl.subplot(1, 2, 2)
140137
pl.imshow(res, cmap='jet')
141-
pl.title('Entropic partial Wasserstein')
138+
pl.title('Entropic Wasserstein')
142139
pl.show()
143140

141+
# transport 2/3 of the mass
144142
print('-----m = 2/3')
145143
m = 2 / 3
146144
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)

ot/partial.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
6666
instabilities, increase its value if an error is raised)
6767
log : bool, optional
6868
record log if True
69+
**kwargs : dict
70+
parameters can be directly passed to the emd solver
6971
7072
.. warning::
7173
When dealing with a large number of points, the EMD solver may face
@@ -190,6 +192,8 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
190192
instabilities, increase its value if an error is raised)
191193
log : bool, optional
192194
record log if True
195+
**kwargs : dict
196+
parameters can be directly passed to the emd solver
193197
194198
195199
.. warning::
@@ -304,6 +308,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
304308
instabilities, increase its value if an error is raised)
305309
log : bool, optional
306310
record log if True
311+
**kwargs : dict
312+
parameters can be directly passed to the emd solver
307313
308314
309315
.. warning::

0 commit comments

Comments
 (0)