Skip to content

Commit 47477c5

Browse files
committed
add sinkhorbn2 +v3
1 parent 0fc1124 commit 47477c5

File tree

4 files changed

+133
-28
lines changed

4 files changed

+133
-28
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,15 @@ import ot
8383
# a,b are 1D histograms (sum to 1 and positive)
8484
# M is the ground cost matrix
8585
Wd=ot.emd2(a,b,M) # exact linear program
86+
Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
8687
# if b is a matrix compute all distances to a and return a vector
8788
```
8889
* Compute OT matrix
8990
```python
9091
# a,b are 1D histograms (sum to 1 and positive)
9192
# M is the ground cost matrix
92-
Totp=ot.emd(a,b,M) # exact linear program
93-
Totp_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
93+
T=ot.emd(a,b,M) # exact linear program
94+
T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
9495
```
9596
* Compute Wasserstein barycenter
9697
```python

examples/plot_compute_emd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@
6161

6262
#%%
6363
reg=1e-2
64-
d_sinkhorn=ot.sinkhorn(a,B,M,reg)
65-
d_sinkhorn2=ot.sinkhorn(a,B,M2,reg)
64+
d_sinkhorn=ot.sinkhorn2(a,B,M,reg)
65+
d_sinkhorn2=ot.sinkhorn2(a,B,M2,reg)
6666

6767
pl.figure(2)
6868
pl.clf()

ot/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
# OT functions
1818
from .lp import emd, emd2
19-
from .bregman import sinkhorn, barycenter
19+
from .bregman import sinkhorn, sinkhorn2, barycenter
2020
from .da import sinkhorn_lpl1_mm
2121

2222
# utils functions
2323
from .utils import dist, unif, tic, toc, toq
2424

25-
__version__ = "0.2"
25+
__version__ = "0.3"
2626

