@@ -107,7 +107,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
107
107
return G
108
108
109
109
110
- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False ):
110
+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False , return_matrix = False ):
111
111
"""Solves the Earth Movers distance problem and returns the loss
112
112
113
113
.. math::
@@ -134,6 +134,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
134
134
num_iter_max : int, optional (default=100000)
135
135
The maximum number of iterations before stopping the optimization
136
136
algorithm if it has not converged.
137
+ log: boolean, optional (default=False)
138
+ If True, returns a dictionary containing the cost and dual
139
+ variables. Otherwise returns only the optimal transportation cost.
140
+ return_matrix: boolean, optional (default=False)
141
+ If True, returns the optimal transportation matrix in the log.
137
142
138
143
Returns
139
144
-------
@@ -181,12 +186,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
181
186
if len (b ) == 0 :
182
187
b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
183
188
184
- if log :
189
+ if log or return_matrix :
185
190
def f (b ):
186
191
G , cost , u , v , resultCode = emd_c (a , b , M , num_iter_max )
187
192
result_code_string = check_result (resultCode )
188
193
log = {}
189
- log ['G' ] = G
194
+ if return_matrix :
195
+ log ['G' ] = G
190
196
log ['u' ] = u
191
197
log ['v' ] = v
192
198
log ['warning' ] = result_code_string
0 commit comments