13
13
14
14
from .bregman import sinkhorn
15
15
from .lp import emd
16
- from .utils import unif , dist , kernel
16
+ from .utils import unif , dist , kernel , cost_normalization
17
17
from .utils import check_params , deprecated , BaseEstimator
18
18
from .optim import cg
19
19
from .optim import gcg
@@ -673,7 +673,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
673
673
self .wt = wt
674
674
675
675
self .M = dist (xs , xt , metric = self .metric )
676
- self .normalizeM ( norm )
676
+ self .M = cost_normalization ( self . M , norm )
677
677
self .G = emd (ws , wt , self .M , max_iter )
678
678
self .computed = True
679
679
@@ -741,26 +741,6 @@ def predict(self, x, direction=1):
741
741
# aply the delta to the interpolation
742
742
return xf [idx , :] + x - x0 [idx , :]
743
743
744
- def normalizeM (self , norm ):
745
- """ Apply normalization to the loss matrix
746
-
747
-
748
- Parameters
749
- ----------
750
- norm : str
751
- type of normalization from 'median','max','log','loglog'
752
-
753
- """
754
-
755
- if norm == "median" :
756
- self .M /= float (np .median (self .M ))
757
- elif norm == "max" :
758
- self .M /= float (np .max (self .M ))
759
- elif norm == "log" :
760
- self .M = np .log (1 + self .M )
761
- elif norm == "loglog" :
762
- self .M = np .log (1 + np .log (1 + self .M ))
763
-
764
744
765
745
@deprecated ("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
766
746
" removed in 0.5 \n Use class SinkhornTransport instead." )
@@ -787,7 +767,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
787
767
self .wt = wt
788
768
789
769
self .M = dist (xs , xt , metric = self .metric )
790
- self .normalizeM ( norm )
770
+ self .M = cost_normalization ( self . M , norm )
791
771
self .G = sinkhorn (ws , wt , self .M , reg , ** kwargs )
792
772
self .computed = True
793
773
@@ -816,7 +796,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
816
796
self .wt = wt
817
797
818
798
self .M = dist (xs , xt , metric = self .metric )
819
- self .normalizeM ( norm )
799
+ self .M = cost_normalization ( self . M , norm )
820
800
self .G = sinkhorn_lpl1_mm (ws , ys , wt , self .M , reg , eta , ** kwargs )
821
801
self .computed = True
822
802
@@ -845,7 +825,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
845
825
self .wt = wt
846
826
847
827
self .M = dist (xs , xt , metric = self .metric )
848
- self .normalizeM ( norm )
828
+ self .M = cost_normalization ( self . M , norm )
849
829
self .G = sinkhorn_l1l2_gl (ws , ys , wt , self .M , reg , eta , ** kwargs )
850
830
self .computed = True
851
831
@@ -1001,7 +981,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1001
981
1002
982
# pairwise distance
1003
983
self .cost_ = dist (Xs , Xt , metric = self .metric )
1004
- self .normalizeCost_ ( self .norm )
984
+ self .cost_ = cost_normalization ( self . cost_ , self .norm )
1005
985
1006
986
if (ys is not None ) and (yt is not None ):
1007
987
@@ -1183,26 +1163,6 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
1183
1163
1184
1164
return transp_Xt
1185
1165
1186
- def normalizeCost_ (self , norm ):
1187
- """ Apply normalization to the loss matrix
1188
-
1189
-
1190
- Parameters
1191
- ----------
1192
- norm : str
1193
- type of normalization from 'median','max','log','loglog'
1194
-
1195
- """
1196
-
1197
- if norm == "median" :
1198
- self .cost_ /= float (np .median (self .cost_ ))
1199
- elif norm == "max" :
1200
- self .cost_ /= float (np .max (self .cost_ ))
1201
- elif norm == "log" :
1202
- self .cost_ = np .log (1 + self .cost_ )
1203
- elif norm == "loglog" :
1204
- self .cost_ = np .log (1 + np .log (1 + self .cost_ ))
1205
-
1206
1166
1207
1167
class SinkhornTransport (BaseTransport ):
1208
1168
"""Domain Adapatation OT method based on Sinkhorn Algorithm
0 commit comments