27-
__all__ = ["emd", "emd2", "sinkhorn", "utils", 'datasets', 'bregman', 'lp',
28-
'plot', 'tic', 'toc', 'toq',
27+
__all__ = ["emd", "emd2", "sinkhorn","sinkhorn2", "utils", 'datasets',
28+
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
2929
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

ot/bregman.py

Lines changed: 124 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
4141
Regularization term >0
4242
method : str
4343
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
44-
'sinkhorn_epsilon_scaling', see those function for specific parameters
44+
'sinkhorn_epsilon_scaling', see those function for specific parameters
4545
numItermax : int, optional
4646
Max number of iterations
4747
stopThr : float, optional
@@ -91,7 +91,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
9191
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
9292
9393
"""
94-
94+
9595
if method.lower()=='sinkhorn':
9696
sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
9797
stopThr=stopThr, verbose=verbose, log=log,**kwargs)
@@ -100,15 +100,119 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
100100
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
101101
elif method.lower()=='sinkhorn_epsilon_scaling':
102102
sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
103-
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
103+
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
104104
else:
105105
print('Warning : unknown method using classic Sinkhorn Knopp')
106106
sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
107-
107+
108108
return sink()
109+
110+
def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
111+
u"""
112+
Solve the entropic regularization optimal transport problem and return the loss
113+
114+
The function solves the following optimization problem:
115+
116+
.. math::
117+
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
118+
119+
s.t. \gamma 1 = a
120+
121+
\gamma^T 1= b
122+
123+
\gamma\geq 0
124+
where :
125+
126+
- M is the (ns,nt) metric cost matrix
127+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
128+
- a and b are source and target weights (sum to 1)
129+
130+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
131+
132+
133+
Parameters
134+
----------
135+
a : np.ndarray (ns,)
136+
samples weights in the source domain
137+
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
138+
samples in the target domain, compute sinkhorn with multiple targets
139+
and fixed M if b is a matrix (return OT loss + dual variables in log)
140+
M : np.ndarray (ns,nt)
141+
loss matrix
142+
reg : float
143+
Regularization term >0
144+
method : str
145+
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
146+
'sinkhorn_epsilon_scaling', see those function for specific parameters
147+
numItermax : int, optional
148+
Max number of iterations
149+
stopThr : float, optional
150+
Stop threshol on error (>0)
151+
verbose : bool, optional
152+
Print information along iterations
153+
log : bool, optional
154+
record log if True
155+
156+
157+
Returns
158+
-------
159+
W : (nt) ndarray or float
160+
Optimal transportation matrix for the given parameters
161+
log : dict
162+
log dictionary return only if log==True in parameters
163+
164+
Examples
165+
--------
166+
167+
>>> import ot
168+
>>> a=[.5,.5]
169+
>>> b=[.5,.5]
170+
>>> M=[[0.,1.],[1.,0.]]
171+
>>> ot.sinkhorn2(a,b,M,1)
172+
array([ 0.26894142])
109173
110174
111175
176+
References
177+
----------
178+
179+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
180+
181+
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
182+
183+
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
184+
185+
186+
187+
See Also
188+
--------
189+
ot.lp.emd : Unregularized OT
190+
ot.optim.cg : General regularized OT
191+
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
192+
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
193+
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
194+
195+
"""
196+
197+
if method.lower()=='sinkhorn':
198+
sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
199+
stopThr=stopThr, verbose=verbose, log=log,**kwargs)
200+
elif method.lower()=='sinkhorn_stabilized':
201+
sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax,
202+
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
203+
elif method.lower()=='sinkhorn_epsilon_scaling':
204+
sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
205+
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
206+
else:
207+
print('Warning : unknown method using classic Sinkhorn Knopp')
208+
sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
209+
210+
b=np.asarray(b,dtype=np.float64)
211+
if len(b.shape)<2:
212+
b=b.reshape((-1,1))
213+
214+
return sink()
215+
112216

113217
def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
114218
"""
@@ -189,23 +293,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
189293
a=np.asarray(a,dtype=np.float64)
190294
b=np.asarray(b,dtype=np.float64)
191295
M=np.asarray(M,dtype=np.float64)
192-
296+
193297

194298
if len(a)==0:
195299
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
196300
if len(b)==0:
197301
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
198-
302+
199303

200304
# init data
201305
Nini = len(a)
202306
Nfin = len(b)
203-
307+
204308
if len(b.shape)>1:
205309
nbb=b.shape[1]
206310
else:
207311
nbb=0
208-
312+
209313

210314
if log:
211315
log={'err':[]}
@@ -217,7 +321,7 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
217321
else:
218322
u = np.ones(Nini)/Nini
219323
v = np.ones(Nfin)/Nfin
220-
324+
221325

222326
#print(reg)
223327

@@ -261,23 +365,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
261365
if log:
262366
log['u']=u
263367
log['v']=v
264-
265-
if nbb: #return only loss
368+
369+
if nbb: #return only loss
266370
res=np.zeros((nbb))
267371
for i in range(nbb):
268372
res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
269373
if log:
270374
return res,log
271375
else:
272-
return res
273-
376+
return res
377+
274378
else: # return OT matrix
275-
379+
276380
if log:
277381
return u.reshape((-1,1))*K*v.reshape((1,-1)),log
278382
else:
279383
return u.reshape((-1,1))*K*v.reshape((1,-1))
280-
384+
281385

282386
def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs):
283387
"""
@@ -393,7 +497,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
393497
alpha,beta=np.zeros(na),np.zeros(nb)
394498
else:
395499
alpha,beta=warmstart
396-
500+
397501
if nbb:
398502
u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb
399503
else:
@@ -420,7 +524,7 @@ def get_Gamma(alpha,beta,u,v):
420524

421525
uprev = u
422526
vprev = v
423-
527+
424528
# sinkhorn update
425529
v = b/(np.dot(K.T,u)+1e-16)
426530
u = a/(np.dot(K,v)+1e-16)
@@ -471,8 +575,8 @@ def get_Gamma(alpha,beta,u,v):
471575
break
472576

473577
cpt = cpt +1
474-
475-
578+
579+
476580
#print('err=',err,' cpt=',cpt)
477581
if log:
478582
log['logu']=alpha/reg+np.log(u)
@@ -493,7 +597,7 @@ def get_Gamma(alpha,beta,u,v):
493597
res=np.zeros((nbb))
494598
for i in range(nbb):
495599
res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M)
496-
return res
600+
return res
497601
else:
498602
return get_Gamma(alpha,beta,u,v)
499603

0 commit comments

Comments
 (0)