Skip to content

Commit feef989

Browse files
committed
Rename for emd and emd2
1 parent 982f36c commit feef989

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ot/lp/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import multiprocessing
1515

1616

17-
def emd(a, b, M, max_iter=100000):
17+
def emd(a, b, M, numItermax=100000):
1818
"""Solves the Earth Movers distance problem and returns the OT matrix
1919
2020
@@ -39,7 +39,7 @@ def emd(a, b, M, max_iter=100000):
3939
Target histogram (uniform weigth if empty list)
4040
M : (ns,nt) ndarray, float64
4141
loss matrix
42-
max_iter : int, optional (default=100000)
42+
numItermax : int, optional (default=100000)
4343
The maximum number of iterations before stopping the optimization
4444
algorithm if it has not converged.
4545
@@ -86,10 +86,10 @@ def emd(a, b, M, max_iter=100000):
8686
if len(b) == 0:
8787
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
8888

89-
return emd_c(a, b, M, max_iter)
89+
return emd_c(a, b, M, numItermax)
9090

9191

92-
def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000):
92+
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
9393
"""Solves the Earth Movers distance problem and returns the loss
9494
9595
.. math::
@@ -113,7 +113,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000):
113113
Target histogram (uniform weigth if empty list)
114114
M : (ns,nt) ndarray, float64
115115
loss matrix
116-
max_iter : int, optional (default=100000)
116+
numItermax : int, optional (default=100000)
117117
The maximum number of iterations before stopping the optimization
118118
algorithm if it has not converged.
119119
@@ -161,12 +161,12 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000):
161161
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
162162

163163
if len(b.shape) == 1:
164-
return emd2_c(a, b, M, max_iter)
164+
return emd2_c(a, b, M, numItermax)
165165
else:
166166
nb = b.shape[1]
167-
# res = [emd2_c(a, b[:, i].copy(), M, max_iter) for i in range(nb)]
167+
# res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
168168

169169
def f(b):
170-
return emd2_c(a, b, M, max_iter)
170+
return emd2_c(a, b, M, numItermax)
171171
res = parmap(f, [b[:, i] for i in range(nb)], processes)
172172
return np.array(res)

0 commit comments

Comments
 (0)