16
16
from ..utils import parmap
17
17
18
18
19
- def emd (a , b , M , max_iter = 100000 , log = False ):
19
+ def emd (a , b , M , num_iter_max = 100000 , log = False ):
20
20
"""Solves the Earth Movers distance problem and returns the OT matrix
21
21
22
22
@@ -41,7 +41,7 @@ def emd(a, b, M, max_iter=100000, log=False):
41
41
Target histogram (uniform weigth if empty list)
42
42
M : (ns,nt) ndarray, float64
43
43
loss matrix
44
- max_iter : int, optional (default=100000)
44
+ num_iter_max : int, optional (default=100000)
45
45
The maximum number of iterations before stopping the optimization
46
46
algorithm if it has not converged.
47
47
log: boolean, optional (default=False)
@@ -94,7 +94,7 @@ def emd(a, b, M, max_iter=100000, log=False):
94
94
if len (b ) == 0 :
95
95
b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
96
96
97
- G , cost , u , v , result_code = emd_c (a , b , M , max_iter )
97
+ G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
98
98
result_code_string = check_result (result_code )
99
99
if log :
100
100
log = {}
@@ -107,7 +107,7 @@ def emd(a, b, M, max_iter=100000, log=False):
107
107
return G
108
108
109
109
110
- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), max_iter = 100000 , log = False ):
110
+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False ):
111
111
"""Solves the Earth Movers distance problem and returns the loss
112
112
113
113
.. math::
@@ -183,7 +183,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa
183
183
184
184
if log :
185
185
def f (b ):
186
- G , cost , u , v , resultCode = emd_c (a , b , M , max_iter )
186
+ G , cost , u , v , resultCode = emd_c (a , b , M , num_iter_max )
187
187
result_code_string = check_result (resultCode )
188
188
log = {}
189
189
log ['G' ] = G
@@ -194,7 +194,7 @@ def f(b):
194
194
return [cost , log ]
195
195
else :
196
196
def f (b ):
197
- G , cost , u , v , result_code = emd_c (a , b , M , max_iter )
197
+ G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
198
198
check_result (result_code )
199
199
return cost
200
200
0 commit comments