Skip to content

Commit cd8c042

Browse files
committed
Renamed variable
1 parent 1ba2c83 commit cd8c042

File tree

6 files changed

+16
-16
lines changed

6 files changed

+16
-16
lines changed

ot/da.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1370,7 +1370,7 @@ class label
13701370

13711371
# coupling estimation
13721372
self.coupling_ = emd(
1373-
a=self.mu_s, b=self.mu_t, M=self.cost_, num_iter_max=self.max_iter
1373+
a=self.mu_s, b=self.mu_t, M=self.cost_, max_iter=self.max_iter
13741374
)
13751375

13761376
return self

ot/lp/EMD.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ enum ProblemType {
3030
MAX_ITER_REACHED
3131
};
3232

33-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter);
33+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
3434

3535
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
19-
double* alpha, double* beta, double *cost, int max_iter) {
19+
double* alpha, double* beta, double *cost, int maxIter) {
2020
// beware M and C anre strored in row major C style!!!
2121
int n, m, i, cur;
2222

@@ -48,7 +48,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
4848
std::vector<int> indI(n), indJ(m);
4949
std::vector<double> weights1(n), weights2(m);
5050
Digraph di(n, m);
51-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, max_iter);
51+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
5252

5353
// Set supply and demand, don't account for 0 values (faster)
5454

ot/lp/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..utils import parmap
1717

1818

19-
def emd(a, b, M, num_iter_max=100000, log=False):
19+
def emd(a, b, M, max_iter=100000, log=False):
2020
"""Solves the Earth Movers distance problem and returns the OT matrix
2121
2222
@@ -41,7 +41,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
4141
Target histogram (uniform weigth if empty list)
4242
M : (ns,nt) ndarray, float64
4343
loss matrix
44-
num_iter_max : int, optional (default=100000)
44+
max_iter : int, optional (default=100000)
4545
The maximum number of iterations before stopping the optimization
4646
algorithm if it has not converged.
4747
log: boolean, optional (default=False)
@@ -94,7 +94,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
9494
if len(b) == 0:
9595
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
9696

97-
G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max)
97+
G, cost, u, v, result_code = emd_c(a, b, M, max_iter)
9898
result_code_string = check_result(result_code)
9999
if log:
100100
log = {}
@@ -107,7 +107,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
107107
return G
108108

109109

110-
def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False):
110+
def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=False):
111111
"""Solves the Earth Movers distance problem and returns the loss
112112
113113
.. math::
@@ -131,7 +131,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
131131
Target histogram (uniform weigth if empty list)
132132
M : (ns,nt) ndarray, float64
133133
loss matrix
134-
num_iter_max : int, optional (default=100000)
134+
max_iter : int, optional (default=100000)
135135
The maximum number of iterations before stopping the optimization
136136
algorithm if it has not converged.
137137
@@ -183,7 +183,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
183183

184184
if log:
185185
def f(b):
186-
G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max)
186+
G, cost, u, v, resultCode = emd_c(a, b, M, max_iter)
187187
result_code_string = check_result(resultCode)
188188
log = {}
189189
log['G'] = G
@@ -194,7 +194,7 @@ def f(b):
194194
return [cost, log]
195195
else:
196196
def f(b):
197-
G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max)
197+
G, cost, u, v, result_code = emd_c(a, b, M, max_iter)
198198
check_result(result_code)
199199
return cost
200200

ot/lp/emd_wrap.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import warnings
1616

1717

1818
cdef extern from "EMD.h":
19-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int numItermax)
19+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter)
2020
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2121

2222

@@ -36,7 +36,7 @@ def check_result(result_code):
3636

3737
@cython.boundscheck(False)
3838
@cython.wraparound(False)
39-
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int num_iter_max):
39+
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
4040
"""
4141
Solves the Earth Movers distance problem and returns the optimal transport matrix
4242
@@ -63,7 +63,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
6363
target histogram
6464
M : (ns,nt) ndarray, float64
6565
loss matrix
66-
num_iter_max : int
66+
max_iter : int
6767
The maximum number of iterations before stopping the optimization
6868
algorithm if it has not converged.
6969
@@ -90,6 +90,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
9090
b=np.ones((n2,))/n2
9191

9292
# calling the function
93-
cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, num_iter_max)
93+
cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
9494

9595
return G, cost, alpha, beta, result_code

test/test_ot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_warnings():
140140
with warnings.catch_warnings(record=True) as w:
141141
warnings.simplefilter("always")
142142
print('Computing {} EMD '.format(1))
143-
ot.emd(a, b, M, num_iter_max=1)
143+
ot.emd(a, b, M, max_iter=1)
144144
assert "numItermax" in str(w[-1].message)
145145
assert len(w) == 1
146146
a[0] = 100

0 commit comments

Comments
 (0)