Skip to content

Commit dd6f826

Browse files
committed
Made the return of the matrix optional in emd2
1 parent 7c61692 commit dd6f826

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

ot/lp/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
107107
return G
108108

109109

110-
def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False):
110+
def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False, return_matrix=False):
111111
"""Solves the Earth Movers distance problem and returns the loss
112112
113113
.. math::
@@ -134,6 +134,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
134134
num_iter_max : int, optional (default=100000)
135135
The maximum number of iterations before stopping the optimization
136136
algorithm if it has not converged.
137+
log: boolean, optional (default=False)
138+
If True, returns a dictionary containing the cost and dual
139+
variables. Otherwise returns only the optimal transportation cost.
140+
return_matrix: boolean, optional (default=False)
141+
If True, returns the optimal transportation matrix in the log.
137142
138143
Returns
139144
-------
@@ -181,12 +186,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
181186
if len(b) == 0:
182187
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
183188

184-
if log:
189+
if log or return_matrix:
185190
def f(b):
186191
G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max)
187192
result_code_string = check_result(resultCode)
188193
log = {}
189-
log['G'] = G
194+
if return_matrix:
195+
log['G'] = G
190196
log['u'] = u
191197
log['v'] = v
192198
log['warning'] = result_code_string

test/test_ot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_emd2_multi():
103103

104104
# emd loss multipro proc with log
105105
ot.tic()
106-
emdn = ot.emd2(a, b, M, log=True)
106+
emdn = ot.emd2(a, b, M, log=True, return_matrix=True)
107107
ot.toc('multi proc : {} s')
108108

109109
for i in range(len(emdn)):
@@ -140,8 +140,8 @@ def test_warnings():
140140
with warnings.catch_warnings(record=True) as w:
141141
warnings.simplefilter("always")
142142
print('Computing {} EMD '.format(1))
143-
ot.emd(a, b, M, num_iter_max=1)
144-
assert "num_iter_max" in str(w[-1].message)
143+
ot.emd(a, b, M, numItermax=1)
144+
assert "numItermax" in str(w[-1].message)
145145
assert len(w) == 1
146146
a[0] = 100
147147
print('Computing {} EMD '.format(2))

0 commit comments

Comments
 (0)