Skip to content

Commit 43b2190

Browse files
authored
Merge pull request #140 from ievred/laplace_da
[MRG] Laplace regularized OTDA
2 parents f10f323 + a4426fd commit 43b2190

File tree

5 files changed

+473
-22
lines changed

5 files changed

+473
-22
lines changed

README.md

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
[![PyPI version](https://badge.fury.io/py/POT.svg)](https://badge.fury.io/py/POT)
44
[![Anaconda Cloud](https://anaconda.org/conda-forge/pot/badges/version.svg)](https://anaconda.org/conda-forge/pot)
5-
[![Build Status](https://travis-ci.org/rflamary/POT.svg?branch=master)](https://travis-ci.org/rflamary/POT)
5+
[![Build Status](https://travis-ci.org/rflamary/POT.svg?branch=master)](https://travis-ci.org/PythonOT/POT)
66
[![Documentation Status](https://readthedocs.org/projects/pot/badge/?version=latest)](http://pot.readthedocs.io/en/latest/?badge=latest)
77
[![Downloads](https://pepy.tech/badge/pot)](https://pepy.tech/project/pot)
88
[![Anaconda downloads](https://anaconda.org/conda-forge/pot/badges/downloads.svg)](https://anaconda.org/conda-forge/pot)
9-
[![License](https://anaconda.org/conda-forge/pot/badges/license.svg)](https://github.com/rflamary/POT/blob/master/LICENSE)
9+
[![License](https://anaconda.org/conda-forge/pot/badges/license.svg)](https://github.com/PythonOT/POT/blob/master/LICENSE)
1010

1111

1212

@@ -20,7 +20,7 @@ It provides the following solvers:
2020
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
2121
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
2222
* Bregman projections for Wasserstein barycenter [3], convolutional barycenter [21] and unmixing [4].
23-
* Optimal transport for domain adaptation with group lasso regularization [5]
23+
* Optimal transport for domain adaptation with group lasso regularization and Laplacian regularization [5][30]
2424
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
2525
* Linear OT [14] and Joint OT matrix and mapping estimation [8].
2626
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
@@ -140,26 +140,26 @@ ba=ot.barycenter(A,M,reg) # reg is regularization parameter
140140
The examples folder contain several examples and use case for the library. The full documentation is available on [Readthedocs](http://pot.readthedocs.io/).
141141

142142

143-
Here is a list of the Python notebooks available [here](https://github.com/rflamary/POT/blob/master/notebooks/) if you want a quick look:
143+
Here is a list of the Python notebooks available [here](https://github.com/PythonOT/POT/blob/master/notebooks/) if you want a quick look:
144144

145-
* [1D optimal transport](https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_1D.ipynb)
146-
* [OT Ground Loss](https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_L1_vs_L2.ipynb)
147-
* [Multiple EMD computation](https://github.com/rflamary/POT/blob/master/notebooks/plot_compute_emd.ipynb)
148-
* [2D optimal transport on empirical distributions](https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_2D_samples.ipynb)
149-
* [1D Wasserstein barycenter](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_1D.ipynb)
150-
* [OT with user provided regularization](https://github.com/rflamary/POT/blob/master/notebooks/plot_optim_OTreg.ipynb)
151-
* [Domain adaptation with optimal transport](https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_d2.ipynb)
152-
* [Color transfer in images](https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_color_images.ipynb)
153-
* [OT mapping estimation for domain adaptation](https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_mapping.ipynb)
154-
* [OT mapping estimation for color transfer in images](https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_mapping_colors_images.ipynb)
155-
* [Wasserstein Discriminant Analysis](https://github.com/rflamary/POT/blob/master/notebooks/plot_WDA.ipynb)
156-
* [Gromov Wasserstein](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov.ipynb)
157-
* [Gromov Wasserstein Barycenter](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov_barycenter.ipynb)
158-
* [Fused Gromov Wasserstein](https://github.com/rflamary/POT/blob/master/notebooks/plot_fgw.ipynb)
159-
* [Fused Gromov Wasserstein Barycenter](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb)
145+
* [1D optimal transport](https://github.com/PythonOT/POT/blob/master/notebooks/plot_OT_1D.ipynb)
146+
* [OT Ground Loss](https://github.com/PythonOT/POT/blob/master/notebooks/plot_OT_L1_vs_L2.ipynb)
147+
* [Multiple EMD computation](https://github.com/PythonOT/POT/blob/master/notebooks/plot_compute_emd.ipynb)
148+
* [2D optimal transport on empirical distributions](https://github.com/PythonOT/POT/blob/master/notebooks/plot_OT_2D_samples.ipynb)
149+
* [1D Wasserstein barycenter](https://github.com/PythonOT/POT/blob/master/notebooks/plot_barycenter_1D.ipynb)
150+
* [OT with user provided regularization](https://github.com/PythonOT/POT/blob/master/notebooks/plot_optim_OTreg.ipynb)
151+
* [Domain adaptation with optimal transport](https://github.com/PythonOT/POT/blob/master/notebooks/plot_otda_d2.ipynb)
152+
* [Color transfer in images](https://github.com/PythonOT/POT/blob/master/notebooks/plot_otda_color_images.ipynb)
153+
* [OT mapping estimation for domain adaptation](https://github.com/PythonOT/POT/blob/master/notebooks/plot_otda_mapping.ipynb)
154+
* [OT mapping estimation for color transfer in images](https://github.com/PythonOT/POT/blob/master/notebooks/plot_otda_mapping_colors_images.ipynb)
155+
* [Wasserstein Discriminant Analysis](https://github.com/PythonOT/POT/blob/master/notebooks/plot_WDA.ipynb)
156+
* [Gromov Wasserstein](https://github.com/PythonOT/POT/blob/master/notebooks/plot_gromov.ipynb)
157+
* [Gromov Wasserstein Barycenter](https://github.com/PythonOT/POT/blob/master/notebooks/plot_gromov_barycenter.ipynb)
158+
* [Fused Gromov Wasserstein](https://github.com/PythonOT/POT/blob/master/notebooks/plot_fgw.ipynb)
159+
* [Fused Gromov Wasserstein Barycenter](https://github.com/PythonOT/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb)
160160

161161

162-
You can also see the notebooks with [Jupyter nbviewer](https://nbviewer.jupyter.org/github/rflamary/POT/tree/master/notebooks/).
162+
You can also see the notebooks with [Jupyter nbviewer](https://nbviewer.jupyter.org/github/PythonOT/POT/tree/master/notebooks/).
163163

164164
## Acknowledgements
165165

@@ -184,6 +184,7 @@ The contributors to this library are
184184
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)
185185
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
186186
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
187+
* [Ievgen Redko](https://ievred.github.io/)
187188

188189
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
189190

@@ -264,4 +265,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t
264265

265266
[28] Caffarelli, L. A., McCann, R. J. (2020). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730.
266267

267-
[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276.
268+
[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276.
269+
270+
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.

examples/plot_otda_laplacian.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
======================================================
4+
OT with Laplacian regularization for domain adaptation
5+
======================================================
6+
7+
This example introduces a domain adaptation in a 2D setting and OTDA
8+
approach with Laplacian regularization.
9+
10+
"""
11+
12+
# Authors: Ievgen Redko <[email protected]>
13+
14+
# License: MIT License
15+
16+
import matplotlib.pylab as pl
17+
import ot
18+
19+
##############################################################################
20+
# Generate data
21+
# -------------
22+
23+
n_source_samples = 150
24+
n_target_samples = 150
25+
26+
Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples)
27+
Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples)
28+
29+
30+
##############################################################################
31+
# Instantiate the different transport algorithms and fit them
32+
# -----------------------------------------------------------
33+
34+
# EMD Transport
35+
ot_emd = ot.da.EMDTransport()
36+
ot_emd.fit(Xs=Xs, Xt=Xt)
37+
38+
# Sinkhorn Transport
39+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.01)
40+
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
41+
42+
# EMD Transport with Laplacian regularization
43+
ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1)
44+
ot_emd_laplace.fit(Xs=Xs, Xt=Xt)
45+
46+
# transport source samples onto target samples
47+
transp_Xs_emd = ot_emd.transform(Xs=Xs)
48+
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
49+
transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs)
50+
51+
##############################################################################
52+
# Fig 1 : plots source and target samples
53+
# ---------------------------------------
54+
55+
pl.figure(1, figsize=(10, 5))
56+
pl.subplot(1, 2, 1)
57+
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
58+
pl.xticks([])
59+
pl.yticks([])
60+
pl.legend(loc=0)
61+
pl.title('Source samples')
62+
63+
pl.subplot(1, 2, 2)
64+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
65+
pl.xticks([])
66+
pl.yticks([])
67+
pl.legend(loc=0)
68+
pl.title('Target samples')
69+
pl.tight_layout()
70+
71+
72+
##############################################################################
73+
# Fig 2 : plot optimal couplings and transported samples
74+
# ------------------------------------------------------
75+
76+
param_img = {'interpolation': 'nearest'}
77+
78+
pl.figure(2, figsize=(15, 8))
79+
pl.subplot(2, 3, 1)
80+
pl.imshow(ot_emd.coupling_, **param_img)
81+
pl.xticks([])
82+
pl.yticks([])
83+
pl.title('Optimal coupling\nEMDTransport')
84+
85+
pl.figure(2, figsize=(15, 8))
86+
pl.subplot(2, 3, 2)
87+
pl.imshow(ot_sinkhorn.coupling_, **param_img)
88+
pl.xticks([])
89+
pl.yticks([])
90+
pl.title('Optimal coupling\nSinkhornTransport')
91+
92+
pl.subplot(2, 3, 3)
93+
pl.imshow(ot_emd_laplace.coupling_, **param_img)
94+
pl.xticks([])
95+
pl.yticks([])
96+
pl.title('Optimal coupling\nEMDLaplaceTransport')
97+
98+
pl.subplot(2, 3, 4)
99+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
100+
label='Target samples', alpha=0.3)
101+
pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
102+
marker='+', label='Transp samples', s=30)
103+
pl.xticks([])
104+
pl.yticks([])
105+
pl.title('Transported samples\nEmdTransport')
106+
pl.legend(loc="lower left")
107+
108+
pl.subplot(2, 3, 5)
109+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
110+
label='Target samples', alpha=0.3)
111+
pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
112+
marker='+', label='Transp samples', s=30)
113+
pl.xticks([])
114+
pl.yticks([])
115+
pl.title('Transported samples\nSinkhornTransport')
116+
117+
pl.subplot(2, 3, 6)
118+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
119+
label='Target samples', alpha=0.3)
120+
pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys,
121+
marker='+', label='Transp samples', s=30)
122+
pl.xticks([])
123+
pl.yticks([])
124+
pl.title('Transported samples\nEMDLaplaceTransport')
125+
pl.tight_layout()
126+
127+
pl.show()

0 commit comments

Comments
 (0)