5
5
6
6
import numpy as np
7
7
# import compiled emd
8
- from .emd import emd_c
8
+ from .emd import emd_c , emd2_c
9
+ from ..utils import parmap
9
10
import multiprocessing
10
11
12
+
13
+
11
14
def emd (a , b , M ):
12
15
"""Solves the Earth Movers distance problem and returns the OT matrix
13
16
@@ -145,41 +148,14 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()):
145
148
b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
146
149
147
150
if len (b .shape )== 1 :
148
- return np . sum ( emd_c ( a , b , M ) * M )
151
+ return emd2_c ( a , b , M )
149
152
else :
150
153
nb = b .shape [1 ]
151
- ls = [(a ,b [:,k ],M ) for k in range (nb )]
152
- def f (l ):
153
- return emd2 (l [0 ],l [1 ],l [2 ])
154
- # run emd in multiprocessing
155
- res = parmap (f , ls ,processes )
154
+ #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)]
155
+ def f (b ):
156
+ return emd2_c (a ,b ,M )
157
+ res = parmap (f , [b [:,i ] for i in range (nb )],processes )
156
158
return np .array (res )
157
- # with Pool(processes) as p:
158
- # res=p.map(f, ls)
159
- # return np.array(res)
160
159
161
160
162
- def fun (f , q_in , q_out ):
163
- while True :
164
- i , x = q_in .get ()
165
- if i is None :
166
- break
167
- q_out .put ((i , f (x )))
168
-
169
- def parmap (f , X , nprocs = multiprocessing .cpu_count ()):
170
- q_in = multiprocessing .Queue (1 )
171
- q_out = multiprocessing .Queue ()
172
-
173
- proc = [multiprocessing .Process (target = fun , args = (f , q_in , q_out ))
174
- for _ in range (nprocs )]
175
- for p in proc :
176
- p .daemon = True
177
- p .start ()
178
-
179
- sent = [q_in .put ((i , x )) for i , x in enumerate (X )]
180
- [q_in .put ((None , None )) for _ in range (nprocs )]
181
- res = [q_out .get () for _ in range (len (sent ))]
182
-
183
- [p .join () for p in proc ]
184
-
185
- return [x for i , x in sorted (res )]
161
+
0 commit comments