@@ -650,15 +650,16 @@ class OTDA(object):
650
650
651
651
"""
652
652
653
- def __init__ (self , metric = 'sqeuclidean' ):
653
+ def __init__ (self , metric = 'sqeuclidean' , norm = None ):
654
654
""" Class initialization"""
655
655
self .xs = 0
656
656
self .xt = 0
657
657
self .G = 0
658
658
self .metric = metric
659
+ self .norm = norm
659
660
self .computed = False
660
661
661
- def fit (self , xs , xt , ws = None , wt = None , norm = None , max_iter = 100000 ):
662
+ def fit (self , xs , xt , ws = None , wt = None , max_iter = 100000 ):
662
663
"""Fit domain adaptation between samples is xs and xt
663
664
(with optional weights)"""
664
665
self .xs = xs
@@ -673,7 +674,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
673
674
self .wt = wt
674
675
675
676
self .M = dist (xs , xt , metric = self .metric )
676
- self .M = cost_normalization (self .M , norm )
677
+ self .M = cost_normalization (self .M , self . norm )
677
678
self .G = emd (ws , wt , self .M , max_iter )
678
679
self .computed = True
679
680
@@ -752,7 +753,7 @@ class OTDA_sinkhorn(OTDA):
752
753
753
754
"""
754
755
755
- def fit (self , xs , xt , reg = 1 , ws = None , wt = None , norm = None , ** kwargs ):
756
+ def fit (self , xs , xt , reg = 1 , ws = None , wt = None , ** kwargs ):
756
757
"""Fit regularized domain adaptation between samples is xs and xt
757
758
(with optional weights)"""
758
759
self .xs = xs
@@ -767,7 +768,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
767
768
self .wt = wt
768
769
769
770
self .M = dist (xs , xt , metric = self .metric )
770
- self .M = cost_normalization (self .M , norm )
771
+ self .M = cost_normalization (self .M , self . norm )
771
772
self .G = sinkhorn (ws , wt , self .M , reg , ** kwargs )
772
773
self .computed = True
773
774
@@ -779,8 +780,7 @@ class OTDA_lpl1(OTDA):
779
780
"""Class for domain adaptation with optimal transport with entropic and
780
781
group regularization"""
781
782
782
- def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , norm = None ,
783
- ** kwargs ):
783
+ def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , ** kwargs ):
784
784
"""Fit regularized domain adaptation between samples is xs and xt
785
785
(with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
786
786
parameters"""
@@ -796,7 +796,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
796
796
self .wt = wt
797
797
798
798
self .M = dist (xs , xt , metric = self .metric )
799
- self .M = cost_normalization (self .M , norm )
799
+ self .M = cost_normalization (self .M , self . norm )
800
800
self .G = sinkhorn_lpl1_mm (ws , ys , wt , self .M , reg , eta , ** kwargs )
801
801
self .computed = True
802
802
@@ -808,8 +808,7 @@ class OTDA_l1l2(OTDA):
808
808
"""Class for domain adaptation with optimal transport with entropic
809
809
and group lasso regularization"""
810
810
811
- def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , norm = None ,
812
- ** kwargs ):
811
+ def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , ** kwargs ):
813
812
"""Fit regularized domain adaptation between samples is xs and xt
814
813
(with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
815
814
parameters"""
@@ -825,7 +824,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
825
824
self .wt = wt
826
825
827
826
self .M = dist (xs , xt , metric = self .metric )
828
- self .M = cost_normalization (self .M , norm )
827
+ self .M = cost_normalization (self .M , self . norm )
829
828
self .G = sinkhorn_l1l2_gl (ws , ys , wt , self .M , reg , eta , ** kwargs )
830
829
self .computed = True
831
830
0 commit comments