1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Demo of Optimal transport for domain adaptation with image color adaptation as in [6] with mapping estimation from [8]
4
+
5
+ [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized
6
+ discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
7
+ [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
8
+ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
9
+
10
+
11
+ """
12
+
13
+ import numpy as np
14
+ import scipy .ndimage as spi
15
+ import matplotlib .pylab as pl
16
+ import ot
17
+
18
+
19
+ #%% Loading images
20
+
21
+ I1 = spi .imread ('../data/ocean_day.jpg' ).astype (np .float64 )/ 256
22
+ I2 = spi .imread ('../data/ocean_sunset.jpg' ).astype (np .float64 )/ 256
23
+
24
+ #%% Plot images
25
+
26
+ pl .figure (1 )
27
+
28
+ pl .subplot (1 ,2 ,1 )
29
+ pl .imshow (I1 )
30
+ pl .title ('Image 1' )
31
+
32
+ pl .subplot (1 ,2 ,2 )
33
+ pl .imshow (I2 )
34
+ pl .title ('Image 2' )
35
+
36
+ pl .show ()
37
+
38
+ #%% Image conversion and dataset generation
39
+
40
+ def im2mat (I ):
41
+ """Converts and image to matrix (one pixel per line)"""
42
+ return I .reshape ((I .shape [0 ]* I .shape [1 ],I .shape [2 ]))
43
+
44
+ def mat2im (X ,shape ):
45
+ """Converts back a matrix to an image"""
46
+ return X .reshape (shape )
47
+
48
+ X1 = im2mat (I1 )
49
+ X2 = im2mat (I2 )
50
+
51
+ # training samples
52
+ nb = 1000
53
+ idx1 = np .random .randint (X1 .shape [0 ],size = (nb ,))
54
+ idx2 = np .random .randint (X2 .shape [0 ],size = (nb ,))
55
+
56
+ xs = X1 [idx1 ,:]
57
+ xt = X2 [idx2 ,:]
58
+
59
+ #%% Plot image distributions
60
+
61
+
62
+ pl .figure (2 ,(10 ,5 ))
63
+
64
+ pl .subplot (1 ,2 ,1 )
65
+ pl .scatter (xs [:,0 ],xs [:,2 ],c = xs )
66
+ pl .axis ([0 ,1 ,0 ,1 ])
67
+ pl .xlabel ('Red' )
68
+ pl .ylabel ('Blue' )
69
+ pl .title ('Image 1' )
70
+
71
+ pl .subplot (1 ,2 ,2 )
72
+ #pl.imshow(I2)
73
+ pl .scatter (xt [:,0 ],xt [:,2 ],c = xt )
74
+ pl .axis ([0 ,1 ,0 ,1 ])
75
+ pl .xlabel ('Red' )
76
+ pl .ylabel ('Blue' )
77
+ pl .title ('Image 2' )
78
+
79
+ pl .show ()
80
+
81
+
82
+
83
+ #%% domain adaptation between images
84
+ def minmax (I ):
85
+ return np .minimum (np .maximum (I ,0 ),1 )
86
+ # LP problem
87
+ da_emd = ot .da .OTDA () # init class
88
+ da_emd .fit (xs ,xt ) # fit distributions
89
+
90
+ X1t = da_emd .predict (X1 ) # out of sample
91
+ I1t = minmax (mat2im (X1t ,I1 .shape ))
92
+
93
+ # sinkhorn regularization
94
+ lambd = 1e-1
95
+ da_entrop = ot .da .OTDA_sinkhorn ()
96
+ da_entrop .fit (xs ,xt ,reg = lambd )
97
+
98
+ X1te = da_entrop .predict (X1 )
99
+ I1te = minmax (mat2im (X1te ,I1 .shape ))
100
+
101
+ # linear mapping estimation
102
+ eta = 1e-8 # quadratic regularization for regression
103
+ mu = 1e0 # weight of the OT linear term
104
+ bias = True # estimate a bias
105
+
106
+ ot_mapping = ot .da .OTDA_mapping_linear ()
107
+ ot_mapping .fit (xs ,xt ,mu = mu ,eta = eta ,bias = bias ,numItermax = 20 ,verbose = True )
108
+
109
+ X1tl = ot_mapping .predict (X1 ) # use the estimated mapping
110
+ I1tl = minmax (mat2im (X1tl ,I1 .shape ))
111
+
112
+ # nonlinear mapping estimation
113
+ eta = 1e-2 # quadratic regularization for regression
114
+ mu = 1e0 # weight of the OT linear term
115
+ bias = False # estimate a bias
116
+ sigma = 1 # sigma bandwidth fot gaussian kernel
117
+
118
+
119
+ ot_mapping_kernel = ot .da .OTDA_mapping_kernel ()
120
+ ot_mapping_kernel .fit (xs ,xt ,mu = mu ,eta = eta ,sigma = sigma ,bias = bias ,numItermax = 10 ,verbose = True )
121
+
122
+ X1tn = ot_mapping_kernel .predict (X1 ) # use the estimated mapping
123
+ I1tn = minmax (mat2im (X1tn ,I1 .shape ))
124
+ #%% plot images
125
+
126
+
127
+ pl .figure (2 ,(10 ,8 ))
128
+
129
+ pl .subplot (2 ,3 ,1 )
130
+
131
+ pl .imshow (I1 )
132
+ pl .title ('Im. 1' )
133
+
134
+ pl .subplot (2 ,3 ,2 )
135
+
136
+ pl .imshow (I2 )
137
+ pl .title ('Im. 2' )
138
+
139
+
140
+ pl .subplot (2 ,3 ,3 )
141
+ pl .imshow (I1t )
142
+ pl .title ('Im. 1 Interp LP' )
143
+
144
+ pl .subplot (2 ,3 ,4 )
145
+ pl .imshow (I1te )
146
+ pl .title ('Im. 1 Interp Entrop' )
147
+
148
+
149
+ pl .subplot (2 ,3 ,5 )
150
+ pl .imshow (I1tl )
151
+ pl .title ('Im. 1 Linear mapping' )
152
+
153
+ pl .subplot (2 ,3 ,6 )
154
+ pl .imshow (I1tn )
155
+ pl .title ('Im. 1 nonlinear mapping' )
156
+
157
+ pl .show ()
0 commit comments