Skip to content

Commit 405f352

Browse files
committed
add mapping estimation with kernels works!
1 parent 0ea30e9 commit 405f352

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

ot/da.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
123123

124124
return transp
125125

126-
def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
126+
def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
127127
"""Joint Ot and mapping estimation (uniform weights and )
128128
"""
129129

@@ -209,15 +209,15 @@ def df(G):
209209
if verbose:
210210
if it%20==0:
211211
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
212-
print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
212+
print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],(vloss[-1]-vloss[-2])/abs(vloss[-2])))
213213
if log:
214214
log['loss']=vloss
215215
return G,L,log
216216
else:
217217
return G,L
218218

219219

220-
def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
220+
def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
221221
"""Joint Ot and mapping estimation (uniform weights and )
222222
"""
223223

@@ -228,15 +228,31 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
228228
K1=np.hstack((K,np.ones((ns,1))))
229229
I=np.eye(ns+1)
230230
I[-1]=0
231-
K0 = K1.T.dot(K1)+eta*I
232-
Kreg=I
233-
sel=lambda x : x[:-1,:]
231+
Kp=np.eye(ns+1)
232+
Kp[:ns,:ns]=K
233+
234+
# ls regu
235+
#K0 = K1.T.dot(K1)+eta*I
236+
#Kreg=I
237+
238+
# RKHS regul
239+
K0 = K1.T.dot(K1)+eta*Kp
240+
Kreg=Kp
241+
234242
else:
235243
K1=K
236244
I=np.eye(ns)
245+
246+
# ls regul
247+
#K0 = K1.T.dot(K1)+eta*I
248+
#Kreg=I
249+
250+
# proper kernel ridge
237251
K0=K+eta*I
238252
Kreg=K
239-
sel=lambda x : x
253+
254+
255+
240256

241257
if log:
242258
log={'err':[]}
@@ -313,7 +329,7 @@ def df(G):
313329
if verbose:
314330
if it%20==0:
315331
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
316-
print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
332+
print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],(vloss[-1]-vloss[-2])/abs(vloss[-2])))
317333
if log:
318334
log['loss']=vloss
319335
return G,L,log

ot/optim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def cost(G):
159159

160160
# problem linearization
161161
Mi=M+reg*df(G)
162+
# set M positive
163+
Mi+=Mi.min()
162164

163165
# solve linear program
164166
Gc=emd(a,b,Mi)

ot/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
def kernel(x1,x2,method='gaussian',sigma=1,**kwargs):
1010
"""Compute kernel matrix"""
1111
if method.lower() in ['gaussian','gauss','rbf']:
12-
K=np.exp(dist(x1,x2)/(2*sigma**2))
12+
K=np.exp(-dist(x1,x2)/(2*sigma**2))
1313
return K
1414

1515
def unif(n):

0 commit comments

Comments
 (0)