Skip to content

Commit 7e16b7a

Browse files
committed
add mapping estimation with kernels (still debugging)
1 parent 86b1c88 commit 7e16b7a

File tree

2 files changed

+47
-7
lines changed

2 files changed

+47
-7
lines changed

ot/da.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
247247

248248
def loss(L,G):
249249
"""Compute full loss"""
250-
return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L)**2)
250+
return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L))
251251

252252
def solve_L_nobias(G):
253253
""" solve L problem with fixed G (least square)"""
@@ -450,11 +450,11 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
450450
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
451451
self.computed=True
452452

453-
class OTDA_mapping(OTDA):
453+
class OTDA_mapping_linear(OTDA):
454454
"""Class for optimal transport with joint linear mapping estimation"""
455455

456456

457-
def __init__(self,metric='sqeuclidean'):
457+
def __init__(self):
458458
""" Class initialization"""
459459

460460

@@ -463,8 +463,8 @@ def __init__(self,metric='sqeuclidean'):
463463
self.G=0
464464
self.L=0
465465
self.bias=False
466-
self.metric=metric
467466
self.computed=False
467+
self.metric='sqeuclidean'
468468

469469
def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs):
470470
""" Fit domain adaptation between samples is xs and xt (with optional
@@ -473,6 +473,7 @@ def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs):
473473
self.xt=xt
474474
self.bias=bias
475475

476+
476477
self.ws=unif(xs.shape[0])
477478
self.wt=unif(xt.shape[0])
478479

@@ -498,3 +499,42 @@ def predict(self,x):
498499
print("Warning, model not fitted yet, returning None")
499500
return None
500501

502+
class OTDA_mapping_kernel(OTDA_mapping_linear):
503+
"""Class for optimal transport with joint linear mapping estimation"""
504+
505+
506+
507+
def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs):
508+
""" Fit domain adaptation between samples is xs and xt (with optional
509+
weights)"""
510+
self.xs=xs
511+
self.xt=xt
512+
self.bias=bias
513+
514+
self.ws=unif(xs.shape[0])
515+
self.wt=unif(xt.shape[0])
516+
self.kernel=kerneltype
517+
self.sigma=sigma
518+
self.kwargs=kwargs
519+
520+
521+
self.G,self.L=joint_OT_mapping_kernel(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs)
522+
self.computed=True
523+
524+
525+
def predict(self,x):
526+
""" Out of sample mapping using the formulation from Ferradans
527+
528+
It basically find the source sample the nearset to the nex sample and
529+
apply the difference to the displaced source sample.
530+
531+
"""
532+
533+
if self.computed:
534+
K=kernel(x,self.xs,method=self.kernel,sigma=self.sigma,**self.kwargs)
535+
if self.bias:
536+
K=np.hstack((K,np.ones((x.shape[0],1))))
537+
return K.dot(self.L)
538+
else:
539+
print("Warning, model not fitted yet, returning None")
540+
return None

ot/datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
108108
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
109109

110110
elif dataset.lower()=='gaussrot' :
111-
rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])
112-
m1=np.array([-1,-1])
113-
m2=np.array([1,1])
111+
rot=np.array([[np.cos(theta),np.sin(theta)],[-np.sin(theta),np.cos(theta)]])
112+
m1=np.array([-1,1])
113+
m2=np.array([1,-1])
114114
y=np.floor((np.arange(n)*1.0/n*2))+1
115115
n1=np.sum(y==1)
116116
n2=np.sum(y==2)

0 commit comments

Comments
 (0)