Skip to content

Commit a9427a1

Browse files
authored
Merge branch 'master' into autonb
2 parents 2324b1f + e70d542 commit a9427a1

19 files changed

+1520
-357
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ It provides the following solvers:
1616
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
1717
* Joint OT matrix and mapping estimation [8].
1818
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
19-
19+
* Gromov-Wasserstein distances and barycenters [12]
2020

2121
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
2222

@@ -138,12 +138,12 @@ The contributors to this library are:
138138
* [Léo Gautheron](https://github.com/aje) (GPU implementation)
139139
* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1)
140140
* [Stanislas Chambon](https://slasnista.github.io/)
141+
* [Antoine Rolet](https://arolet.github.io/)
141142

142143
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
143144

144145
* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab)
145146
* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) ( C++ code for EMD)
146-
* [Antoine Rolet](https://arolet.github.io/) ( Mex file for EMD )
147147
* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda)
148148

149149

@@ -184,3 +184,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
184184
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816.
185185

186186
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063.
187+
188+
[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.

data/cross.png

230 Bytes
Loading

data/square.png

168 Bytes
Loading

data/star.png

225 Bytes
Loading

data/triangle.png

254 Bytes
Loading
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
============================================
4+
OTDA unsupervised vs semi-supervised setting
5+
============================================
6+
7+
This example introduces a semi supervised domain adaptation in a 2D setting.
8+
It explicits the problem of semi supervised domain adaptation and introduces
9+
some optimal transport approaches to solve it.
10+
11+
Quantities such as optimal couplings, greater coupling coefficients and
12+
transported samples are represented in order to give a visual understanding
13+
of what the transport methods are doing.
14+
"""
15+
16+
# Authors: Remi Flamary <[email protected]>
17+
# Stanislas Chambon <[email protected]>
18+
#
19+
# License: MIT License
20+
21+
import matplotlib.pylab as pl
22+
import ot
23+
24+
25+
##############################################################################
26+
# generate data
27+
##############################################################################
28+
29+
n_samples_source = 150
30+
n_samples_target = 150
31+
32+
Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
33+
Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
34+
35+
36+
##############################################################################
37+
# Transport source samples onto target samples
38+
##############################################################################
39+
40+
# unsupervised domain adaptation
41+
ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
42+
ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
43+
transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
44+
45+
# semi-supervised domain adaptation
46+
ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
47+
ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
48+
transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
49+
50+
# semi supervised DA uses available labaled target samples to modify the cost
51+
# matrix involved in the OT problem. The cost of transporting a source sample
52+
# of class A onto a target sample of class B != A is set to infinite, or a
53+
# very large value
54+
55+
# note that in the present case we consider that all the target samples are
56+
# labeled. For daily applications, some target sample might not have labels,
57+
# in this case the element of yt corresponding to these samples should be
58+
# filled with -1.
59+
60+
# Warning: we recall that -1 cannot be used as a class label
61+
62+
63+
##############################################################################
64+
# Fig 1 : plots source and target samples + matrix of pairwise distance
65+
##############################################################################
66+
67+
pl.figure(1, figsize=(10, 10))
68+
pl.subplot(2, 2, 1)
69+
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
70+
pl.xticks([])
71+
pl.yticks([])
72+
pl.legend(loc=0)
73+
pl.title('Source samples')
74+
75+
pl.subplot(2, 2, 2)
76+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
77+
pl.xticks([])
78+
pl.yticks([])
79+
pl.legend(loc=0)
80+
pl.title('Target samples')
81+
82+
pl.subplot(2, 2, 3)
83+
pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
84+
pl.xticks([])
85+
pl.yticks([])
86+
pl.title('Cost matrix - unsupervised DA')
87+
88+
pl.subplot(2, 2, 4)
89+
pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
90+
pl.xticks([])
91+
pl.yticks([])
92+
pl.title('Cost matrix - semisupervised DA')
93+
94+
pl.tight_layout()
95+
96+
# the optimal coupling in the semi-supervised DA case will exhibit " shape
97+
# similar" to the cost matrix, (block diagonal matrix)
98+
99+
100+
##############################################################################
101+
# Fig 2 : plots optimal couplings for the different methods
102+
##############################################################################
103+
104+
pl.figure(2, figsize=(8, 4))
105+
106+
pl.subplot(1, 2, 1)
107+
pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
108+
pl.xticks([])
109+
pl.yticks([])
110+
pl.title('Optimal coupling\nUnsupervised DA')
111+
112+
pl.subplot(1, 2, 2)
113+
pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
114+
pl.xticks([])
115+
pl.yticks([])
116+
pl.title('Optimal coupling\nSemi-supervised DA')
117+
118+
pl.tight_layout()
119+
120+
121+
##############################################################################
122+
# Fig 3 : plot transported samples
123+
##############################################################################
124+
125+
# display transported samples
126+
pl.figure(4, figsize=(8, 4))
127+
pl.subplot(1, 2, 1)
128+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
129+
label='Target samples', alpha=0.5)
130+
pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
131+
marker='+', label='Transp samples', s=30)
132+
pl.title('Transported samples\nEmdTransport')
133+
pl.legend(loc=0)
134+
pl.xticks([])
135+
pl.yticks([])
136+
137+
pl.subplot(1, 2, 2)
138+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
139+
label='Target samples', alpha=0.5)
140+
pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
141+
marker='+', label='Transp samples', s=30)
142+
pl.title('Transported samples\nSinkhornTransport')
143+
pl.xticks([])
144+
pl.yticks([])
145+
146+
pl.tight_layout()
147+
pl.show()

examples/plot_gromov.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
==========================
4+
Gromov-Wasserstein example
5+
==========================
6+
This example is designed to show how to use the Gromov-Wassertsein distance
7+
computation in POT.
8+
"""
9+
10+
# Author: Erwan Vautier <[email protected]>
11+
# Nicolas Courty <[email protected]>
12+
#
13+
# License: MIT License
14+
15+
import scipy as sp
16+
import numpy as np
17+
import matplotlib.pylab as pl
18+
19+
import ot
20+
21+
22+
"""
23+
Sample two Gaussian distributions (2D and 3D)
24+
=============================================
25+
The Gromov-Wasserstein distance allows to compute distances with samples that
26+
do not belong to the same metric space. For demonstration purpose, we sample
27+
two Gaussian distributions in 2- and 3-dimensional spaces.
28+
"""
29+
30+
n_samples = 30 # nb samples
31+
32+
mu_s = np.array([0, 0])
33+
cov_s = np.array([[1, 0], [0, 1]])
34+
35+
mu_t = np.array([4, 4, 4])
36+
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
37+
38+
39+
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
40+
P = sp.linalg.sqrtm(cov_t)
41+
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
42+
43+
44+
"""
45+
Plotting the distributions
46+
==========================
47+
"""
48+
fig = pl.figure()
49+
ax1 = fig.add_subplot(121)
50+
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
51+
ax2 = fig.add_subplot(122, projection='3d')
52+
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
53+
pl.show()
54+
55+
56+
"""
57+
Compute distance kernels, normalize them and then display
58+
=========================================================
59+
"""
60+
61+
C1 = sp.spatial.distance.cdist(xs, xs)
62+
C2 = sp.spatial.distance.cdist(xt, xt)
63+
64+
C1 /= C1.max()
65+
C2 /= C2.max()
66+
67+
pl.figure()
68+
pl.subplot(121)
69+
pl.imshow(C1)
70+
pl.subplot(122)
71+
pl.imshow(C2)
72+
pl.show()
73+
74+
"""
75+
Compute Gromov-Wasserstein plans and distance
76+
=============================================
77+
"""
78+
79+
p = ot.unif(n_samples)
80+
q = ot.unif(n_samples)
81+
82+
gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
83+
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
84+
85+
print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
86+
87+
pl.figure()
88+
pl.imshow(gw, cmap='jet')
89+
pl.colorbar()
90+
pl.show()

0 commit comments

Comments
 (0)