@@ -247,7 +247,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
247
247
248
248
def loss (L ,G ):
249
249
"""Compute full loss"""
250
- return np .sum ((K1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum ( sel ( L ) ** 2 )
250
+ return np .sum ((K1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .trace ( L . T . dot ( K0 ). dot ( L ) )
251
251
252
252
def solve_L_nobias (G ):
253
253
""" solve L problem with fixed G (least square)"""
@@ -450,11 +450,11 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
450
450
self .G = sinkhorn_lpl1_mm (ws ,ys ,wt ,self .M ,reg ,eta ,** kwargs )
451
451
self .computed = True
452
452
453
- class OTDA_mapping (OTDA ):
453
+ class OTDA_mapping_linear (OTDA ):
454
454
"""Class for optimal transport with joint linear mapping estimation"""
455
455
456
456
457
- def __init__ (self , metric = 'sqeuclidean' ):
457
+ def __init__ (self ):
458
458
""" Class initialization"""
459
459
460
460
@@ -463,8 +463,8 @@ def __init__(self,metric='sqeuclidean'):
463
463
self .G = 0
464
464
self .L = 0
465
465
self .bias = False
466
- self .metric = metric
467
466
self .computed = False
467
+ self .metric = 'sqeuclidean'
468
468
469
469
def fit (self ,xs ,xt ,mu = 1 ,eta = 1 ,bias = False ,** kwargs ):
470
470
""" Fit domain adaptation between samples is xs and xt (with optional
@@ -473,6 +473,7 @@ def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs):
473
473
self .xt = xt
474
474
self .bias = bias
475
475
476
+
476
477
self .ws = unif (xs .shape [0 ])
477
478
self .wt = unif (xt .shape [0 ])
478
479
@@ -498,3 +499,42 @@ def predict(self,x):
498
499
print ("Warning, model not fitted yet, returning None" )
499
500
return None
500
501
502
+ class OTDA_mapping_kernel (OTDA_mapping_linear ):
503
+ """Class for optimal transport with joint linear mapping estimation"""
504
+
505
+
506
+
507
+ def fit (self ,xs ,xt ,mu = 1 ,eta = 1 ,bias = False ,kerneltype = 'gaussian' ,sigma = 1 ,** kwargs ):
508
+ """ Fit domain adaptation between samples is xs and xt (with optional
509
+ weights)"""
510
+ self .xs = xs
511
+ self .xt = xt
512
+ self .bias = bias
513
+
514
+ self .ws = unif (xs .shape [0 ])
515
+ self .wt = unif (xt .shape [0 ])
516
+ self .kernel = kerneltype
517
+ self .sigma = sigma
518
+ self .kwargs = kwargs
519
+
520
+
521
+ self .G ,self .L = joint_OT_mapping_kernel (xs ,xt ,mu = mu ,eta = eta ,bias = bias ,** kwargs )
522
+ self .computed = True
523
+
524
+
525
+ def predict (self ,x ):
526
+ """ Out of sample mapping using the formulation from Ferradans
527
+
528
+ It basically find the source sample the nearset to the nex sample and
529
+ apply the difference to the displaced source sample.
530
+
531
+ """
532
+
533
+ if self .computed :
534
+ K = kernel (x ,self .xs ,method = self .kernel ,sigma = self .sigma ,** self .kwargs )
535
+ if self .bias :
536
+ K = np .hstack ((K ,np .ones ((x .shape [0 ],1 ))))
537
+ return K .dot (self .L )
538
+ else :
539
+ print ("Warning, model not fitted yet, returning None" )
540
+ return None
0 commit comments