Skip to content

Commit c5d7c40

Browse files
committed
check input parameters with helper functions
1 parent a8fa91b commit c5d7c40

File tree

2 files changed

+120
-75
lines changed

2 files changed

+120
-75
lines changed

ot/da.py

Lines changed: 99 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .bregman import sinkhorn
1515
from .lp import emd
1616
from .utils import unif, dist, kernel
17-
from .utils import deprecated, BaseEstimator
17+
from .utils import check_params, deprecated, BaseEstimator
1818
from .optim import cg
1919
from .optim import gcg
2020

@@ -954,6 +954,26 @@ def distribution_estimation_uniform(X):
954954

955955

956956
class BaseTransport(BaseEstimator):
957+
"""Base class for OTDA objects
958+
959+
Notes
960+
-----
961+
All estimators should specify all the parameters that can be set
962+
at the class level in their ``__init__`` as explicit keyword
963+
arguments (no ``*args`` or ``**kwargs``).
964+
965+
fit method should:
966+
- estimate a cost matrix and store it in a `cost_` attribute
967+
- estimate a coupling matrix and store it in a `coupling_`
968+
attribute
969+
- estimate distributions from source and target data and store them in
970+
mu_s and mu_t attributes
971+
- store Xs and Xt in attributes to be used later on in transform and
972+
inverse_transform methods
973+
974+
transform method should always get as input a Xs parameter
975+
inverse_transform method should always get as input a Xt parameter
976+
"""
957977

958978
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
959979
"""Build a coupling matrix from source and target sets of samples
@@ -976,7 +996,9 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
976996
Returns self.
977997
"""
978998

979-
if Xs is not None and Xt is not None:
999+
# check the necessary inputs parameters are here
1000+
if check_params(Xs=Xs, Xt=Xt):
1001+
9801002
# pairwise distance
9811003
self.cost_ = dist(Xs, Xt, metric=self.metric)
9821004

@@ -1003,14 +1025,10 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
10031025
self.mu_t = self.distribution_estimation(Xt)
10041026

10051027
# store arrays of samples
1006-
self.Xs = Xs
1007-
self.Xt = Xt
1028+
self.xs_ = Xs
1029+
self.xt_ = Xt
10081030

1009-
return self
1010-
else:
1011-
print("POT-Warning")
1012-
print("Please provide both Xs and Xt arguments when calling")
1013-
print("fit method")
1031+
return self
10141032

10151033
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
10161034
"""Build a coupling matrix from source and target sets of samples
@@ -1058,16 +1076,19 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
10581076
The transport source samples.
10591077
"""
10601078

1061-
if Xs is not None:
1062-
if np.array_equal(self.Xs, Xs):
1079+
# check the necessary inputs parameters are here
1080+
if check_params(Xs=Xs):
1081+
1082+
if np.array_equal(self.xs_, Xs):
1083+
10631084
# perform standard barycentric mapping
10641085
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
10651086

10661087
# set nans to 0
10671088
transp[~ np.isfinite(transp)] = 0
10681089

10691090
# compute transported samples
1070-
transp_Xs = np.dot(transp, self.Xt)
1091+
transp_Xs = np.dot(transp, self.xt_)
10711092
else:
10721093
# perform out of sample mapping
10731094
indices = np.arange(Xs.shape[0])
@@ -1079,26 +1100,23 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
10791100
for bi in batch_ind:
10801101

10811102
# get the nearest neighbor in the source domain
1082-
D0 = dist(Xs[bi], self.Xs)
1103+
D0 = dist(Xs[bi], self.xs_)
10831104
idx = np.argmin(D0, axis=1)
10841105

10851106
# transport the source samples
10861107
transp = self.coupling_ / np.sum(
10871108
self.coupling_, 1)[:, None]
10881109
transp[~ np.isfinite(transp)] = 0
1089-
transp_Xs_ = np.dot(transp, self.Xt)
1110+
transp_Xs_ = np.dot(transp, self.xt_)
10901111

10911112
# define the transported points
1092-
transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.Xs[idx, :]
1113+
transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :]
10931114

10941115
transp_Xs.append(transp_Xs_)
10951116

10961117
transp_Xs = np.concatenate(transp_Xs, axis=0)
10971118

