6
6
import numpy as np
7
7
# import compiled emd
8
8
from .emd import emd_c
9
-
9
+ import multiprocessing
10
10
11
11
def emd (a , b , M ):
12
12
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -70,9 +70,114 @@ def emd(a, b, M):
70
70
b = np .asarray (b , dtype = np .float64 )
71
71
M = np .asarray (M , dtype = np .float64 )
72
72
73
+ # if empty array given then use unifor distributions
73
74
if len (a ) == 0 :
74
75
a = np .ones ((M .shape [0 ], ), dtype = np .float64 )/ M .shape [0 ]
75
76
if len (b ) == 0 :
76
77
b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
77
78
78
79
return emd_c (a , b , M )
80
+
81
+ def emd2 (a , b , M ,processes = None ):
82
+ """Solves the Earth Movers distance problem and returns the loss
83
+
84
+ .. math::
85
+ \gamma = arg\min_\gamma <\gamma,M>_F
86
+
87
+ s.t. \gamma 1 = a
88
+ \gamma^T 1= b
89
+ \gamma\geq 0
90
+ where :
91
+
92
+ - M is the metric cost matrix
93
+ - a and b are the sample weights
94
+
95
+ Uses the algorithm proposed in [1]_
96
+
97
+ Parameters
98
+ ----------
99
+ a : (ns,) ndarray, float64
100
+ Source histogram (uniform weigth if empty list)
101
+ b : (nt,) ndarray, float64
102
+ Target histogram (uniform weigth if empty list)
103
+ M : (ns,nt) ndarray, float64
104
+ loss matrix
105
+
106
+ Returns
107
+ -------
108
+ gamma: (ns x nt) ndarray
109
+ Optimal transportation matrix for the given parameters
110
+
111
+
112
+ Examples
113
+ --------
114
+
115
+ Simple example with obvious solution. The function emd accepts lists and
116
+ perform automatic conversion to numpy arrays
117
+ >>> import ot
118
+ >>> a=[.5,.5]
119
+ >>> b=[.5,.5]
120
+ >>> M=[[0.,1.],[1.,0.]]
121
+ >>> ot.emd2(a,b,M)
122
+ 0.0
123
+
124
+ References
125
+ ----------
126
+
127
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
128
+ (2011, December). Displacement interpolation using Lagrangian mass
129
+ transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
130
+ 158). ACM.
131
+
132
+ See Also
133
+ --------
134
+ ot.bregman.sinkhorn : Entropic regularized OT
135
+ ot.optim.cg : General regularized OT"""
136
+
137
+ a = np .asarray (a , dtype = np .float64 )
138
+ b = np .asarray (b , dtype = np .float64 )
139
+ M = np .asarray (M , dtype = np .float64 )
140
+
141
+ # if empty array given then use unifor distributions
142
+ if len (a ) == 0 :
143
+ a = np .ones ((M .shape [0 ], ), dtype = np .float64 )/ M .shape [0 ]
144
+ if len (b ) == 0 :
145
+ b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
146
+
147
+ if len (b .shape )== 1 :
148
+ return np .sum (emd_c (a , b , M )* M )
149
+ else :
150
+ nb = b .shape [1 ]
151
+ ls = [(a ,b [:,k ],M ) for k in range (nb )]
152
+ # run emd in multiprocessing
153
+ res = parmap (emd2 , ls ,processes )
154
+ np .array (res )
155
+ # with Pool(processes) as p:
156
+ # res=p.map(f, ls)
157
+ # return np.array(res)
158
+
159
+
160
+ def fun (f , q_in , q_out ):
161
+ while True :
162
+ i , x = q_in .get ()
163
+ if i is None :
164
+ break
165
+ q_out .put ((i , f (x )))
166
+
167
+ def parmap (f , X , nprocs ):
168
+ q_in = multiprocessing .Queue (1 )
169
+ q_out = multiprocessing .Queue ()
170
+
171
+ proc = [multiprocessing .Process (target = fun , args = (f , q_in , q_out ))
172
+ for _ in range (nprocs )]
173
+ for p in proc :
174
+ p .daemon = True
175
+ p .start ()
176
+
177
+ sent = [q_in .put ((i , x )) for i , x in enumerate (X )]
178
+ [q_in .put ((None , None )) for _ in range (nprocs )]
179
+ res = [q_out .get () for _ in range (len (sent ))]
180
+
181
+ [p .join () for p in proc ]
182
+
183
+ return [x for i , x in sorted (res )]
0 commit comments