Skip to content

Commit fadaf2a

Browse files
committed
Move norm out of fit to init for deprecated OTDA
1 parent 0316d55 commit fadaf2a

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

ot/da.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -650,15 +650,16 @@ class OTDA(object):
650650
651651
"""
652652

653-
def __init__(self, metric='sqeuclidean'):
653+
def __init__(self, metric='sqeuclidean', norm=None):
654654
""" Class initialization"""
655655
self.xs = 0
656656
self.xt = 0
657657
self.G = 0
658658
self.metric = metric
659+
self.norm = norm
659660
self.computed = False
660661

661-
def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
662+
def fit(self, xs, xt, ws=None, wt=None, max_iter=100000):
662663
"""Fit domain adaptation between samples is xs and xt
663664
(with optional weights)"""
664665
self.xs = xs
@@ -673,7 +674,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
673674
self.wt = wt
674675

675676
self.M = dist(xs, xt, metric=self.metric)
676-
self.M = cost_normalization(self.M, norm)
677+
self.M = cost_normalization(self.M, self.norm)
677678
self.G = emd(ws, wt, self.M, max_iter)
678679
self.computed = True
679680

@@ -752,7 +753,7 @@ class OTDA_sinkhorn(OTDA):
752753
753754
"""
754755

755-
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
756+
def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs):
756757
"""Fit regularized domain adaptation between samples is xs and xt
757758
(with optional weights)"""
758759
self.xs = xs
@@ -767,7 +768,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
767768
self.wt = wt
768769

769770
self.M = dist(xs, xt, metric=self.metric)
770-
self.M = cost_normalization(self.M, norm)
771+
self.M = cost_normalization(self.M, self.norm)
771772
self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
772773
self.computed = True
773774

@@ -779,8 +780,7 @@ class OTDA_lpl1(OTDA):
779780
"""Class for domain adaptation with optimal transport with entropic and
780781
group regularization"""
781782

782-
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
783-
**kwargs):
783+
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
784784
"""Fit regularized domain adaptation between samples is xs and xt
785785
(with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
786786
parameters"""
@@ -796,7 +796,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
796796
self.wt = wt
797797

798798
self.M = dist(xs, xt, metric=self.metric)
799-
self.M = cost_normalization(self.M, norm)
799+
self.M = cost_normalization(self.M, self.norm)
800800
self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
801801
self.computed = True
802802

@@ -808,8 +808,7 @@ class OTDA_l1l2(OTDA):
808808
"""Class for domain adaptation with optimal transport with entropic
809809
and group lasso regularization"""
810810

811-
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
812-
**kwargs):
811+
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
813812
"""Fit regularized domain adaptation between samples is xs and xt
814813
(with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
815814
parameters"""
@@ -825,7 +824,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
825824
self.wt = wt
826825

827826
self.M = dist(xs, xt, metric=self.metric)
828-
self.M = cost_normalization(self.M, norm)
827+
self.M = cost_normalization(self.M, self.norm)
829828
self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
830829
self.computed = True
831830

0 commit comments

Comments
 (0)