Skip to content

Commit a2ec6e5

Browse files
authored
Merge pull request #22 from Slasnista/domain_adaptation
Fixes #17
2 parents 7638d01 + 65de6fc commit a2ec6e5

14 files changed

+2353
-730
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ The contributors to this library are:
136136
* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/)
137137
* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation)
138138
* [Léo Gautheron](https://github.com/aje) (GPU implementation)
139+
* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1)
140+
* [Stanislas Chambon](https://slasnista.github.io/)
139141

140142
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
141143

examples/da/plot_otda_classes.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================
4+
OT for domain adaptation
5+
========================
6+
7+
This example introduces a domain adaptation in a 2D setting and the 4 OTDA
8+
approaches currently supported in POT.
9+
10+
"""
11+
12+
# Authors: Remi Flamary <[email protected]>
13+
# Stanislas Chambon <[email protected]>
14+
#
15+
# License: MIT License
16+
17+
import matplotlib.pylab as pl
18+
import ot
19+
20+
21+
##############################################################################
22+
# generate data
23+
##############################################################################
24+
25+
n_source_samples = 150
26+
n_target_samples = 150
27+
28+
Xs, ys = ot.datasets.get_data_classif('3gauss', n_source_samples)
29+
Xt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples)
30+
31+
32+
##############################################################################
33+
# Instantiate the different transport algorithms and fit them
34+
##############################################################################
35+
36+
# EMD Transport
37+
ot_emd = ot.da.EMDTransport()
38+
ot_emd.fit(Xs=Xs, Xt=Xt)
39+
40+
# Sinkhorn Transport
41+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
42+
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
43+
44+
# Sinkhorn Transport with Group lasso regularization
45+
ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
46+
ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
47+
48+
# Sinkhorn Transport with Group lasso regularization l1l2
49+
ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,
50+
verbose=True)
51+
ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)
52+
53+
# transport source samples onto target samples
54+
transp_Xs_emd = ot_emd.transform(Xs=Xs)
55+
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
56+
transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
57+
transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)
58+
59+
60+
##############################################################################
61+
# Fig 1 : plots source and target samples
62+
##############################################################################
63+
64+
pl.figure(1, figsize=(10, 5))
65+
pl.subplot(1, 2, 1)
66+
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
67+
pl.xticks([])
68+
pl.yticks([])
69+
pl.legend(loc=0)
70+
pl.title('Source samples')
71+
72+
pl.subplot(1, 2, 2)
73+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
74+
pl.xticks([])
75+
pl.yticks([])
76+
pl.legend(loc=0)
77+
pl.title('Target samples')
78+
pl.tight_layout()
79+
80+
81+
##############################################################################
82+
# Fig 2 : plot optimal couplings and transported samples
83+
##############################################################################
84+
85+
param_img = {'interpolation': 'nearest', 'cmap': 'spectral'}
86+
87+
pl.figure(2, figsize=(15, 8))
88+
pl.subplot(2, 4, 1)
89+
pl.imshow(ot_emd.coupling_, **param_img)
90+
pl.xticks([])
91+
pl.yticks([])
92+
pl.title('Optimal coupling\nEMDTransport')
93+
94+
pl.subplot(2, 4, 2)
95+
pl.imshow(ot_sinkhorn.coupling_, **param_img)
96+
pl.xticks([])
97+
pl.yticks([])
98+
pl.title('Optimal coupling\nSinkhornTransport')
99+
100+
pl.subplot(2, 4, 3)
101+
pl.imshow(ot_lpl1.coupling_, **param_img)
102+
pl.xticks([])
103+
pl.yticks([])
104+
pl.title('Optimal coupling\nSinkhornLpl1Transport')
105+
106+
pl.subplot(2, 4, 4)
107+
pl.imshow(ot_l1l2.coupling_, **param_img)
108+
pl.xticks([])
109+
pl.yticks([])
110+
pl.title('Optimal coupling\nSinkhornL1l2Transport')
111+
112+
pl.subplot(2, 4, 5)
113+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
114+
label='Target samples', alpha=0.3)
115+
pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
116+
marker='+', label='Transp samples', s=30)
117+
pl.xticks([])
118+
pl.yticks([])
119+
pl.title('Transported samples\nEmdTransport')
120+
pl.legend(loc="lower left")
121+
122+
pl.subplot(2, 4, 6)
123+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
124+
label='Target samples', alpha=0.3)
125+
pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
126+
marker='+', label='Transp samples', s=30)
127+
pl.xticks([])
128+
pl.yticks([])
129+
pl.title('Transported samples\nSinkhornTransport')
130+
131+
pl.subplot(2, 4, 7)
132+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
133+
label='Target samples', alpha=0.3)
134+
pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
135+
marker='+', label='Transp samples', s=30)
136+
pl.xticks([])
137+
pl.yticks([])
138+
pl.title('Transported samples\nSinkhornLpl1Transport')
139+
140+
pl.subplot(2, 4, 8)
141+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
142+
label='Target samples', alpha=0.3)
143+
pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,
144+
marker='+', label='Transp samples', s=30)
145+
pl.xticks([])
146+
pl.yticks([])
147+
pl.title('Transported samples\nSinkhornL1l2Transport')
148+
pl.tight_layout()
149+
150+
pl.show()

examples/da/plot_otda_color_images.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================================================
4+
OT for domain adaptation with image color adaptation [6]
5+
========================================================
6+
7+
This example presents a way of transferring colors between two image
8+
with Optimal Transport as introduced in [6]
9+
10+
[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).
11+
Regularized discrete optimal transport.
12+
SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
13+
"""
14+
15+
# Authors: Remi Flamary <[email protected]>
16+
# Stanislas Chambon <[email protected]>
17+
#
18+
# License: MIT License
19+
20+
import numpy as np
21+
from scipy import ndimage
22+
import matplotlib.pylab as pl
23+
import ot
24+
25+
26+
r = np.random.RandomState(42)
27+
28+
29+
def im2mat(I):
30+
"""Converts and image to matrix (one pixel per line)"""
31+
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
32+
33+
34+
def mat2im(X, shape):
35+
"""Converts back a matrix to an image"""
36+
return X.reshape(shape)
37+
38+
39+
def minmax(I):
40+
return np.clip(I, 0, 1)
41+
42+
43+
##############################################################################
44+
# generate data
45+
##############################################################################
46+
47+
# Loading images
48+
I1 = ndimage.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
49+
I2 = ndimage.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
50+
51+
X1 = im2mat(I1)
52+
X2 = im2mat(I2)
53+
54+
# training samples
55+
nb = 1000
56+
idx1 = r.randint(X1.shape[0], size=(nb,))
57+
idx2 = r.randint(X2.shape[0], size=(nb,))
58+
59+
Xs = X1[idx1, :]
60+
Xt = X2[idx2, :]
61+
62+
63+
##############################################################################
64+
# Instantiate the different transport algorithms and fit them
65+
##############################################################################
66+
67+
# EMDTransport
68+
ot_emd = ot.da.EMDTransport()
69+
ot_emd.fit(Xs=Xs, Xt=Xt)
70+
71+
# SinkhornTransport
72+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
73+
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
74+
75+
# prediction between images (using out of sample prediction as in [6])
76+
transp_Xs_emd = ot_emd.transform(Xs=X1)
77+
transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
78+
79+
transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
80+
transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2)
81+
82+
I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
83+
I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
84+
85+
I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
86+
I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
87+
88+
89+
##############################################################################
90+
# plot original image
91+
##############################################################################
92+
93+
pl.figure(1, figsize=(6.4, 3))
94+
95+
pl.subplot(1, 2, 1)
96+
pl.imshow(I1)
97+
pl.axis('off')
98+
pl.title('Image 1')
99+
100+
pl.subplot(1, 2, 2)
101+
pl.imshow(I2)
102+
pl.axis('off')
103+
pl.title('Image 2')
104+
105+
106+
##############################################################################
107+
# scatter plot of colors
108+
##############################################################################
109+
110+
pl.figure(2, figsize=(6.4, 3))
111+
112+
pl.subplot(1, 2, 1)
113+
pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
114+
pl.axis([0, 1, 0, 1])
115+
pl.xlabel('Red')
116+
pl.ylabel('Blue')
117+
pl.title('Image 1')
118+
119+
pl.subplot(1, 2, 2)
120+
pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
121+
pl.axis([0, 1, 0, 1])
122+
pl.xlabel('Red')
123+
pl.ylabel('Blue')
124+
pl.title('Image 2')
125+
pl.tight_layout()
126+
127+
128+
##############################################################################
129+
# plot new images
130+
##############################################################################
131+
132+
pl.figure(3, figsize=(8, 4))
133+
134+
pl.subplot(2, 3, 1)
135+
pl.imshow(I1)
136+
pl.axis('off')
137+
pl.title('Image 1')
138+
139+
pl.subplot(2, 3, 2)
140+
pl.imshow(I1t)
141+
pl.axis('off')
142+
pl.title('Image 1 Adapt')
143+
144+
pl.subplot(2, 3, 3)
145+
pl.imshow(I1te)
146+
pl.axis('off')
147+
pl.title('Image 1 Adapt (reg)')
148+
149+
pl.subplot(2, 3, 4)
150+
pl.imshow(I2)
151+
pl.axis('off')
152+
pl.title('Image 2')
153+
154+
pl.subplot(2, 3, 5)
155+
pl.imshow(I2t)
156+
pl.axis('off')
157+
pl.title('Image 2 Adapt')
158+
159+
pl.subplot(2, 3, 6)
160+
pl.imshow(I2te)
161+
pl.axis('off')
162+
pl.title('Image 2 Adapt (reg)')
163+
pl.tight_layout()
164+
165+
pl.show()

0 commit comments

Comments
 (0)