Skip to content

Commit 42772c2

Browse files
authored
Merge branch 'master' into doc_travis
2 parents 8933a84 + f10f323 commit 42772c2

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

examples/plot_partial_wass_and_gromov.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
==========================
3+
==================================================
44
Partial Wasserstein and Gromov-Wasserstein example
5-
==========================
5+
==================================================
66
77
This example is designed to show how to use the Partial (Gromov-)Wassertsein
88
distance computation in POT.
@@ -52,8 +52,7 @@
5252

5353
#############################################################################
5454
#
55-
# Compute partial Wasserstein plans and distance,
56-
# by transporting 50% of the mass
55+
# Compute partial Wasserstein plans and distance
5756
# ----------------------------------------------
5857

5958
p = ot.unif(n_samples + n_noise)
@@ -115,34 +114,33 @@
115114

116115
#############################################################################
117116
#
118-
# Compute partial Gromov-Wasserstein plans and distance,
119-
# by transporting 100% and 2/3 of the mass
117+
# Compute partial Gromov-Wasserstein plans and distance
120118
# -----------------------------------------------------
121119

122120
C1 = sp.spatial.distance.cdist(xs, xs)
123121
C2 = sp.spatial.distance.cdist(xt, xt)
124122

123+
# transport 100% of the mass
125124
print('-----m = 1')
126125
m = 1
127-
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
128-
log=True)
126+
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
129127
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
130128
m=m, log=True)
131129

132-
print('Partial Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
133-
print('Entropic partial Wasserstein distance (m = 1): ' +
134-
str(log['partial_gw_dist']))
130+
print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
131+
print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))
135132

136133
pl.figure(1, (10, 5))
137134
pl.title("mass to be transported m = 1")
138135
pl.subplot(1, 2, 1)
139136
pl.imshow(res0, cmap='jet')
140-
pl.title('Partial Wasserstein')
137+
pl.title('Wasserstein')
141138
pl.subplot(1, 2, 2)
142139
pl.imshow(res, cmap='jet')
143-
pl.title('Entropic partial Wasserstein')
140+
pl.title('Entropic Wasserstein')
144141
pl.show()
145142

143+
# transport 2/3 of the mass
146144
print('-----m = 2/3')
147145
m = 2 / 3
148146
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)

ot/partial.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
6666
instabilities, increase its value if an error is raised)
6767
log : bool, optional
6868
record log if True
69+
**kwargs : dict
70+
parameters can be directly passed to the emd solver
6971
7072
.. warning::
7173
When dealing with a large number of points, the EMD solver may face
@@ -190,6 +192,8 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
190192
instabilities, increase its value if an error is raised)
191193
log : bool, optional
192194
record log if True
195+
**kwargs : dict
196+
parameters can be directly passed to the emd solver
193197
194198
195199
.. warning::
@@ -304,6 +308,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
304308
instabilities, increase its value if an error is raised)
305309
log : bool, optional
306310
record log if True
311+
**kwargs : dict
312+
parameters can be directly passed to the emd solver
307313
308314
309315
.. warning::

0 commit comments

Comments
 (0)