Skip to content

Commit 0928668

Browse files
committed
Added Unbalaced transport to domain adaptation methods. Corrected small bug related to warnings in unbalaced.py . Added an error message when user wants to normalize with other than expected cost normalization functions.
1 parent b2157e9 commit 0928668

File tree

3 files changed

+126
-2
lines changed

3 files changed

+126
-2
lines changed

ot/da.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Author: Remi Flamary <[email protected]>
77
# Nicolas Courty <[email protected]>
88
# Michael Perrot <[email protected]>
9+
# Nathalie Gayraud <[email protected]>
910
#
1011
# License: MIT License
1112

@@ -16,6 +17,7 @@
1617
from .lp import emd
1718
from .utils import unif, dist, kernel, cost_normalization
1819
from .utils import check_params, BaseEstimator
20+
from .unbalanced import sinkhorn_unbalanced
1921
from .optim import cg
2022
from .optim import gcg
2123

@@ -1793,3 +1795,122 @@ def transform(self, Xs):
17931795
transp_Xs = K.dot(self.mapping_)
17941796

17951797
return transp_Xs
1798+
1799+
1800+
class UnbalancedSinkhornTransport(BaseTransport):
1801+
1802+
"""Domain Adapatation unbalanced OT method based on sinkhorn algorithm
1803+
1804+
Parameters
1805+
----------
1806+
reg_e : float, optional (default=1)
1807+
Entropic regularization parameter
1808+
reg_m : float, optional (default=0.1)
1809+
Mass regularization parameter
1810+
method : str
1811+
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
1812+
'sinkhorn_epsilon_scaling', see those function for specific parameters
1813+
max_iter : int, float, optional (default=10)
1814+
The minimum number of iteration before stopping the optimization
1815+
algorithm if no it has not converged
1816+
tol : float, optional (default=10e-9)
1817+
Stop threshold on error (inner sinkhorn solver) (>0)
1818+
verbose : bool, optional (default=False)
1819+
Controls the verbosity of the optimization algorithm
1820+
log : bool, optional (default=False)
1821+
Controls the logs of the optimization algorithm
1822+
metric : string, optional (default="sqeuclidean")
1823+
The ground metric for the Wasserstein problem
1824+
norm : string, optional (default=None)
1825+
If given, normalize the ground metric to avoid numerical errors that
1826+
can occur with large metric values.
1827+
distribution_estimation : callable, optional (defaults to the uniform)
1828+
The kind of distribution estimation to employ
1829+
out_of_sample_map : string, optional (default="ferradans")
1830+
The kind of out of sample mapping to apply to transport samples
1831+
from a domain into another one. Currently the only possible option is
1832+
"ferradans" which uses the method proposed in [6].
1833+
limit_max: float, optional (default=10)
1834+
Controls the semi supervised mode. Transport between labeled source
1835+
and target samples of different classes will exhibit an infinite cost
1836+
(10 times the maximum value of the cost matrix)
1837+
1838+
Attributes
1839+
----------
1840+
coupling_ : array-like, shape (n_source_samples, n_target_samples)
1841+
The optimal coupling
1842+
log_ : dictionary
1843+
The dictionary of log, empty dic if parameter log is not True
1844+
1845+
References
1846+
----------
1847+
1848+
.. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
1849+
Scaling algorithms for unbalanced transport problems. arXiv preprint
1850+
arXiv:1607.05816.
1851+
1852+
"""
1853+
1854+
def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
1855+
max_iter=10, tol=10e-9, verbose=False, log=False,
1856+
metric="sqeuclidean", norm=None,
1857+
distribution_estimation=distribution_estimation_uniform,
1858+
out_of_sample_map='ferradans', limit_max=10):
1859+
1860+
self.reg_e = reg_e
1861+
self.reg_m = reg_m
1862+
self.method = method
1863+
self.max_iter = max_iter
1864+
self.tol = tol
1865+
self.verbose = verbose
1866+
self.log = log
1867+
self.metric = metric
1868+
self.norm = norm
1869+
self.distribution_estimation = distribution_estimation
1870+
self.out_of_sample_map = out_of_sample_map
1871+
self.limit_max = limit_max
1872+
1873+
def fit(self, Xs, ys=None, Xt=None, yt=None):
1874+
"""Build a coupling matrix from source and target sets of samples
1875+
(Xs, ys) and (Xt, yt)
1876+
1877+
Parameters
1878+
----------
1879+
Xs : array-like, shape (n_source_samples, n_features)
1880+
The training input samples.
1881+
ys : array-like, shape (n_source_samples,)
1882+
The class labels
1883+
Xt : array-like, shape (n_target_samples, n_features)
1884+
The training input samples.
1885+
yt : array-like, shape (n_target_samples,)
1886+
The class labels. If some target samples are unlabeled, fill the
1887+
yt's elements with -1.
1888+
1889+
Warning: Note that, due to this convention -1 cannot be used as a
1890+
class label
1891+
1892+
Returns
1893+
-------
1894+
self : object
1895+
Returns self.
1896+
"""
1897+
1898+
# check the necessary inputs parameters are here
1899+
if check_params(Xs=Xs, Xt=Xt):
1900+
1901+
super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt)
1902+
1903+
returned_ = sinkhorn_unbalanced(
1904+
a=self.mu_s, b=self.mu_t, M=self.cost_,
1905+
reg=self.reg_e, alpha=self.reg_m, method=self.method,
1906+
numItermax=self.max_iter, stopThr=self.tol,
1907+
verbose=self.verbose, log=self.log)
1908+
1909+
# deal with the value of log
1910+
if self.log:
1911+
self.coupling_, self.log_ = returned_
1912+
else:
1913+
self.coupling_ = returned_
1914+
self.log_ = dict()
1915+
1916+
return self

ot/unbalanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
364364
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
365365
# we have reached the machine precision
366366
# come back to previous solution and quit loop
367-
warnings.warn('Numerical errors at iteration', cpt)
367+
warnings.warn('Numerical errors at iteration %s' % cpt)
368368
u = uprev
369369
v = vprev
370370
break

ot/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,10 @@ def cost_normalization(C, norm=None):
186186
C = np.log(1 + C)
187187
elif norm == "loglog":
188188
C = np.log1p(np.log1p(C))
189-
189+
else:
190+
raise ValueError(f'Norm {norm} is not a valid option. '
191+
f'Valid options are:\n'
192+
f'median, max, log, loglog')
190193
return C
191194

192195

0 commit comments

Comments
 (0)