Skip to content

Commit a84f2c3

Browse files
committed
add emd2+ multiproc
1 parent 84219d9 commit a84f2c3

File tree

6 files changed

+1683
-2069
lines changed

6 files changed

+1683
-2069
lines changed

ot/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020

2121
__version__ = "0.1.12"
2222

23-
__all__ = ["emd", "emd2", "sinkhorn", "utils", 'datasets', 'bregman', 'lp', 'plot',
23+
__all__ = ["emd", "emd2", "sinkhorn", "utils", 'datasets', 'bregman', 'lp',
24+
'plot', 'tic', 'toc', 'toq',
2425
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

ot/lp/__init__.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55

66
import numpy as np
77
# import compiled emd
8-
from .emd import emd_c
8+
from .emd import emd_c, emd2_c
9+
from ..utils import parmap
910
import multiprocessing
1011

12+
13+
1114
def emd(a, b, M):
1215
"""Solves the Earth Movers distance problem and returns the OT matrix
1316
@@ -145,41 +148,14 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()):
145148
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
146149

147150
if len(b.shape)==1:
148-
return np.sum(emd_c(a, b, M)*M)
151+
return emd2_c(a, b, M)
149152
else:
150153
nb=b.shape[1]
151-
ls=[(a,b[:,k],M) for k in range(nb)]
152-
def f(l):
153-
return emd2(l[0],l[1],l[2])
154-
# run emd in multiprocessing
155-
res=parmap(f, ls,processes)
154+
#res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)]
155+
def f(b):
156+
return emd2_c(a,b,M)
157+
res= parmap(f, [b[:,i] for i in range(nb)],processes)
156158
return np.array(res)
157-
# with Pool(processes) as p:
158-
# res=p.map(f, ls)
159-
# return np.array(res)
160159

161160

162-
def fun(f, q_in, q_out):
163-
while True:
164-
i, x = q_in.get()
165-
if i is None:
166-
break
167-
q_out.put((i, f(x)))
168-
169-
def parmap(f, X, nprocs=multiprocessing.cpu_count()):
170-
q_in = multiprocessing.Queue(1)
171-
q_out = multiprocessing.Queue()
172-
173-
proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
174-
for _ in range(nprocs)]
175-
for p in proc:
176-
p.daemon = True
177-
p.start()
178-
179-
sent = [q_in.put((i, x)) for i, x in enumerate(X)]
180-
[q_in.put((None, None)) for _ in range(nprocs)]
181-
res = [q_out.get() for _ in range(len(sent))]
182-
183-
[p.join() for p in proc]
184-
185-
return [x for i, x in sorted(res)]
161+

0 commit comments

Comments
 (0)