12
12
import numpy as np
13
13
14
14
# import compiled emd
15
- from .emd_wrap import emd_c , checkResult
15
+ from .emd_wrap import emd_c , check_result
16
16
from ..utils import parmap
17
17
18
18
19
- def emd (a , b , M , numItermax = 100000 , log = False ):
19
+ def emd (a , b , M , num_iter_max = 100000 , log = False ):
20
20
"""Solves the Earth Movers distance problem and returns the OT matrix
21
21
22
22
@@ -41,7 +41,7 @@ def emd(a, b, M, numItermax=100000, log=False):
41
41
Target histogram (uniform weigth if empty list)
42
42
M : (ns,nt) ndarray, float64
43
43
loss matrix
44
- numItermax : int, optional (default=100000)
44
+ num_iter_max : int, optional (default=100000)
45
45
The maximum number of iterations before stopping the optimization
46
46
algorithm if it has not converged.
47
47
log: boolean, optional (default=False)
@@ -54,7 +54,7 @@ def emd(a, b, M, numItermax=100000, log=False):
54
54
Optimal transportation matrix for the given parameters
55
55
log: dict
56
56
If input log is true, a dictionary containing the cost and dual
57
- variables
57
+ variables and exit status
58
58
59
59
60
60
Examples
@@ -94,20 +94,20 @@ def emd(a, b, M, numItermax=100000, log=False):
94
94
if len (b ) == 0 :
95
95
b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
96
96
97
- G , cost , u , v , resultCode = emd_c (a , b , M , numItermax )
98
- resultCodeString = checkResult ( resultCode )
97
+ G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
98
+ resultCodeString = check_result ( result_code )
99
99
if log :
100
100
log = {}
101
101
log ['cost' ] = cost
102
102
log ['u' ] = u
103
103
log ['v' ] = v
104
104
log ['warning' ] = resultCodeString
105
- log ['resultCode ' ] = resultCode
105
+ log ['result_code ' ] = result_code
106
106
return G , log
107
107
return G
108
108
109
109
110
- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), numItermax = 100000 , log = False ):
110
+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False ):
111
111
"""Solves the Earth Movers distance problem and returns the loss
112
112
113
113
.. math::
@@ -131,14 +131,17 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
131
131
Target histogram (uniform weigth if empty list)
132
132
M : (ns,nt) ndarray, float64
133
133
loss matrix
134
- numItermax : int, optional (default=100000)
134
+ num_iter_max : int, optional (default=100000)
135
135
The maximum number of iterations before stopping the optimization
136
136
algorithm if it has not converged.
137
137
138
138
Returns
139
139
-------
140
140
gamma: (ns x nt) ndarray
141
141
Optimal transportation matrix for the given parameters
142
+ log: dict
143
+ If input log is true, a dictionary containing the cost and dual
144
+ variables and exit status
142
145
143
146
144
147
Examples
@@ -180,19 +183,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
180
183
181
184
if log :
182
185
def f (b ):
183
- G , cost , u , v , resultCode = emd_c (a , b , M , numItermax )
184
- resultCodeString = checkResult (resultCode )
186
+ G , cost , u , v , resultCode = emd_c (a , b , M , num_iter_max )
187
+ resultCodeString = check_result (resultCode )
185
188
log = {}
186
189
log ['G' ] = G
187
190
log ['u' ] = u
188
191
log ['v' ] = v
189
192
log ['warning' ] = resultCodeString
190
- log ['resultCode ' ] = resultCode
193
+ log ['result_code ' ] = resultCode
191
194
return [cost , log ]
192
195
else :
193
196
def f (b ):
194
- G , cost , u , v , resultCode = emd_c (a , b , M , numItermax )
195
- checkResult ( resultCode )
197
+ G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
198
+ check_result ( result_code )
196
199
return cost
197
200
198
201
if len (b .shape ) == 1 :
0 commit comments