10981119
return transp_Xs
1099-
else:
1100-
print("POT-Warning")
1101-
print("Please provide Xs argument when calling transform method")
11021120

11031121
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11041122
batch_size=128):
@@ -1123,16 +1141,19 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11231141
The transported target samples.
11241142
"""
11251143

1126-
if Xt is not None:
1127-
if np.array_equal(self.Xt, Xt):
1144+
# check the necessary inputs parameters are here
1145+
if check_params(Xt=Xt):
1146+
1147+
if np.array_equal(self.xt_, Xt):
1148+
11281149
# perform standard barycentric mapping
11291150
transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None]
11301151

11311152
# set nans to 0
11321153
transp_[~ np.isfinite(transp_)] = 0
11331154

11341155
# compute transported samples
1135-
transp_Xt = np.dot(transp_, self.Xs)
1156+
transp_Xt = np.dot(transp_, self.xs_)
11361157
else:
11371158
# perform out of sample mapping
11381159
indices = np.arange(Xt.shape[0])
@@ -1143,26 +1164,23 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11431164
transp_Xt = []
11441165
for bi in batch_ind:
11451166

1146-
D0 = dist(Xt[bi], self.Xt)
1167+
D0 = dist(Xt[bi], self.xt_)
11471168
idx = np.argmin(D0, axis=1)
11481169

11491170
# transport the target samples
11501171
transp_ = self.coupling_.T / np.sum(
11511172
self.coupling_, 0)[:, None]
11521173
transp_[~ np.isfinite(transp_)] = 0
1153-
transp_Xt_ = np.dot(transp_, self.Xs)
1174+
transp_Xt_ = np.dot(transp_, self.xs_)
11541175

11551176
# define the transported points
1156-
transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.Xt[idx, :]
1177+
transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :]
11571178

11581179
transp_Xt.append(transp_Xt_)
11591180

11601181
transp_Xt = np.concatenate(transp_Xt, axis=0)
11611182

11621183
return transp_Xt
1163-
else:
1164-
print("POT-Warning")
1165-
print("Please provide Xt argument when calling inverse_transform")
11661184

11671185

11681186
class SinkhornTransport(BaseTransport):
@@ -1428,7 +1446,8 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
14281446
Returns self.
14291447
"""
14301448

1431-
if Xs is not None and Xt is not None and ys is not None:
1449+
# check the necessary inputs parameters are here
1450+
if check_params(Xs=Xs, Xt=Xt, ys=ys):
14321451

14331452
super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
14341453

@@ -1438,10 +1457,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
14381457
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
14391458
verbose=self.verbose)
14401459

1441-
return self
1442-
else:
1443-
print("POT-Warning")
1444-
print("Please provide both Xs, Xt, ys arguments to fit method")
1460+
return self
14451461

14461462

14471463
class SinkhornL1l2Transport(BaseTransport):
@@ -1537,7 +1553,8 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
15371553
Returns self.
15381554
"""
15391555

1540-
if Xs is not None and Xt is not None and ys is not None:
1556+
# check the necessary inputs parameters are here
1557+
if check_params(Xs=Xs, Xt=Xt, ys=ys):
15411558

15421559
super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)
15431560

@@ -1554,10 +1571,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
15541571
self.coupling_ = returned_
15551572
self.log_ = dict()
15561573

1557-
return self
1558-
else:
1559-
print("POT-Warning")
1560-
print("Please, provide both Xs, Xt and ys argument to fit method")
1574+
return self
15611575

15621576

15631577
class MappingTransport(BaseEstimator):
@@ -1652,29 +1666,35 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
16521666
Returns self
16531667
"""
16541668

