16
16
from ..utils import parmap
17
17
18
18
19
- def emd (a , b , M , num_iter_max = 100000 , log = False ):
19
+ def emd (a , b , M , numItermax = 100000 , log = False ):
20
20
"""Solves the Earth Movers distance problem and returns the OT matrix
21
21
22
22
@@ -41,7 +41,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
41
41
Target histogram (uniform weigth if empty list)
42
42
M : (ns,nt) ndarray, float64
43
43
loss matrix
44
- num_iter_max : int, optional (default=100000)
44
+ numItermax : int, optional (default=100000)
45
45
The maximum number of iterations before stopping the optimization
46
46
algorithm if it has not converged.
47
47
log: boolean, optional (default=False)
@@ -94,7 +94,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
94
94
if len (b ) == 0 :
95
95
b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
96
96
97
- G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
97
+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax )
98
98
result_code_string = check_result (result_code )
99
99
if log :
100
100
log = {}
@@ -107,7 +107,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
107
107
return G
108
108
109
109
110
- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False , return_matrix = False ):
110
+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), numItermax = 100000 , log = False , return_matrix = False ):
111
111
"""Solves the Earth Movers distance problem and returns the loss
112
112
113
113
.. math::
@@ -131,7 +131,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
131
131
Target histogram (uniform weigth if empty list)
132
132
M : (ns,nt) ndarray, float64
133
133
loss matrix
134
- num_iter_max : int, optional (default=100000)
134
+ numItermax : int, optional (default=100000)
135
135
The maximum number of iterations before stopping the optimization
136
136
algorithm if it has not converged.
137
137
log: boolean, optional (default=False)
@@ -188,7 +188,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
188
188
189
189
if log or return_matrix :
190
190
def f (b ):
191
- G , cost , u , v , resultCode = emd_c (a , b , M , num_iter_max )
191
+ G , cost , u , v , resultCode = emd_c (a , b , M , numItermax )
192
192
result_code_string = check_result (resultCode )
193
193
log = {}
194
194
if return_matrix :
@@ -200,7 +200,7 @@ def f(b):
200
200
return [cost , log ]
201
201
else :
202
202
def f (b ):
203
- G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
203
+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax )
204
204
check_result (result_code )
205
205
return cost
206
206
0 commit comments