|
6 | 6 | # Author: Remi Flamary <[email protected]>
|
7 | 7 | # Nicolas Courty <[email protected]>
|
8 | 8 | # Michael Perrot <[email protected]>
|
| 9 | +# Nathalie Gayraud <[email protected]> |
9 | 10 | #
|
10 | 11 | # License: MIT License
|
11 | 12 |
|
|
16 | 17 | from .lp import emd
|
17 | 18 | from .utils import unif, dist, kernel, cost_normalization
|
18 | 19 | from .utils import check_params, BaseEstimator
|
| 20 | +from .unbalanced import sinkhorn_unbalanced |
19 | 21 | from .optim import cg
|
20 | 22 | from .optim import gcg
|
21 | 23 |
|
@@ -1793,3 +1795,122 @@ def transform(self, Xs):
|
1793 | 1795 | transp_Xs = K.dot(self.mapping_)
|
1794 | 1796 |
|
1795 | 1797 | return transp_Xs
|
| 1798 | + |
| 1799 | + |
| 1800 | +class UnbalancedSinkhornTransport(BaseTransport): |
| 1801 | + |
| 1802 | + """Domain Adapatation unbalanced OT method based on sinkhorn algorithm |
| 1803 | +
|
| 1804 | + Parameters |
| 1805 | + ---------- |
| 1806 | + reg_e : float, optional (default=1) |
| 1807 | + Entropic regularization parameter |
| 1808 | + reg_m : float, optional (default=0.1) |
| 1809 | + Mass regularization parameter |
| 1810 | + method : str |
| 1811 | + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or |
| 1812 | + 'sinkhorn_epsilon_scaling', see those function for specific parameters |
| 1813 | + max_iter : int, float, optional (default=10) |
| 1814 | + The minimum number of iteration before stopping the optimization |
| 1815 | + algorithm if no it has not converged |
| 1816 | + tol : float, optional (default=10e-9) |
| 1817 | + Stop threshold on error (inner sinkhorn solver) (>0) |
| 1818 | + verbose : bool, optional (default=False) |
| 1819 | + Controls the verbosity of the optimization algorithm |
| 1820 | + log : bool, optional (default=False) |
| 1821 | + Controls the logs of the optimization algorithm |
| 1822 | + metric : string, optional (default="sqeuclidean") |
| 1823 | + The ground metric for the Wasserstein problem |
| 1824 | + norm : string, optional (default=None) |
| 1825 | + If given, normalize the ground metric to avoid numerical errors that |
| 1826 | + can occur with large metric values. |
| 1827 | + distribution_estimation : callable, optional (defaults to the uniform) |
| 1828 | + The kind of distribution estimation to employ |
| 1829 | + out_of_sample_map : string, optional (default="ferradans") |
| 1830 | + The kind of out of sample mapping to apply to transport samples |
| 1831 | + from a domain into another one. Currently the only possible option is |
| 1832 | + "ferradans" which uses the method proposed in [6]. |
| 1833 | + limit_max: float, optional (default=10) |
| 1834 | + Controls the semi supervised mode. Transport between labeled source |
| 1835 | + and target samples of different classes will exhibit an infinite cost |
| 1836 | + (10 times the maximum value of the cost matrix) |
| 1837 | +
|
| 1838 | + Attributes |
| 1839 | + ---------- |
| 1840 | + coupling_ : array-like, shape (n_source_samples, n_target_samples) |
| 1841 | + The optimal coupling |
| 1842 | + log_ : dictionary |
| 1843 | + The dictionary of log, empty dic if parameter log is not True |
| 1844 | +
|
| 1845 | + References |
| 1846 | + ---------- |
| 1847 | +
|
| 1848 | + .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). |
| 1849 | + Scaling algorithms for unbalanced transport problems. arXiv preprint |
| 1850 | + arXiv:1607.05816. |
| 1851 | +
|
| 1852 | + """ |
| 1853 | + |
| 1854 | + def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', |
| 1855 | + max_iter=10, tol=1e-9, verbose=False, log=False, |
| 1856 | + metric="sqeuclidean", norm=None, |
| 1857 | + distribution_estimation=distribution_estimation_uniform, |
| 1858 | + out_of_sample_map='ferradans', limit_max=10): |
| 1859 | + |
| 1860 | + self.reg_e = reg_e |
| 1861 | + self.reg_m = reg_m |
| 1862 | + self.method = method |
| 1863 | + self.max_iter = max_iter |
| 1864 | + self.tol = tol |
| 1865 | + self.verbose = verbose |
| 1866 | + self.log = log |
| 1867 | + self.metric = metric |
| 1868 | + self.norm = norm |
| 1869 | + self.distribution_estimation = distribution_estimation |
| 1870 | + self.out_of_sample_map = out_of_sample_map |
| 1871 | + self.limit_max = limit_max |
| 1872 | + |
| 1873 | + def fit(self, Xs, ys=None, Xt=None, yt=None): |
| 1874 | + """Build a coupling matrix from source and target sets of samples |
| 1875 | + (Xs, ys) and (Xt, yt) |
| 1876 | +
|
| 1877 | + Parameters |
| 1878 | + ---------- |
| 1879 | + Xs : array-like, shape (n_source_samples, n_features) |
| 1880 | + The training input samples. |
| 1881 | + ys : array-like, shape (n_source_samples,) |
| 1882 | + The class labels |
| 1883 | + Xt : array-like, shape (n_target_samples, n_features) |
| 1884 | + The training input samples. |
| 1885 | + yt : array-like, shape (n_target_samples,) |
| 1886 | + The class labels. If some target samples are unlabeled, fill the |
| 1887 | + yt's elements with -1. |
| 1888 | +
|
| 1889 | + Warning: Note that, due to this convention -1 cannot be used as a |
| 1890 | + class label |
| 1891 | +
|
| 1892 | + Returns |
| 1893 | + ------- |
| 1894 | + self : object |
| 1895 | + Returns self. |
| 1896 | + """ |
| 1897 | + |
| 1898 | + # check the necessary inputs parameters are here |
| 1899 | + if check_params(Xs=Xs, Xt=Xt): |
| 1900 | + |
| 1901 | + super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt) |
| 1902 | + |
| 1903 | + returned_ = sinkhorn_unbalanced( |
| 1904 | + a=self.mu_s, b=self.mu_t, M=self.cost_, |
| 1905 | + reg=self.reg_e, alpha=self.reg_m, method=self.method, |
| 1906 | + numItermax=self.max_iter, stopThr=self.tol, |
| 1907 | + verbose=self.verbose, log=self.log) |
| 1908 | + |
| 1909 | + # deal with the value of log |
| 1910 | + if self.log: |
| 1911 | + self.coupling_, self.log_ = returned_ |
| 1912 | + else: |
| 1913 | + self.coupling_ = returned_ |
| 1914 | + self.log_ = dict() |
| 1915 | + |
| 1916 | + return self |
0 commit comments