Skip to content

Commit 06429e5

Browse files
committed
Returned to old variable name to follow repo convention
1 parent 8cc04ef commit 06429e5

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

ot/da.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1370,7 +1370,7 @@ class label
13701370

13711371
# coupling estimation
13721372
self.coupling_ = emd(
1373-
a=self.mu_s, b=self.mu_t, M=self.cost_, max_iter=self.max_iter
1373+
a=self.mu_s, b=self.mu_t, M=self.cost_, num_iter_max=self.max_iter
13741374
)
13751375

13761376
return self

ot/lp/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..utils import parmap
1717

1818

19-
def emd(a, b, M, max_iter=100000, log=False):
19+
def emd(a, b, M, num_iter_max=100000, log=False):
2020
"""Solves the Earth Movers distance problem and returns the OT matrix
2121
2222
@@ -41,7 +41,7 @@ def emd(a, b, M, max_iter=100000, log=False):
4141
Target histogram (uniform weigth if empty list)
4242
M : (ns,nt) ndarray, float64
4343
loss matrix
44-
max_iter : int, optional (default=100000)
44+
num_iter_max : int, optional (default=100000)
4545
The maximum number of iterations before stopping the optimization
4646
algorithm if it has not converged.
4747
log: boolean, optional (default=False)
@@ -94,7 +94,7 @@ def emd(a, b, M, max_iter=100000, log=False):
9494
if len(b) == 0:
9595
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
9696

97-
G, cost, u, v, result_code = emd_c(a, b, M, max_iter)
97+
G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max)
9898
result_code_string = check_result(result_code)
9999
if log:
100100
log = {}
@@ -107,7 +107,7 @@ def emd(a, b, M, max_iter=100000, log=False):
107107
return G
108108

109109

110-
def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=False):
110+
def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False):
111111
"""Solves the Earth Movers distance problem and returns the loss
112112
113113
.. math::
@@ -183,7 +183,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa
183183

184184
if log:
185185
def f(b):
186-
G, cost, u, v, resultCode = emd_c(a, b, M, max_iter)
186+
G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max)
187187
result_code_string = check_result(resultCode)
188188
log = {}
189189
log['G'] = G
@@ -194,7 +194,7 @@ def f(b):
194194
return [cost, log]
195195
else:
196196
def f(b):
197-
G, cost, u, v, result_code = emd_c(a, b, M, max_iter)
197+
G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max)
198198
check_result(result_code)
199199
return cost
200200

ot/lp/emd_wrap.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def check_result(result_code):
2929
elif result_code == UNBOUNDED:
3030
message = "Problem unbounded"
3131
elif result_code == MAX_ITER_REACHED:
32-
message = "max_iter reached before optimality. Try to increase max_iter."
32+
message = "num_iter_max reached before optimality. Try to increase num_iter_max."
3333
warnings.warn(message)
3434
return message
3535

test/test_ot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ def test_warnings():
140140
with warnings.catch_warnings(record=True) as w:
141141
warnings.simplefilter("always")
142142
print('Computing {} EMD '.format(1))
143-
ot.emd(a, b, M, max_iter=1)
144-
assert "max_iter" in str(w[-1].message)
143+
ot.emd(a, b, M, num_iter_max=1)
144+
assert "num_iter_max" in str(w[-1].message)
145145
assert len(w) == 1
146146
a[0] = 100
147147
print('Computing {} EMD '.format(2))

0 commit comments

Comments
 (0)