Skip to content

Commit 74ca2d7

Browse files
committed
refactoring examples according to new DA classes
1 parent 2d4d0b4 commit 74ca2d7

File tree

5 files changed

+744
-0
lines changed

5 files changed

+744
-0
lines changed

examples/da/plot_otda_classes.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================
4+
OT for domain adaptation
5+
========================
6+
7+
This example introduces a domain adaptation in a 2D setting and the 4 OTDA
8+
approaches currently supported in POT.
9+
10+
"""
11+
12+
# Authors: Remi Flamary <[email protected]>
13+
# Stanilslas Chambon <[email protected]>
14+
#
15+
# License: MIT License
16+
17+
import matplotlib.pylab as pl
18+
import ot
19+
20+
21+
# number of source and target points to generate
22+
ns = 150
23+
nt = 150
24+
25+
Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
26+
Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
27+
28+
# Instantiate the different transport algorithms and fit them
29+
30+
# EMD Transport
31+
ot_emd = ot.da.EMDTransport()
32+
ot_emd.fit(Xs=Xs, Xt=Xt)
33+
34+
# Sinkhorn Transport
35+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
36+
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
37+
38+
# Sinkhorn Transport with Group lasso regularization
39+
ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
40+
ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
41+
42+
# Sinkhorn Transport with Group lasso regularization l1l2
43+
ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,
44+
verbose=True)
45+
ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)
46+
47+
# transport source samples onto target samples
48+
transp_Xs_emd = ot_emd.transform(Xs=Xs)
49+
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
50+
transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
51+
transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)
52+
53+
##############################################################################
54+
# Fig 1 : plots source and target samples
55+
##############################################################################
56+
57+
pl.figure(1, figsize=(10, 5))
58+
pl.subplot(1, 2, 1)
59+
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
60+
pl.xticks([])
61+
pl.yticks([])
62+
pl.legend(loc=0)
63+
pl.title('Source samples')
64+
65+
pl.subplot(1, 2, 2)
66+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
67+
pl.xticks([])
68+
pl.yticks([])
69+
pl.legend(loc=0)
70+
pl.title('Target samples')
71+
pl.tight_layout()
72+
73+
##############################################################################
74+
# Fig 2 : plot optimal couplings and transported samples
75+
##############################################################################
76+
77+
param_img = {'interpolation': 'nearest', 'cmap': 'spectral'}
78+
79+
pl.figure(2, figsize=(15, 8))
80+
pl.subplot(2, 4, 1)
81+
pl.imshow(ot_emd.coupling_, **param_img)
82+
pl.xticks([])
83+
pl.yticks([])
84+
pl.title('Optimal coupling\nEMDTransport')
85+
86+
pl.subplot(2, 4, 2)
87+
pl.imshow(ot_sinkhorn.coupling_, **param_img)
88+
pl.xticks([])
89+
pl.yticks([])
90+
pl.title('Optimal coupling\nSinkhornTransport')
91+
92+
pl.subplot(2, 4, 3)
93+
pl.imshow(ot_lpl1.coupling_, **param_img)
94+
pl.xticks([])
95+
pl.yticks([])
96+
pl.title('Optimal coupling\nSinkhornLpl1Transport')
97+
98+
pl.subplot(2, 4, 4)
99+
pl.imshow(ot_l1l2.coupling_, **param_img)
100+
pl.xticks([])
101+
pl.yticks([])
102+
pl.title('Optimal coupling\nSinkhornL1l2Transport')
103+
104+
pl.subplot(2, 4, 5)
105+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
106+
label='Target samples', alpha=0.3)
107+
pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
108+
marker='+', label='Transp samples', s=30)
109+
pl.xticks([])
110+
pl.yticks([])
111+
pl.title('Transported samples\nEmdTransport')
112+
pl.legend(loc="lower left")
113+
114+
pl.subplot(2, 4, 6)
115+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
116+
label='Target samples', alpha=0.3)
117+
pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
118+
marker='+', label='Transp samples', s=30)
119+
pl.xticks([])
120+
pl.yticks([])
121+
pl.title('Transported samples\nSinkhornTransport')
122+
123+
pl.subplot(2, 4, 7)
124+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
125+
label='Target samples', alpha=0.3)
126+
pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
127+
marker='+', label='Transp samples', s=30)
128+
pl.xticks([])
129+
pl.yticks([])
130+
pl.title('Transported samples\nSinkhornLpl1Transport')
131+
132+
pl.subplot(2, 4, 8)
133+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
134+
label='Target samples', alpha=0.3)
135+
pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,
136+
marker='+', label='Transp samples', s=30)
137+
pl.xticks([])
138+
pl.yticks([])
139+
pl.title('Transported samples\nSinkhornL1l2Transport')
140+
pl.tight_layout()
141+
142+
pl.show()

examples/da/plot_otda_color_images.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================================================
4+
OT for domain adaptation with image color adaptation [6]
5+
========================================================
6+
7+
This example presents a way of transferring colors between two image
8+
with Optimal Transport as introduced in [6]
9+
10+
[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).
11+
Regularized discrete optimal transport.
12+
SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
13+
"""
14+
15+
# Authors: Remi Flamary <[email protected]>
16+
# Stanilslas Chambon <[email protected]>
17+
#
18+
# License: MIT License
19+
20+
import numpy as np
21+
from scipy import ndimage
22+
import matplotlib.pylab as pl
23+
24+
import ot
25+
26+
27+
def im2mat(I):
28+
"""Converts and image to matrix (one pixel per line)"""
29+
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
30+
31+
32+
def mat2im(X, shape):
33+
"""Converts back a matrix to an image"""
34+
return X.reshape(shape)
35+
36+
37+
def minmax(I):
38+
return np.clip(I, 0, 1)
39+
40+
41+
# Loading images
42+
I1 = ndimage.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
43+
I2 = ndimage.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
44+
45+
X1 = im2mat(I1)
46+
X2 = im2mat(I2)
47+
48+
# training samples
49+
nb = 1000
50+
idx1 = np.random.randint(X1.shape[0], size=(nb,))
51+
idx2 = np.random.randint(X2.shape[0], size=(nb,))
52+
53+
Xs = X1[idx1, :]
54+
Xt = X2[idx2, :]
55+
56+
# EMDTransport
57+
ot_emd = ot.da.EMDTransport()
58+
ot_emd.fit(Xs=Xs, Xt=Xt)
59+
60+
# SinkhornTransport
61+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
62+
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
63+
64+
# prediction between images (using out of sample prediction as in [6])
65+
transp_Xs_emd = ot_emd.transform(Xs=X1)
66+
transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
67+
68+
transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
69+
transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2)
70+
71+
I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
72+
I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
73+
74+
I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
75+
I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
76+
77+
##############################################################################
78+
# plot original image
79+
##############################################################################
80+
81+
pl.figure(1, figsize=(6.4, 3))
82+
83+
pl.subplot(1, 2, 1)
84+
pl.imshow(I1)
85+
pl.axis('off')
86+
pl.title('Image 1')
87+
88+
pl.subplot(1, 2, 2)
89+
pl.imshow(I2)
90+
pl.axis('off')
91+
pl.title('Image 2')
92+
93+
##############################################################################
94+
# scatter plot of colors
95+
##############################################################################
96+
97+
pl.figure(2, figsize=(6.4, 3))
98+
99+
pl.subplot(1, 2, 1)
100+
pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
101+
pl.axis([0, 1, 0, 1])
102+
pl.xlabel('Red')
103+
pl.ylabel('Blue')
104+
pl.title('Image 1')
105+
106+
pl.subplot(1, 2, 2)
107+
pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
108+
pl.axis([0, 1, 0, 1])
109+
pl.xlabel('Red')
110+
pl.ylabel('Blue')
111+
pl.title('Image 2')
112+
pl.tight_layout()
113+
114+
##############################################################################
115+
# plot new images
116+
##############################################################################
117+
118+
pl.figure(3, figsize=(8, 4))
119+
120+
pl.subplot(2, 3, 1)
121+
pl.imshow(I1)
122+
pl.axis('off')
123+
pl.title('Image 1')
124+
125+
pl.subplot(2, 3, 2)
126+
pl.imshow(I1t)
127+
pl.axis('off')
128+
pl.title('Image 1 Adapt')
129+
130+
pl.subplot(2, 3, 3)
131+
pl.imshow(I1te)
132+
pl.axis('off')
133+
pl.title('Image 1 Adapt (reg)')
134+
135+
pl.subplot(2, 3, 4)
136+
pl.imshow(I2)
137+
pl.axis('off')
138+
pl.title('Image 2')
139+
140+
pl.subplot(2, 3, 5)
141+
pl.imshow(I2t)
142+
pl.axis('off')
143+
pl.title('Image 2 Adapt')
144+
145+
pl.subplot(2, 3, 6)
146+
pl.imshow(I2te)
147+
pl.axis('off')
148+
pl.title('Image 2 Adapt (reg)')
149+
pl.tight_layout()
150+
151+
pl.show()

0 commit comments

Comments
 (0)