@@ -41,7 +41,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
41
41
Regularization term >0
42
42
method : str
43
43
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
44
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
44
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
45
45
numItermax : int, optional
46
46
Max number of iterations
47
47
stopThr : float, optional
@@ -91,7 +91,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
91
91
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
92
92
93
93
"""
94
-
94
+
95
95
if method .lower ()== 'sinkhorn' :
96
96
sink = lambda : sinkhorn_knopp (a ,b , M , reg ,numItermax = numItermax ,
97
97
stopThr = stopThr , verbose = verbose , log = log ,** kwargs )
@@ -100,15 +100,119 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
100
100
stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
101
101
elif method .lower ()== 'sinkhorn_epsilon_scaling' :
102
102
sink = lambda : sinkhorn_epsilon_scaling (a ,b , M , reg ,numItermax = numItermax ,
103
- stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
103
+ stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
104
104
else :
105
105
print ('Warning : unknown method using classic Sinkhorn Knopp' )
106
106
sink = lambda : sinkhorn_knopp (a ,b , M , reg , ** kwargs )
107
-
107
+
108
108
return sink ()
109
+
110
+ def sinkhorn2 (a ,b , M , reg ,method = 'sinkhorn' , numItermax = 1000 , stopThr = 1e-9 , verbose = False , log = False ,** kwargs ):
111
+ u"""
112
+ Solve the entropic regularization optimal transport problem and return the loss
113
+
114
+ The function solves the following optimization problem:
115
+
116
+ .. math::
117
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
118
+
119
+ s.t. \gamma 1 = a
120
+
121
+ \gamma^T 1= b
122
+
123
+ \gamma\geq 0
124
+ where :
125
+
126
+ - M is the (ns,nt) metric cost matrix
127
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
128
+ - a and b are source and target weights (sum to 1)
129
+
130
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
131
+
132
+
133
+ Parameters
134
+ ----------
135
+ a : np.ndarray (ns,)
136
+ samples weights in the source domain
137
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
138
+ samples in the target domain, compute sinkhorn with multiple targets
139
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
140
+ M : np.ndarray (ns,nt)
141
+ loss matrix
142
+ reg : float
143
+ Regularization term >0
144
+ method : str
145
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
146
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
147
+ numItermax : int, optional
148
+ Max number of iterations
149
+ stopThr : float, optional
150
+ Stop threshol on error (>0)
151
+ verbose : bool, optional
152
+ Print information along iterations
153
+ log : bool, optional
154
+ record log if True
155
+
156
+
157
+ Returns
158
+ -------
159
+ W : (nt) ndarray or float
160
+ Optimal transportation matrix for the given parameters
161
+ log : dict
162
+ log dictionary return only if log==True in parameters
163
+
164
+ Examples
165
+ --------
166
+
167
+ >>> import ot
168
+ >>> a=[.5,.5]
169
+ >>> b=[.5,.5]
170
+ >>> M=[[0.,1.],[1.,0.]]
171
+ >>> ot.sinkhorn2(a,b,M,1)
172
+ array([ 0.26894142])
109
173
110
174
111
175
176
+ References
177
+ ----------
178
+
179
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
180
+
181
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
182
+
183
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
184
+
185
+
186
+
187
+ See Also
188
+ --------
189
+ ot.lp.emd : Unregularized OT
190
+ ot.optim.cg : General regularized OT
191
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
192
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
193
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
194
+
195
+ """
196
+
197
+ if method .lower ()== 'sinkhorn' :
198
+ sink = lambda : sinkhorn_knopp (a ,b , M , reg ,numItermax = numItermax ,
199
+ stopThr = stopThr , verbose = verbose , log = log ,** kwargs )
200
+ elif method .lower ()== 'sinkhorn_stabilized' :
201
+ sink = lambda : sinkhorn_stabilized (a ,b , M , reg ,numItermax = numItermax ,
202
+ stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
203
+ elif method .lower ()== 'sinkhorn_epsilon_scaling' :
204
+ sink = lambda : sinkhorn_epsilon_scaling (a ,b , M , reg ,numItermax = numItermax ,
205
+ stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
206
+ else :
207
+ print ('Warning : unknown method using classic Sinkhorn Knopp' )
208
+ sink = lambda : sinkhorn_knopp (a ,b , M , reg , ** kwargs )
209
+
210
+ b = np .asarray (b ,dtype = np .float64 )
211
+ if len (b .shape )< 2 :
212
+ b = b .reshape ((- 1 ,1 ))
213
+
214
+ return sink ()
215
+
112
216
113
217
def sinkhorn_knopp (a ,b , M , reg , numItermax = 1000 , stopThr = 1e-9 , verbose = False , log = False ,** kwargs ):
114
218
"""
@@ -189,23 +293,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
189
293
a = np .asarray (a ,dtype = np .float64 )
190
294
b = np .asarray (b ,dtype = np .float64 )
191
295
M = np .asarray (M ,dtype = np .float64 )
192
-
296
+
193
297
194
298
if len (a )== 0 :
195
299
a = np .ones ((M .shape [0 ],),dtype = np .float64 )/ M .shape [0 ]
196
300
if len (b )== 0 :
197
301
b = np .ones ((M .shape [1 ],),dtype = np .float64 )/ M .shape [1 ]
198
-
302
+
199
303
200
304
# init data
201
305
Nini = len (a )
202
306
Nfin = len (b )
203
-
307
+
204
308
if len (b .shape )> 1 :
205
309
nbb = b .shape [1 ]
206
310
else :
207
311
nbb = 0
208
-
312
+
209
313
210
314
if log :
211
315
log = {'err' :[]}
@@ -217,7 +321,7 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
217
321
else :
218
322
u = np .ones (Nini )/ Nini
219
323
v = np .ones (Nfin )/ Nfin
220
-
324
+
221
325
222
326
#print(reg)
223
327
@@ -261,23 +365,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
261
365
if log :
262
366
log ['u' ]= u
263
367
log ['v' ]= v
264
-
265
- if nbb : #return only loss
368
+
369
+ if nbb : #return only loss
266
370
res = np .zeros ((nbb ))
267
371
for i in range (nbb ):
268
372
res [i ]= np .sum (u [:,i ].reshape ((- 1 ,1 ))* K * v [:,i ].reshape ((1 ,- 1 ))* M )
269
373
if log :
270
374
return res ,log
271
375
else :
272
- return res
273
-
376
+ return res
377
+
274
378
else : # return OT matrix
275
-
379
+
276
380
if log :
277
381
return u .reshape ((- 1 ,1 ))* K * v .reshape ((1 ,- 1 )),log
278
382
else :
279
383
return u .reshape ((- 1 ,1 ))* K * v .reshape ((1 ,- 1 ))
280
-
384
+
281
385
282
386
def sinkhorn_stabilized (a ,b , M , reg , numItermax = 1000 ,tau = 1e3 , stopThr = 1e-9 ,warmstart = None , verbose = False ,print_period = 20 , log = False ,** kwargs ):
283
387
"""
@@ -393,7 +497,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
393
497
alpha ,beta = np .zeros (na ),np .zeros (nb )
394
498
else :
395
499
alpha ,beta = warmstart
396
-
500
+
397
501
if nbb :
398
502
u ,v = np .ones ((na ,nbb ))/ na ,np .ones ((nb ,nbb ))/ nb
399
503
else :
@@ -420,7 +524,7 @@ def get_Gamma(alpha,beta,u,v):
420
524
421
525
uprev = u
422
526
vprev = v
423
-
527
+
424
528
# sinkhorn update
425
529
v = b / (np .dot (K .T ,u )+ 1e-16 )
426
530
u = a / (np .dot (K ,v )+ 1e-16 )
@@ -471,8 +575,8 @@ def get_Gamma(alpha,beta,u,v):
471
575
break
472
576
473
577
cpt = cpt + 1
474
-
475
-
578
+
579
+
476
580
#print('err=',err,' cpt=',cpt)
477
581
if log :
478
582
log ['logu' ]= alpha / reg + np .log (u )
@@ -493,7 +597,7 @@ def get_Gamma(alpha,beta,u,v):
493
597
res = np .zeros ((nbb ))
494
598
for i in range (nbb ):
495
599
res [i ]= np .sum (get_Gamma (alpha ,beta ,u [:,i ],v [:,i ])* M )
496
- return res
600
+ return res
497
601
else :
498
602
return get_Gamma (alpha ,beta ,u ,v )
499
603
0 commit comments