14
14
from .bregman import sinkhorn
15
15
from .lp import emd
16
16
from .utils import unif , dist , kernel
17
- from .utils import deprecated , BaseEstimator
17
+ from .utils import check_params , deprecated , BaseEstimator
18
18
from .optim import cg
19
19
from .optim import gcg
20
20
@@ -954,6 +954,26 @@ def distribution_estimation_uniform(X):
954
954
955
955
956
956
class BaseTransport (BaseEstimator ):
957
+ """Base class for OTDA objects
958
+
959
+ Notes
960
+ -----
961
+ All estimators should specify all the parameters that can be set
962
+ at the class level in their ``__init__`` as explicit keyword
963
+ arguments (no ``*args`` or ``**kwargs``).
964
+
965
+ fit method should:
966
+ - estimate a cost matrix and store it in a `cost_` attribute
967
+ - estimate a coupling matrix and store it in a `coupling_`
968
+ attribute
969
+ - estimate distributions from source and target data and store them in
970
+ mu_s and mu_t attributes
971
+ - store Xs and Xt in attributes to be used later on in transform and
972
+ inverse_transform methods
973
+
974
+ transform method should always get as input a Xs parameter
975
+ inverse_transform method should always get as input a Xt parameter
976
+ """
957
977
958
978
def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
959
979
"""Build a coupling matrix from source and target sets of samples
@@ -976,7 +996,9 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
976
996
Returns self.
977
997
"""
978
998
979
- if Xs is not None and Xt is not None :
999
+ # check the necessary inputs parameters are here
1000
+ if check_params (Xs = Xs , Xt = Xt ):
1001
+
980
1002
# pairwise distance
981
1003
self .cost_ = dist (Xs , Xt , metric = self .metric )
982
1004
@@ -1003,14 +1025,10 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1003
1025
self .mu_t = self .distribution_estimation (Xt )
1004
1026
1005
1027
# store arrays of samples
1006
- self .Xs = Xs
1007
- self .Xt = Xt
1028
+ self .xs_ = Xs
1029
+ self .xt_ = Xt
1008
1030
1009
- return self
1010
- else :
1011
- print ("POT-Warning" )
1012
- print ("Please provide both Xs and Xt arguments when calling" )
1013
- print ("fit method" )
1031
+ return self
1014
1032
1015
1033
def fit_transform (self , Xs = None , ys = None , Xt = None , yt = None ):
1016
1034
"""Build a coupling matrix from source and target sets of samples
@@ -1058,16 +1076,19 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
1058
1076
The transport source samples.
1059
1077
"""
1060
1078
1061
- if Xs is not None :
1062
- if np .array_equal (self .Xs , Xs ):
1079
+ # check the necessary inputs parameters are here
1080
+ if check_params (Xs = Xs ):
1081
+
1082
+ if np .array_equal (self .xs_ , Xs ):
1083
+
1063
1084
# perform standard barycentric mapping
1064
1085
transp = self .coupling_ / np .sum (self .coupling_ , 1 )[:, None ]
1065
1086
1066
1087
# set nans to 0
1067
1088
transp [~ np .isfinite (transp )] = 0
1068
1089
1069
1090
# compute transported samples
1070
- transp_Xs = np .dot (transp , self .Xt )
1091
+ transp_Xs = np .dot (transp , self .xt_ )
1071
1092
else :
1072
1093
# perform out of sample mapping
1073
1094
indices = np .arange (Xs .shape [0 ])
@@ -1079,26 +1100,23 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
1079
1100
for bi in batch_ind :
1080
1101
1081
1102
# get the nearest neighbor in the source domain
1082
- D0 = dist (Xs [bi ], self .Xs )
1103
+ D0 = dist (Xs [bi ], self .xs_ )
1083
1104
idx = np .argmin (D0 , axis = 1 )
1084
1105
1085
1106
# transport the source samples
1086
1107
transp = self .coupling_ / np .sum (
1087
1108
self .coupling_ , 1 )[:, None ]
1088
1109
transp [~ np .isfinite (transp )] = 0
1089
- transp_Xs_ = np .dot (transp , self .Xt )
1110
+ transp_Xs_ = np .dot (transp , self .xt_ )
1090
1111
1091
1112
# define the transported points
1092
- transp_Xs_ = transp_Xs_ [idx , :] + Xs [bi ] - self .Xs [idx , :]
1113
+ transp_Xs_ = transp_Xs_ [idx , :] + Xs [bi ] - self .xs_ [idx , :]
1093
1114
1094
1115
transp_Xs .append (transp_Xs_ )
1095
1116
1096
1117
transp_Xs = np .concatenate (transp_Xs , axis = 0 )
1097
1118
1098
1119
return transp_Xs
1099
- else :
1100
- print ("POT-Warning" )
1101
- print ("Please provide Xs argument when calling transform method" )
1102
1120
1103
1121
def inverse_transform (self , Xs = None , ys = None , Xt = None , yt = None ,
1104
1122
batch_size = 128 ):
@@ -1123,16 +1141,19 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
1123
1141
The transported target samples.
1124
1142
"""
1125
1143
1126
- if Xt is not None :
1127
- if np .array_equal (self .Xt , Xt ):
1144
+ # check the necessary inputs parameters are here
1145
+ if check_params (Xt = Xt ):
1146
+
1147
+ if np .array_equal (self .xt_ , Xt ):
1148
+
1128
1149
# perform standard barycentric mapping
1129
1150
transp_ = self .coupling_ .T / np .sum (self .coupling_ , 0 )[:, None ]
1130
1151
1131
1152
# set nans to 0
1132
1153
transp_ [~ np .isfinite (transp_ )] = 0
1133
1154
1134
1155
# compute transported samples
1135
- transp_Xt = np .dot (transp_ , self .Xs )
1156
+ transp_Xt = np .dot (transp_ , self .xs_ )
1136
1157
else :
1137
1158
# perform out of sample mapping
1138
1159
indices = np .arange (Xt .shape [0 ])
@@ -1143,26 +1164,23 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
1143
1164
transp_Xt = []
1144
1165
for bi in batch_ind :
1145
1166
1146
- D0 = dist (Xt [bi ], self .Xt )
1167
+ D0 = dist (Xt [bi ], self .xt_ )
1147
1168
idx = np .argmin (D0 , axis = 1 )
1148
1169
1149
1170
# transport the target samples
1150
1171
transp_ = self .coupling_ .T / np .sum (
1151
1172
self .coupling_ , 0 )[:, None ]
1152
1173
transp_ [~ np .isfinite (transp_ )] = 0
1153
- transp_Xt_ = np .dot (transp_ , self .Xs )
1174
+ transp_Xt_ = np .dot (transp_ , self .xs_ )
1154
1175
1155
1176
# define the transported points
1156
- transp_Xt_ = transp_Xt_ [idx , :] + Xt [bi ] - self .Xt [idx , :]
1177
+ transp_Xt_ = transp_Xt_ [idx , :] + Xt [bi ] - self .xt_ [idx , :]
1157
1178
1158
1179
transp_Xt .append (transp_Xt_ )
1159
1180
1160
1181
transp_Xt = np .concatenate (transp_Xt , axis = 0 )
1161
1182
1162
1183
return transp_Xt
1163
- else :
1164
- print ("POT-Warning" )
1165
- print ("Please provide Xt argument when calling inverse_transform" )
1166
1184
1167
1185
1168
1186
class SinkhornTransport (BaseTransport ):
@@ -1428,7 +1446,8 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
1428
1446
Returns self.
1429
1447
"""
1430
1448
1431
- if Xs is not None and Xt is not None and ys is not None :
1449
+ # check the necessary inputs parameters are here
1450
+ if check_params (Xs = Xs , Xt = Xt , ys = ys ):
1432
1451
1433
1452
super (SinkhornLpl1Transport , self ).fit (Xs , ys , Xt , yt )
1434
1453
@@ -1438,10 +1457,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
1438
1457
numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
1439
1458
verbose = self .verbose )
1440
1459
1441
- return self
1442
- else :
1443
- print ("POT-Warning" )
1444
- print ("Please provide both Xs, Xt, ys arguments to fit method" )
1460
+ return self
1445
1461
1446
1462
1447
1463
class SinkhornL1l2Transport (BaseTransport ):
@@ -1537,7 +1553,8 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
1537
1553
Returns self.
1538
1554
"""
1539
1555
1540
- if Xs is not None and Xt is not None and ys is not None :
1556
+ # check the necessary inputs parameters are here
1557
+ if check_params (Xs = Xs , Xt = Xt , ys = ys ):
1541
1558
1542
1559
super (SinkhornL1l2Transport , self ).fit (Xs , ys , Xt , yt )
1543
1560
@@ -1554,10 +1571,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
1554
1571
self .coupling_ = returned_
1555
1572
self .log_ = dict ()
1556
1573
1557
- return self
1558
- else :
1559
- print ("POT-Warning" )
1560
- print ("Please, provide both Xs, Xt and ys argument to fit method" )
1574
+ return self
1561
1575
1562
1576
1563
1577
class MappingTransport (BaseEstimator ):
@@ -1652,29 +1666,35 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1652
1666
Returns self
1653
1667
"""
1654
1668
1655
- self .Xs = Xs
1656
- self .Xt = Xt
1657
-
1658
- if self .kernel == "linear" :
1659
- returned_ = joint_OT_mapping_linear (
1660
- Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1661
- verbose = self .verbose , verbose2 = self .verbose2 ,
1662
- numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
1663
- stopThr = self .tol , stopInnerThr = self .inner_tol , log = self .log )
1669
+ # check the necessary inputs parameters are here
1670
+ if check_params (Xs = Xs , Xt = Xt ):
1671
+
1672
+ self .xs_ = Xs
1673
+ self .xt_ = Xt
1674
+
1675
+ if self .kernel == "linear" :
1676
+ returned_ = joint_OT_mapping_linear (
1677
+ Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1678
+ verbose = self .verbose , verbose2 = self .verbose2 ,
1679
+ numItermax = self .max_iter ,
1680
+ numInnerItermax = self .max_inner_iter , stopThr = self .tol ,
1681
+ stopInnerThr = self .inner_tol , log = self .log )
1682
+
1683
+ elif self .kernel == "gaussian" :
1684
+ returned_ = joint_OT_mapping_kernel (
1685
+ Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1686
+ sigma = self .sigma , verbose = self .verbose ,
1687
+ verbose2 = self .verbose , numItermax = self .max_iter ,
1688
+ numInnerItermax = self .max_inner_iter ,
1689
+ stopInnerThr = self .inner_tol , stopThr = self .tol ,
1690
+ log = self .log )
1664
1691
1665
- elif self .kernel == "gaussian" :
1666
- returned_ = joint_OT_mapping_kernel (
1667
- Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1668
- sigma = self .sigma , verbose = self .verbose , verbose2 = self .verbose ,
1669
- numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
1670
- stopInnerThr = self .inner_tol , stopThr = self .tol , log = self .log )
1671
-
1672
- # deal with the value of log
1673
- if self .log :
1674
- self .coupling_ , self .mapping_ , self .log_ = returned_
1675
- else :
1676
- self .coupling_ , self .mapping_ = returned_
1677
- self .log_ = dict ()
1692
+ # deal with the value of log
1693
+ if self .log :
1694
+ self .coupling_ , self .mapping_ , self .log_ = returned_
1695
+ else :
1696
+ self .coupling_ , self .mapping_ = returned_
1697
+ self .log_ = dict ()
1678
1698
1679
1699
return self
1680
1700
@@ -1692,22 +1712,26 @@ def transform(self, Xs):
1692
1712
The transport source samples.
1693
1713
"""
1694
1714
1695
- if np .array_equal (self .Xs , Xs ):
1696
- # perform standard barycentric mapping
1697
- transp = self .coupling_ / np .sum (self .coupling_ , 1 )[:, None ]
1715
+ # check the necessary inputs parameters are here
1716
+ if check_params (Xs = Xs ):
1698
1717
1699
- # set nans to 0
1700
- transp [~ np .isfinite (transp )] = 0
1718
+ if np .array_equal (self .xs_ , Xs ):
1719
+ # perform standard barycentric mapping
1720
+ transp = self .coupling_ / np .sum (self .coupling_ , 1 )[:, None ]
1701
1721
1702
- # compute transported samples
1703
- transp_Xs = np .dot (transp , self .Xt )
1704
- else :
1705
- if self .kernel == "gaussian" :
1706
- K = kernel (Xs , self .Xs , method = self .kernel , sigma = self .sigma )
1707
- elif self .kernel == "linear" :
1708
- K = Xs
1709
- if self .bias :
1710
- K = np .hstack ((K , np .ones ((Xs .shape [0 ], 1 ))))
1711
- transp_Xs = K .dot (self .mapping_ )
1722
+ # set nans to 0
1723
+ transp [~ np .isfinite (transp )] = 0
1712
1724
1713
- return transp_Xs
1725
+ # compute transported samples
1726
+ transp_Xs = np .dot (transp , self .xt_ )
1727
+ else :
1728
+ if self .kernel == "gaussian" :
1729
+ K = kernel (Xs , self .xs_ , method = self .kernel ,
1730
+ sigma = self .sigma )
1731
+ elif self .kernel == "linear" :
1732
+ K = Xs
1733
+ if self .bias :
1734
+ K = np .hstack ((K , np .ones ((Xs .shape [0 ], 1 ))))
1735
+ transp_Xs = K .dot (self .mapping_ )
1736
+
1737
+ return transp_Xs
0 commit comments