Skip to content

Commit 85c56d9

Browse files
committed
Renamed variables
1 parent c4aca9e commit 85c56d9

File tree

4 files changed

+28
-25
lines changed

4 files changed

+28
-25
lines changed

ot/da.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1370,7 +1370,7 @@ class label
13701370

13711371
# coupling estimation
13721372
self.coupling_ = emd(
1373-
a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter
1373+
a=self.mu_s, b=self.mu_t, M=self.cost_, num_iter_max=self.max_iter
13741374
)
13751375

13761376
return self

ot/lp/__init__.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import numpy as np
1313

1414
# import compiled emd
15-
from .emd_wrap import emd_c, checkResult
15+
from .emd_wrap import emd_c, check_result
1616
from ..utils import parmap
1717

1818

19-
def emd(a, b, M, numItermax=100000, log=False):
19+
def emd(a, b, M, num_iter_max=100000, log=False):
2020
"""Solves the Earth Movers distance problem and returns the OT matrix
2121
2222
@@ -41,7 +41,7 @@ def emd(a, b, M, numItermax=100000, log=False):
4141
Target histogram (uniform weigth if empty list)
4242
M : (ns,nt) ndarray, float64
4343
loss matrix
44-
numItermax : int, optional (default=100000)
44+
num_iter_max : int, optional (default=100000)
4545
The maximum number of iterations before stopping the optimization
4646
algorithm if it has not converged.
4747
log: boolean, optional (default=False)
@@ -54,7 +54,7 @@ def emd(a, b, M, numItermax=100000, log=False):
5454
Optimal transportation matrix for the given parameters
5555
log: dict
5656
If input log is true, a dictionary containing the cost and dual
57-
variables
57+
variables and exit status
5858
5959
6060
Examples
@@ -94,20 +94,20 @@ def emd(a, b, M, numItermax=100000, log=False):
9494
if len(b) == 0:
9595
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
9696

97-
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
98-
resultCodeString = checkResult(resultCode)
97+
G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max)
98+
resultCodeString = check_result(result_code)
9999
if log:
100100
log = {}
101101
log['cost'] = cost
102102
log['u'] = u
103103
log['v'] = v
104104
log['warning'] = resultCodeString
105-
log['resultCode'] = resultCode
105+
log['result_code'] = result_code
106106
return G, log
107107
return G
108108

109109

110-
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False):
110+
def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False):
111111
"""Solves the Earth Movers distance problem and returns the loss
112112
113113
.. math::
@@ -131,14 +131,17 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
131131
Target histogram (uniform weigth if empty list)
132132
M : (ns,nt) ndarray, float64
133133
loss matrix
134-
numItermax : int, optional (default=100000)
134+
num_iter_max : int, optional (default=100000)
135135
The maximum number of iterations before stopping the optimization
136136
algorithm if it has not converged.
137137
138138
Returns
139139
-------
140140
gamma: (ns x nt) ndarray
141141
Optimal transportation matrix for the given parameters
142+
log: dict
143+
If input log is true, a dictionary containing the cost and dual
144+
variables and exit status
142145
143146
144147
Examples
@@ -180,19 +183,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
180183

181184
if log:
182185
def f(b):
183-
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
184-
resultCodeString = checkResult(resultCode)
186+
G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max)
187+
resultCodeString = check_result(resultCode)
185188
log = {}
186189
log['G'] = G
187190
log['u'] = u
188191
log['v'] = v
189192
log['warning'] = resultCodeString
190-
log['resultCode'] = resultCode
193+
log['result_code'] = resultCode
191194
return [cost, log]
192195
else:
193196
def f(b):
194-
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
195-
checkResult(resultCode)
197+
G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max)
198+
check_result(result_code)
196199
return cost
197200

198201
if len(b.shape) == 1:

ot/lp/emd_wrap.pyx

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,23 @@ cdef extern from "EMD.h":
2020
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2121

2222

23-
def checkResult(resultCode):
24-
if resultCode == OPTIMAL:
23+
def check_result(result_code):
24+
if result_code == OPTIMAL:
2525
return None
2626

27-
if resultCode == INFEASIBLE:
27+
if result_code == INFEASIBLE:
2828
message = "Problem infeasible. Check that a and b are in the simplex"
29-
elif resultCode == UNBOUNDED:
29+
elif result_code == UNBOUNDED:
3030
message = "Problem unbounded"
31-
elif resultCode == MAX_ITER_REACHED:
31+
elif result_code == MAX_ITER_REACHED:
3232
message = "numItermax reached before optimality. Try to increase numItermax."
3333
warnings.warn(message)
3434
return message
3535

3636

3737
@cython.boundscheck(False)
3838
@cython.wraparound(False)
39-
def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
39+
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int num_iter_max):
4040
"""
4141
Solves the Earth Movers distance problem and returns the optimal transport matrix
4242
@@ -63,7 +63,7 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
6363
target histogram
6464
M : (ns,nt) ndarray, float64
6565
loss matrix
66-
numItermax : int
66+
num_iter_max : int
6767
The maximum number of iterations before stopping the optimization
6868
algorithm if it has not converged.
6969
@@ -90,6 +90,6 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
9090
b=np.ones((n2,))/n2
9191

9292
# calling the function
93-
cdef int resultCode = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
93+
cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, num_iter_max)
9494

95-
return G, cost, alpha, beta, resultCode
95+
return G, cost, alpha, beta, result_code

test/test_ot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_warnings():
140140
with warnings.catch_warnings(record=True) as w:
141141
warnings.simplefilter("always")
142142
print('Computing {} EMD '.format(1))
143-
ot.emd(a, b, M, numItermax=1)
143+
ot.emd(a, b, M, num_iter_max=1)
144144
assert "numItermax" in str(w[-1].message)
145145
assert len(w) == 1
146146
a[0] = 100

0 commit comments

Comments
 (0)