1655-
self.Xs = Xs
1656-
self.Xt = Xt
1657-
1658-
if self.kernel == "linear":
1659-
returned_ = joint_OT_mapping_linear(
1660-
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
1661-
verbose=self.verbose, verbose2=self.verbose2,
1662-
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
1663-
stopThr=self.tol, stopInnerThr=self.inner_tol, log=self.log)
1669+
# check the necessary inputs parameters are here
1670+
if check_params(Xs=Xs, Xt=Xt):
1671+
1672+
self.xs_ = Xs
1673+
self.xt_ = Xt
1674+
1675+
if self.kernel == "linear":
1676+
returned_ = joint_OT_mapping_linear(
1677+
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
1678+
verbose=self.verbose, verbose2=self.verbose2,
1679+
numItermax=self.max_iter,
1680+
numInnerItermax=self.max_inner_iter, stopThr=self.tol,
1681+
stopInnerThr=self.inner_tol, log=self.log)
1682+
1683+
elif self.kernel == "gaussian":
1684+
returned_ = joint_OT_mapping_kernel(
1685+
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
1686+
sigma=self.sigma, verbose=self.verbose,
1687+
verbose2=self.verbose, numItermax=self.max_iter,
1688+
numInnerItermax=self.max_inner_iter,
1689+
stopInnerThr=self.inner_tol, stopThr=self.tol,
1690+
log=self.log)
16641691

1665-
elif self.kernel == "gaussian":
1666-
returned_ = joint_OT_mapping_kernel(
1667-
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
1668-
sigma=self.sigma, verbose=self.verbose, verbose2=self.verbose,
1669-
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
1670-
stopInnerThr=self.inner_tol, stopThr=self.tol, log=self.log)
1671-
1672-
# deal with the value of log
1673-
if self.log:
1674-
self.coupling_, self.mapping_, self.log_ = returned_
1675-
else:
1676-
self.coupling_, self.mapping_ = returned_
1677-
self.log_ = dict()
1692+
# deal with the value of log
1693+
if self.log:
1694+
self.coupling_, self.mapping_, self.log_ = returned_
1695+
else:
1696+
self.coupling_, self.mapping_ = returned_
1697+
self.log_ = dict()
16781698

16791699
return self
16801700

@@ -1692,22 +1712,26 @@ def transform(self, Xs):
16921712
The transport source samples.
16931713
"""
16941714

1695-
if np.array_equal(self.Xs, Xs):
1696-
# perform standard barycentric mapping
1697-
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
1715+
# check the necessary inputs parameters are here
1716+
if check_params(Xs=Xs):
16981717

1699-
# set nans to 0
1700-
transp[~ np.isfinite(transp)] = 0
1718+
if np.array_equal(self.xs_, Xs):
1719+
# perform standard barycentric mapping
1720+
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
17011721

1702-
# compute transported samples
1703-
transp_Xs = np.dot(transp, self.Xt)
1704-
else:
1705-
if self.kernel == "gaussian":
1706-
K = kernel(Xs, self.Xs, method=self.kernel, sigma=self.sigma)
1707-
elif self.kernel == "linear":
1708-
K = Xs
1709-
if self.bias:
1710-
K = np.hstack((K, np.ones((Xs.shape[0], 1))))
1711-
transp_Xs = K.dot(self.mapping_)
1722+
# set nans to 0
1723+
transp[~ np.isfinite(transp)] = 0
17121724

1713-
return transp_Xs
1725+
# compute transported samples
1726+
transp_Xs = np.dot(transp, self.xt_)
1727+
else:
1728+
if self.kernel == "gaussian":
1729+
K = kernel(Xs, self.xs_, method=self.kernel,
1730+
sigma=self.sigma)
1731+
elif self.kernel == "linear":
1732+
K = Xs
1733+
if self.bias:
1734+
K = np.hstack((K, np.ones((Xs.shape[0], 1))))
1735+
transp_Xs = K.dot(self.mapping_)
1736+
1737+
return transp_Xs

ot/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,27 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()):
168168
return [x for i, x in sorted(res)]
169169

170170

171+
def check_params(**kwargs):
172+
"""check_params: check whether some parameters are missing
173+
"""
174+
175+
missing_params = []
176+
check = True
177+
178+
for param in kwargs:
179+
if kwargs[param] is None:
180+
missing_params.append(param)
181+
182+
if len(missing_params) > 0:
183+
print("POT - Warning: following necessary parameters are missing")
184+
for p in missing_params:
185+
print("\n", p)
186+
187+
check = False
188+
189+
return check
190+
191+
171192
class deprecated(object):
172193
"""Decorator to mark a function or class as deprecated.
173194

0 commit comments

Comments
 (0)