Skip to content

Commit bb6020d

Browse files
committed
more efficient sinkhorn
1 parent 17315cb commit bb6020d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ot/bregman.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
141141
cpt = cpt +1
142142
#print 'err=',err,' cpt=',cpt
143143
if log:
144-
return np.dot(np.diag(u),np.dot(K,np.diag(v))),log
144+
return u.reshape((-1,1))*K*v.reshape((1,-1)),log
145145
else:
146-
return np.dot(np.diag(u),np.dot(K,np.diag(v)))
146+
return u.reshape((-1,1))*K*v.reshape((1,-1))
147147

148148
def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False):
149149
"""

0 commit comments

Comments
 (0)