Skip to content

Commit 94bc743

Browse files
committed
new example for mapping
1 parent a65ff1b commit 94bc743

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Demo of Optimal transport for domain adaptation with image color adaptation as in [6] with mapping estimation from [8]
4+
5+
[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized
6+
discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
7+
[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
8+
discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
9+
10+
11+
"""
12+
13+
import numpy as np
14+
import scipy.ndimage as spi
15+
import matplotlib.pylab as pl
16+
import ot
17+
18+
19+
#%% Loading images
20+
21+
I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256
22+
I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256
23+
24+
#%% Plot images
25+
26+
pl.figure(1)
27+
28+
pl.subplot(1,2,1)
29+
pl.imshow(I1)
30+
pl.title('Image 1')
31+
32+
pl.subplot(1,2,2)
33+
pl.imshow(I2)
34+
pl.title('Image 2')
35+
36+
pl.show()
37+
38+
#%% Image conversion and dataset generation
39+
40+
def im2mat(I):
41+
"""Converts and image to matrix (one pixel per line)"""
42+
return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))
43+
44+
def mat2im(X,shape):
45+
"""Converts back a matrix to an image"""
46+
return X.reshape(shape)
47+
48+
X1=im2mat(I1)
49+
X2=im2mat(I2)
50+
51+
# training samples
52+
nb=1000
53+
idx1=np.random.randint(X1.shape[0],size=(nb,))
54+
idx2=np.random.randint(X2.shape[0],size=(nb,))
55+
56+
xs=X1[idx1,:]
57+
xt=X2[idx2,:]
58+
59+
#%% Plot image distributions
60+
61+
62+
pl.figure(2,(10,5))
63+
64+
pl.subplot(1,2,1)
65+
pl.scatter(xs[:,0],xs[:,2],c=xs)
66+
pl.axis([0,1,0,1])
67+
pl.xlabel('Red')
68+
pl.ylabel('Blue')
69+
pl.title('Image 1')
70+
71+
pl.subplot(1,2,2)
72+
#pl.imshow(I2)
73+
pl.scatter(xt[:,0],xt[:,2],c=xt)
74+
pl.axis([0,1,0,1])
75+
pl.xlabel('Red')
76+
pl.ylabel('Blue')
77+
pl.title('Image 2')
78+
79+
pl.show()
80+
81+
82+
83+
#%% domain adaptation between images
84+
def minmax(I):
85+
return np.minimum(np.maximum(I,0),1)
86+
# LP problem
87+
da_emd=ot.da.OTDA() # init class
88+
da_emd.fit(xs,xt) # fit distributions
89+
90+
X1t=da_emd.predict(X1) # out of sample
91+
I1t=minmax(mat2im(X1t,I1.shape))
92+
93+
# sinkhorn regularization
94+
lambd=1e-1
95+
da_entrop=ot.da.OTDA_sinkhorn()
96+
da_entrop.fit(xs,xt,reg=lambd)
97+
98+
X1te=da_entrop.predict(X1)
99+
I1te=minmax(mat2im(X1te,I1.shape))
100+
101+
# linear mapping estimation
102+
eta=1e-8 # quadratic regularization for regression
103+
mu=1e0 # weight of the OT linear term
104+
bias=True # estimate a bias
105+
106+
ot_mapping=ot.da.OTDA_mapping_linear()
107+
ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)
108+
109+
X1tl=ot_mapping.predict(X1) # use the estimated mapping
110+
I1tl=minmax(mat2im(X1tl,I1.shape))
111+
112+
# nonlinear mapping estimation
113+
eta=1e-2 # quadratic regularization for regression
114+
mu=1e0 # weight of the OT linear term
115+
bias=False # estimate a bias
116+
sigma=1 # sigma bandwidth fot gaussian kernel
117+
118+
119+
ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
120+
ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)
121+
122+
X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping
123+
I1tn=minmax(mat2im(X1tn,I1.shape))
124+
#%% plot images
125+
126+
127+
pl.figure(2,(10,8))
128+
129+
pl.subplot(2,3,1)
130+
131+
pl.imshow(I1)
132+
pl.title('Im. 1')
133+
134+
pl.subplot(2,3,2)
135+
136+
pl.imshow(I2)
137+
pl.title('Im. 2')
138+
139+
140+
pl.subplot(2,3,3)
141+
pl.imshow(I1t)
142+
pl.title('Im. 1 Interp LP')
143+
144+
pl.subplot(2,3,4)
145+
pl.imshow(I1te)
146+
pl.title('Im. 1 Interp Entrop')
147+
148+
149+
pl.subplot(2,3,5)
150+
pl.imshow(I1tl)
151+
pl.title('Im. 1 Linear mapping')
152+
153+
pl.subplot(2,3,6)
154+
pl.imshow(I1tn)
155+
pl.title('Im. 1 nonlinear mapping')
156+
157+
pl.show()

0 commit comments

Comments
 (0)