@@ -151,7 +151,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
151
151
152
152
The algorithm used for solving the problem is the block coordinate
153
153
descent that alternates between updates of G (using conditionnal gradient)
154
- abd the update of L using a classical least square solver.
154
+ and the update of L using a classical least square solver.
155
155
156
156
157
157
Parameters
@@ -320,7 +320,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
320
320
321
321
The algorithm used for solving the problem is the block coordinate
322
322
descent that alternates between updates of G (using conditionnal gradient)
323
- abd the update of L using a classical kernel least square solver.
323
+ and the update of L using a classical kernel least square solver.
324
324
325
325
326
326
Parameters
@@ -492,7 +492,15 @@ def df(G):
492
492
493
493
494
494
class OTDA (object ):
495
- """Class for domain adaptation with optimal transport"""
495
+ """Class for domain adaptation with optimal transport as proposed in [5]
496
+
497
+
498
+ References
499
+ ----------
500
+
501
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
502
+
503
+ """
496
504
497
505
def __init__ (self ,metric = 'sqeuclidean' ):
498
506
""" Class initialization"""
@@ -504,8 +512,7 @@ def __init__(self,metric='sqeuclidean'):
504
512
505
513
506
514
def fit (self ,xs ,xt ,ws = None ,wt = None ):
507
- """ Fit domain adaptation between samples is xs and xt (with optional
508
- weights)"""
515
+ """ Fit domain adaptation between samples is xs and xt (with optional weights)"""
509
516
self .xs = xs
510
517
self .xt = xt
511
518
@@ -522,7 +529,7 @@ def fit(self,xs,xt,ws=None,wt=None):
522
529
self .computed = True
523
530
524
531
def interp (self ,direction = 1 ):
525
- """Barycentric interpolation for the source (1) or target (-1)
532
+ """Barycentric interpolation for the source (1) or target (-1) samples
526
533
527
534
This Barycentric interpolation solves for each source (resp target)
528
535
sample xs (resp xt) the following optimization problem:
@@ -558,10 +565,16 @@ def interp(self,direction=1):
558
565
559
566
560
567
def predict (self ,x ,direction = 1 ):
561
- """ Out of sample mapping using the formulation from Ferradans
568
+ """ Out of sample mapping using the formulation from [6]
569
+
570
+ For each sample x to map, it finds the nearest source sample xs and
571
+ map the samle x to the position xst+(x-xs) wher xst is the barycentric
572
+ interpolation of source sample xs.
573
+
574
+ References
575
+ ----------
562
576
563
- It basically find the source sample the nearset to the nex sample and
564
- apply the difference to the displaced source sample.
577
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
565
578
566
579
"""
567
580
if direction > 0 : # >0 then source to target
@@ -582,8 +595,7 @@ class OTDA_sinkhorn(OTDA):
582
595
"""Class for domain adaptation with optimal transport with entropic regularization"""
583
596
584
597
def fit (self ,xs ,xt ,reg = 1 ,ws = None ,wt = None ,** kwargs ):
585
- """ Fit domain adaptation between samples is xs and xt (with optional
586
- weights)"""
598
+ """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
587
599
self .xs = xs
588
600
self .xt = xt
589
601
@@ -601,12 +613,12 @@ def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
601
613
602
614
603
615
class OTDA_lpl1 (OTDA ):
604
- """Class for domain adaptation with optimal transport with entropic an group regularization"""
616
+ """Class for domain adaptation with optimal transport with entropic and group regularization"""
605
617
606
618
607
619
def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,** kwargs ):
608
- """ Fit domain adaptation between samples is xs and xt (with optional
609
- weights) """
620
+ """ Fit regularized domain adaptation between samples is xs and xt (with optional weights),
621
+ See ot.da.sinkhorn_lpl1_mm for fit parameters" "" "
610
622
self .xs = xs
611
623
self .xt = xt
612
624
@@ -623,7 +635,7 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
623
635
self .computed = True
624
636
625
637
class OTDA_mapping_linear (OTDA ):
626
- """Class for optimal transport with joint linear mapping estimation"""
638
+ """Class for optimal transport with joint linear mapping estimation as in [8] """
627
639
628
640
629
641
def __init__ (self ):
@@ -657,12 +669,7 @@ def mapping(self):
657
669
658
670
659
671
def predict (self ,x ):
660
- """ Out of sample mapping using the formulation from Ferradans
661
-
662
- It basically find the source sample the nearset to the nex sample and
663
- apply the difference to the displaced source sample.
664
-
665
- """
672
+ """ Out of sample mapping estimated during the call to fit"""
666
673
if self .computed :
667
674
if self .bias :
668
675
x = np .hstack ((x ,np .ones ((x .shape [0 ],1 ))))
@@ -672,13 +679,12 @@ def predict(self,x):
672
679
return None
673
680
674
681
class OTDA_mapping_kernel (OTDA_mapping_linear ):
675
- """Class for optimal transport with joint linear mapping estimation"""
682
+ """Class for optimal transport with joint nonlinear mapping estimation as in [8] """
676
683
677
684
678
685
679
686
def fit (self ,xs ,xt ,mu = 1 ,eta = 1 ,bias = False ,kerneltype = 'gaussian' ,sigma = 1 ,** kwargs ):
680
- """ Fit domain adaptation between samples is xs and xt (with optional
681
- weights)"""
687
+ """ Fit domain adaptation between samples is xs and xt """
682
688
self .xs = xs
683
689
self .xt = xt
684
690
self .bias = bias
@@ -695,12 +701,7 @@ def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs)
695
701
696
702
697
703
def predict (self ,x ):
698
- """ Out of sample mapping using the formulation from Ferradans
699
-
700
- It basically find the source sample the nearset to the nex sample and
701
- apply the difference to the displaced source sample.
702
-
703
- """
704
+ """ Out of sample mapping estimated during the call to fit"""
704
705
705
706
if self .computed :
706
707
K = kernel (x ,self .xs ,method = self .kernel ,sigma = self .sigma ,** self .kwargs )
0 commit comments