Skip to content

Commit abfe183

Browse files
authored
Merge pull request #100 from ngayraud/add_unbalanced_da
[MRG] Adds Unbalaced transport to domain adaptation methods + bugfixes
2 parents b2157e9 + ce86d14 commit abfe183

File tree

5 files changed

+197
-5
lines changed

5 files changed

+197
-5
lines changed

.travis.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@ matrix:
1313
python: 3.5
1414
- os: linux
1515
sudo: required
16-
python: 3.6
16+
python: 3.6
1717
- os: linux
1818
sudo: required
1919
python: 2.7
2020
before_install:
2121
- ./.travis/before_install.sh
2222
before_script: # configure a headless display to test plot generation
2323
- "export DISPLAY=:99.0"
24-
- "sh -e /etc/init.d/xvfb start"
2524
- sleep 3 # give xvfb some time to start
2625
# command to install dependencies
2726
install:
@@ -30,6 +29,8 @@ install:
3029
- pip install flake8 pytest "pytest-cov<2.6"
3130
- pip install .
3231
# command to run tests + check syntax style
32+
services:
33+
- xvfb
3334
script:
3435
- python setup.py develop
3536
- flake8 examples/ ot/ test/

ot/da.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Author: Remi Flamary <[email protected]>
77
# Nicolas Courty <[email protected]>
88
# Michael Perrot <[email protected]>
9+
# Nathalie Gayraud <[email protected]>
910
#
1011
# License: MIT License
1112

@@ -16,6 +17,7 @@
1617
from .lp import emd
1718
from .utils import unif, dist, kernel, cost_normalization
1819
from .utils import check_params, BaseEstimator
20+
from .unbalanced import sinkhorn_unbalanced
1921
from .optim import cg
2022
from .optim import gcg
2123

@@ -1793,3 +1795,122 @@ def transform(self, Xs):
17931795
transp_Xs = K.dot(self.mapping_)
17941796

17951797
return transp_Xs
1798+
1799+
1800+
class UnbalancedSinkhornTransport(BaseTransport):
1801+
1802+
"""Domain Adapatation unbalanced OT method based on sinkhorn algorithm
1803+
1804+
Parameters
1805+
----------
1806+
reg_e : float, optional (default=1)
1807+
Entropic regularization parameter
1808+
reg_m : float, optional (default=0.1)
1809+
Mass regularization parameter
1810+
method : str
1811+
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
1812+
'sinkhorn_epsilon_scaling', see those function for specific parameters
1813+
max_iter : int, float, optional (default=10)
1814+
The minimum number of iteration before stopping the optimization
1815+
algorithm if no it has not converged
1816+
tol : float, optional (default=10e-9)
1817+
Stop threshold on error (inner sinkhorn solver) (>0)
1818+
verbose : bool, optional (default=False)
1819+
Controls the verbosity of the optimization algorithm
1820+
log : bool, optional (default=False)
1821+
Controls the logs of the optimization algorithm
1822+
metric : string, optional (default="sqeuclidean")
1823+
The ground metric for the Wasserstein problem
1824+
norm : string, optional (default=None)
1825+
If given, normalize the ground metric to avoid numerical errors that
1826+
can occur with large metric values.
1827+
distribution_estimation : callable, optional (defaults to the uniform)
1828+
The kind of distribution estimation to employ
1829+
out_of_sample_map : string, optional (default="ferradans")
1830+
The kind of out of sample mapping to apply to transport samples
1831+
from a domain into another one. Currently the only possible option is
1832+
"ferradans" which uses the method proposed in [6].
1833+
limit_max: float, optional (default=10)
1834+
Controls the semi supervised mode. Transport between labeled source
1835+
and target samples of different classes will exhibit an infinite cost
1836+
(10 times the maximum value of the cost matrix)
1837+
1838+
Attributes
1839+
----------
1840+
coupling_ : array-like, shape (n_source_samples, n_target_samples)
1841+
The optimal coupling
1842+
log_ : dictionary
1843+
The dictionary of log, empty dic if parameter log is not True
1844+
1845+
References
1846+
----------
1847+
1848+
.. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
1849+
Scaling algorithms for unbalanced transport problems. arXiv preprint
1850+
arXiv:1607.05816.
1851+
1852+
"""
1853+
1854+
def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
1855+
max_iter=10, tol=1e-9, verbose=False, log=False,
1856+
metric="sqeuclidean", norm=None,
1857+
distribution_estimation=distribution_estimation_uniform,
1858+
out_of_sample_map='ferradans', limit_max=10):
1859+
1860+
self.reg_e = reg_e
1861+
self.reg_m = reg_m
1862+
self.method = method
1863+
self.max_iter = max_iter
1864+
self.tol = tol
1865+
self.verbose = verbose
1866+
self.log = log
1867+
self.metric = metric
1868+
self.norm = norm
1869+
self.distribution_estimation = distribution_estimation
1870+
self.out_of_sample_map = out_of_sample_map
1871+
self.limit_max = limit_max
1872+
1873+
def fit(self, Xs, ys=None, Xt=None, yt=None):
1874+
"""Build a coupling matrix from source and target sets of samples
1875+
(Xs, ys) and (Xt, yt)
1876+
1877+
Parameters
1878+
----------
1879+
Xs : array-like, shape (n_source_samples, n_features)
1880+
The training input samples.
1881+
ys : array-like, shape (n_source_samples,)
1882+
The class labels
1883+
Xt : array-like, shape (n_target_samples, n_features)
1884+
The training input samples.
1885+
yt : array-like, shape (n_target_samples,)
1886+
The class labels. If some target samples are unlabeled, fill the
1887+
yt's elements with -1.
1888+
1889+
Warning: Note that, due to this convention -1 cannot be used as a
1890+
class label
1891+
1892+
Returns
1893+
-------
1894+
self : object
1895+
Returns self.
1896+
"""
1897+
1898+
# check the necessary inputs parameters are here
1899+
if check_params(Xs=Xs, Xt=Xt):
1900+
1901+
super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt)
1902+
1903+
returned_ = sinkhorn_unbalanced(
1904+
a=self.mu_s, b=self.mu_t, M=self.cost_,
1905+
reg=self.reg_e, alpha=self.reg_m, method=self.method,
1906+
numItermax=self.max_iter, stopThr=self.tol,
1907+
verbose=self.verbose, log=self.log)
1908+
1909+
# deal with the value of log
1910+
if self.log:
1911+
self.coupling_, self.log_ = returned_
1912+
else:
1913+
self.coupling_ = returned_
1914+
self.log_ = dict()
1915+
1916+
return self

