14
14
import multiprocessing
15
15
16
16
17
- def emd (a , b , M , max_iter = 100000 ):
17
+ def emd (a , b , M , numItermax = 100000 ):
18
18
"""Solves the Earth Movers distance problem and returns the OT matrix
19
19
20
20
@@ -39,7 +39,7 @@ def emd(a, b, M, max_iter=100000):
39
39
Target histogram (uniform weigth if empty list)
40
40
M : (ns,nt) ndarray, float64
41
41
loss matrix
42
- max_iter : int, optional (default=100000)
42
+ numItermax : int, optional (default=100000)
43
43
The maximum number of iterations before stopping the optimization
44
44
algorithm if it has not converged.
45
45
@@ -86,10 +86,10 @@ def emd(a, b, M, max_iter=100000):
86
86
if len (b ) == 0 :
87
87
b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
88
88
89
- return emd_c (a , b , M , max_iter )
89
+ return emd_c (a , b , M , numItermax )
90
90
91
91
92
- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), max_iter = 100000 ):
92
+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), numItermax = 100000 ):
93
93
"""Solves the Earth Movers distance problem and returns the loss
94
94
95
95
.. math::
@@ -113,7 +113,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000):
113
113
Target histogram (uniform weigth if empty list)
114
114
M : (ns,nt) ndarray, float64
115
115
loss matrix
116
- max_iter : int, optional (default=100000)
116
+ numItermax : int, optional (default=100000)
117
117
The maximum number of iterations before stopping the optimization
118
118
algorithm if it has not converged.
119
119
@@ -161,12 +161,12 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000):
161
161
b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
162
162
163
163
if len (b .shape ) == 1 :
164
- return emd2_c (a , b , M , max_iter )
164
+ return emd2_c (a , b , M , numItermax )
165
165
else :
166
166
nb = b .shape [1 ]
167
- # res = [emd2_c(a, b[:, i].copy(), M, max_iter ) for i in range(nb)]
167
+ # res = [emd2_c(a, b[:, i].copy(), M, numItermax ) for i in range(nb)]
168
168
169
169
def f (b ):
170
- return emd2_c (a , b , M , max_iter )
170
+ return emd2_c (a , b , M , numItermax )
171
171
res = parmap (f , [b [:, i ] for i in range (nb )], processes )
172
172
return np .array (res )
0 commit comments