Skip to content

Commit 0ea30e9

Browse files
committed
add mapping estimation with kernels (smaller bugs)
1 parent 7e16b7a commit 0ea30e9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

ot/da.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,13 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
229229
I=np.eye(ns+1)
230230
I[-1]=0
231231
K0 = K1.T.dot(K1)+eta*I
232+
Kreg=I
232233
sel=lambda x : x[:-1,:]
233234
else:
234235
K1=K
235236
I=np.eye(ns)
236237
K0=K+eta*I
238+
Kreg=K
237239
sel=lambda x : x
238240

239241
if log:
@@ -247,7 +249,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
247249

248250
def loss(L,G):
249251
"""Compute full loss"""
250-
return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L))
252+
return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(Kreg).dot(L))
251253

252254
def solve_L_nobias(G):
253255
""" solve L problem with fixed G (least square)"""

0 commit comments

Comments
 (0)