Skip to content

Commit 5992b14

Browse files
committed
add demo mapping
1 parent 405f352 commit 5992b14

File tree

3 files changed

+117
-2
lines changed

3 files changed

+117
-2
lines changed

docs/source/examples.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,8 @@ Color transfer in images
3232
------------------------
3333

3434
.. literalinclude:: ../../examples/demo_OTDA_color_images.py
35+
36+
OT mapping estimation for domain adaptation
37+
-------------------------------------------
38+
39+
.. literalinclude:: ../../examples/demo_OTDA_mapping.py

examples/demo_OTDA_mapping.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Demo of OT mapping estimation for somain adaptation
4+
"""
5+
6+
import numpy as np
7+
import matplotlib.pylab as pl
8+
import ot
9+
10+
11+
12+
#%% dataset generation
13+
14+
np.random.seed(0)
15+
16+
n=100 # nb samples in source and target datasets
17+
theta=2*np.pi/20
18+
nz=0.1
19+
xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz)
20+
xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz)
21+
22+
# one of the target mode changes its variance (no linear mapping)
23+
xt[yt==2]*=3
24+
xt=xt+4
25+
26+
27+
#%% plot samples
28+
29+
pl.figure(1,(8,5))
30+
pl.clf()
31+
32+
pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
33+
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
34+
35+
pl.legend(loc=0)
36+
pl.title('Source and target distributions')
37+
38+
39+
40+
#%% OT linear mapping estimation
41+
42+
eta=1e-8 # quadratic regularization for regression
43+
mu=1e0 # weight of the OT linear term
44+
bias=True # estimate a bias
45+
46+
ot_mapping=ot.da.OTDA_mapping_linear()
47+
ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)
48+
49+
xst=ot_mapping.predict(xs) # use the estimated mapping
50+
xst0=ot_mapping.interp() # use barycentric mapping
51+
52+
53+
pl.figure(2,(10,7))
54+
pl.clf()
55+
pl.subplot(2,2,1)
56+
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
57+
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping')
58+
pl.title("barycentric mapping")
59+
60+
pl.subplot(2,2,2)
61+
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
62+
pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
63+
pl.title("Learned mapping")
64+
65+
66+
67+
#%% Kernel mapping estimation
68+
69+
eta=1e-5 # quadratic regularization for regression
70+
mu=1e-1 # weight of the OT linear term
71+
bias=True # estimate a bias
72+
sigma=1 # sigma bandwidth fot gaussian kernel
73+
74+
75+
ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
76+
ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)
77+
78+
xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping
79+
xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping
80+
81+
82+
#%% Plotting the mapped samples
83+
84+
pl.figure(2,(10,7))
85+
pl.clf()
86+
pl.subplot(2,2,1)
87+
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
88+
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples')
89+
pl.title("Bary. mapping (linear)")
90+
pl.legend(loc=0)
91+
92+
pl.subplot(2,2,2)
93+
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
94+
pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
95+
pl.title("Estim. mapping (linear)")
96+
97+
pl.subplot(2,2,3)
98+
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
99+
pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping')
100+
pl.title("Bary. mapping (kernel)")
101+
102+
pl.subplot(2,2,4)
103+
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
104+
pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping')
105+
pl.title("Estim. mapping (kernel)")
106+
107+
108+
109+
110+

ot/da.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def df(G):
203203
if it>=numItermax:
204204
loop=0
205205

206-
if abs(vloss[-1]-vloss[-2])<stopThr:
206+
if abs(vloss[-1]-vloss[-2])/abs(vloss[-2])<stopThr:
207207
loop=0
208208

209209
if verbose:
@@ -323,7 +323,7 @@ def df(G):
323323
if it>=numItermax:
324324
loop=0
325325

326-
if abs(vloss[-1]-vloss[-2])<stopThr:
326+
if abs(vloss[-1]-vloss[-2])/abs(vloss[-2])<stopThr:
327327
loop=0
328328

329329
if verbose:

0 commit comments

Comments
 (0)