Skip to content

Commit 86b1c88

Browse files
committed
add mapping estimation with kernels (still debugging)
1 parent 566645a commit 86b1c88

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

ot/da.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -217,25 +217,23 @@ def df(G):
217217
return G,L
218218

219219

220-
def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
220+
def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
221221
"""Joint Ot and mapping estimation (uniform weights and )
222222
"""
223223

224224
ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1]
225225

226+
K=kernel(xs,xs,method=kerneltype,sigma=sigma)
226227
if bias:
227-
K=
228-
xs1=np.hstack((xs,np.ones((ns,1))))
229-
xstxs=xs1.T.dot(xs1)
230-
I=np.eye(d+1)
228+
K1=np.hstack((K,np.ones((ns,1))))
229+
I=np.eye(ns+1)
231230
I[-1]=0
232-
I0=I[:,:-1]
231+
K0 = K1.T.dot(K1)+eta*I
233232
sel=lambda x : x[:-1,:]
234233
else:
235-
xs1=xs
236-
xstxs=xs1.T.dot(xs1)
237-
I=np.eye(d)
238-
I0=I
234+
K1=K
235+
I=np.eye(ns)
236+
K0=K+eta*I
239237
sel=lambda x : x
240238

241239
if log:
@@ -249,23 +247,32 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=
249247

250248
def loss(L,G):
251249
"""Compute full loss"""
252-
return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
250+
return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L)**2)
253251

254-
def solve_L(G):
252+
def solve_L_nobias(G):
255253
""" solve L problem with fixed G (least square)"""
256254
xst=ns*G.dot(xt)
257-
return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0)
255+
return np.linalg.solve(K0,xst)
256+
257+
def solve_L_bias(G):
258+
""" solve L problem with fixed G (least square)"""
259+
xst=ns*G.dot(xt)
260+
return np.linalg.solve(K0,K1.T.dot(xst))
258261

259262
def solve_G(L,G0):
260263
"""Update G with CG algorithm"""
261-
xsi=xs1.dot(L)
264+
xsi=K1.dot(L)
262265
def f(G):
263266
return np.sum((xsi-ns*G.dot(xt))**2)
264267
def df(G):
265268
return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T)
266269
G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr)
267270
return G
268271

272+
if bias:
273+
solve_L=solve_L_bias
274+
else:
275+
solve_L=solve_L_nobias
269276

270277
L=solve_L(G)
271278

0 commit comments

Comments
 (0)