Skip to content

Commit df32d77

Browse files
committed
first try
1 parent 0b80637 commit df32d77

File tree

2 files changed

+108
-3
lines changed

2 files changed

+108
-3
lines changed

ot/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from . import da
1212

1313
# OT functions
14-
from .lp import emd
14+
from .lp import emd, emd2
1515
from .bregman import sinkhorn, barycenter
1616
from .da import sinkhorn_lpl1_mm
1717

@@ -20,5 +20,5 @@
2020

2121
__version__ = "0.1.12"
2222

23-
__all__ = ["emd", "sinkhorn", "utils", 'datasets', 'bregman', 'lp', 'plot',
23+
__all__ = ["emd", "emd2", "sinkhorn", "utils", 'datasets', 'bregman', 'lp', 'plot',
2424
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

ot/lp/__init__.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
# import compiled emd
88
from .emd import emd_c
9-
9+
import multiprocessing
1010

1111
def emd(a, b, M):
1212
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -70,9 +70,114 @@ def emd(a, b, M):
7070
b = np.asarray(b, dtype=np.float64)
7171
M = np.asarray(M, dtype=np.float64)
7272

73+
# if empty array given then use unifor distributions
7374
if len(a) == 0:
7475
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
7576
if len(b) == 0:
7677
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
7778

7879
return emd_c(a, b, M)
80+
81+
def emd2(a, b, M,processes=None):
82+
"""Solves the Earth Movers distance problem and returns the loss
83+
84+
.. math::
85+
\gamma = arg\min_\gamma <\gamma,M>_F
86+
87+
s.t. \gamma 1 = a
88+
\gamma^T 1= b
89+
\gamma\geq 0
90+
where :
91+
92+
- M is the metric cost matrix
93+
- a and b are the sample weights
94+
95+
Uses the algorithm proposed in [1]_
96+
97+
Parameters
98+
----------
99+
a : (ns,) ndarray, float64
100+
Source histogram (uniform weigth if empty list)
101+
b : (nt,) ndarray, float64
102+
Target histogram (uniform weigth if empty list)
103+
M : (ns,nt) ndarray, float64
104+
loss matrix
105+
106+
Returns
107+
-------
108+
gamma: (ns x nt) ndarray
109+
Optimal transportation matrix for the given parameters
110+
111+
112+
Examples
113+
--------
114+
115+
Simple example with obvious solution. The function emd accepts lists and
116+
perform automatic conversion to numpy arrays
117+
>>> import ot
118+
>>> a=[.5,.5]
119+
>>> b=[.5,.5]
120+
>>> M=[[0.,1.],[1.,0.]]
121+
>>> ot.emd2(a,b,M)
122+
0.0
123+
124+
References
125+
----------
126+
127+
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
128+
(2011, December). Displacement interpolation using Lagrangian mass
129+
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
130+
158). ACM.
131+
132+
See Also
133+
--------
134+
ot.bregman.sinkhorn : Entropic regularized OT
135+
ot.optim.cg : General regularized OT"""
136+
137+
a = np.asarray(a, dtype=np.float64)
138+
b = np.asarray(b, dtype=np.float64)
139+
M = np.asarray(M, dtype=np.float64)
140+
141+
# if empty array given then use unifor distributions
142+
if len(a) == 0:
143+
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
144+
if len(b) == 0:
145+
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
146+
147+
if len(b.shape)==1:
148+
return np.sum(emd_c(a, b, M)*M)
149+
else:
150+
nb=b.shape[1]
151+
ls=[(a,b[:,k],M) for k in range(nb)]
152+
# run emd in multiprocessing
153+
res=parmap(emd2, ls,processes)
154+
np.array(res)
155+
# with Pool(processes) as p:
156+
# res=p.map(f, ls)
157+
# return np.array(res)
158+
159+
160+
def fun(f, q_in, q_out):
161+
while True:
162+
i, x = q_in.get()
163+
if i is None:
164+
break
165+
q_out.put((i, f(x)))
166+
167+
def parmap(f, X, nprocs):
168+
q_in = multiprocessing.Queue(1)
169+
q_out = multiprocessing.Queue()
170+
171+
proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
172+
for _ in range(nprocs)]
173+
for p in proc:
174+
p.daemon = True
175+
p.start()
176+
177+
sent = [q_in.put((i, x)) for i, x in enumerate(X)]
178+
[q_in.put((None, None)) for _ in range(nprocs)]
179+
res = [q_out.get() for _ in range(len(sent))]
180+
181+
[p.join() for p in proc]
182+
183+
return [x for i, x in sorted(res)]

0 commit comments

Comments
 (0)