ot/unbalanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
364364
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
365365
# we have reached the machine precision
366366
# come back to previous solution and quit loop
367-
warnings.warn('Numerical errors at iteration', cpt)
367+
warnings.warn('Numerical errors at iteration %s' % cpt)
368368
u = uprev
369369
v = vprev
370370
break

ot/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,20 @@ def cost_normalization(C, norm=None):
178178
The input cost matrix normalized according to given norm.
179179
"""
180180

181-
if norm == "median":
181+
if norm is None:
182+
pass
183+
elif norm == "median":
182184
C /= float(np.median(C))
183185
elif norm == "max":
184186
C /= float(np.max(C))
185187
elif norm == "log":
186188
C = np.log(1 + C)
187189
elif norm == "loglog":
188190
C = np.log1p(np.log1p(C))
189-
191+
else:
192+
raise ValueError('Norm %s is not a valid option.\n'
193+
'Valid options are:\n'
194+
'median, max, log, loglog' % norm)
190195
return C
191196

192197

test/test_da.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,71 @@ def test_sinkhorn_transport_class():
245245
assert len(otda.log_.keys()) != 0
246246

247247

248+
def test_unbalanced_sinkhorn_transport_class():
249+
"""test_sinkhorn_transport
250+
"""
251+
252+
ns = 150
253+
nt = 200
254+
255+
Xs, ys = make_data_classif('3gauss', ns)
256+
Xt, yt = make_data_classif('3gauss2', nt)
257+
258+
otda = ot.da.UnbalancedSinkhornTransport()
259+
260+
# test its computed
261+
otda.fit(Xs=Xs, Xt=Xt)
262+
assert hasattr(otda, "cost_")
263+
assert hasattr(otda, "coupling_")
264+
assert hasattr(otda, "log_")
265+
266+
# test dimensions of coupling
267+
assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
268+
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
269+
270+
# test transform
271+
transp_Xs = otda.transform(Xs=Xs)
272+
assert_equal(transp_Xs.shape, Xs.shape)
273+
274+
Xs_new, _ = make_data_classif('3gauss', ns + 1)
275+
transp_Xs_new = otda.transform(Xs_new)
276+
277+
# check that the oos method is working
278+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
279+
280+
# test inverse transform
281+
transp_Xt = otda.inverse_transform(Xt=Xt)
282+
assert_equal(transp_Xt.shape, Xt.shape)
283+
284+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
285+
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
286+
287+
# check that the oos method is working
288+
assert_equal(transp_Xt_new.shape, Xt_new.shape)
289+
290+
# test fit_transform
291+
transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
292+
assert_equal(transp_Xs.shape, Xs.shape)
293+
294+
# test unsupervised vs semi-supervised mode
295+
otda_unsup = ot.da.SinkhornTransport()
296+
otda_unsup.fit(Xs=Xs, Xt=Xt)
297+
n_unsup = np.sum(otda_unsup.cost_)
298+
299+
otda_semi = ot.da.SinkhornTransport()
300+
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
301+
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
302+
n_semisup = np.sum(otda_semi.cost_)
303+
304+
# check that the cost matrix norms are indeed different
305+
assert n_unsup != n_semisup, "semisupervised mode not working"
306+
307+
# check everything runs well with log=True
308+
otda = ot.da.SinkhornTransport(log=True)
309+
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
310+
assert len(otda.log_.keys()) != 0
311+
312+
248313
def test_emd_transport_class():
249314
"""test_sinkhorn_transport
250315
"""

0 commit comments

Comments
 (0)