Skip to content

Commit 20547ee

Browse files
committed
update wda example
1 parent 05765e2 commit 20547ee

File tree

1 file changed

+41
-13
lines changed

1 file changed

+41
-13
lines changed

examples/plot_WDA.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,51 +18,79 @@
1818

1919
n=1000 # nb samples in source and target datasets
2020
nz=0.2
21-
xs,ys=ot.datasets.get_data_classif('3gauss',n,nz)
22-
xt,yt=ot.datasets.get_data_classif('3gauss',n,nz)
21+
22+
# generate circle dataset
23+
t=np.random.rand(n)*2*np.pi
24+
ys=np.floor((np.arange(n)*1.0/n*3))+1
25+
xs=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1)
26+
xs=xs*ys.reshape(-1,1)+nz*np.random.randn(n,2)
27+
28+
t=np.random.rand(n)*2*np.pi
29+
yt=np.floor((np.arange(n)*1.0/n*3))+1
30+
xt=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1)
31+
xt=xt*yt.reshape(-1,1)+nz*np.random.randn(n,2)
2332

2433
nbnoise=8
2534

2635
xs=np.hstack((xs,np.random.randn(n,nbnoise)))
2736
xt=np.hstack((xt,np.random.randn(n,nbnoise)))
2837

2938
#%% plot samples
39+
pl.figure(1,(10,5))
3040

31-
pl.figure(1)
32-
33-
41+
pl.subplot(1,2,1)
3442
pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
3543
pl.legend(loc=0)
3644
pl.title('Discriminant dimensions')
3745

46+
pl.subplot(1,2,2)
47+
pl.scatter(xt[:,2],xt[:,3],c=ys,marker='+',label='Source samples')
48+
pl.legend(loc=0)
49+
pl.title('Other dimensions')
50+
pl.show()
3851

39-
#%% Comlpute FDA
52+
#%% Compute FDA
4053
p=2
4154

4255
Pfda,projfda = fda(xs,ys,p)
4356

4457
#%% Compute WDA
4558
p=2
46-
reg=1
59+
reg=1e-1
4760
k=10
4861
maxiter=100
4962

50-
P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
63+
Pwda,projwda = wda(xs,ys,p,reg,k,maxiter=maxiter)
5164

5265
#%% plot samples
5366

5467
xsp=projfda(xs)
5568
xtp=projfda(xt)
5669

57-
pl.figure(1,(10,5))
70+
xspw=projwda(xs)
71+
xtpw=projwda(xt)
5872

59-
pl.subplot(1,2,1)
73+
pl.figure(1,(10,10))
74+
75+
pl.subplot(2,2,1)
6076
pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
6177
pl.legend(loc=0)
62-
pl.title('Projected training samples')
78+
pl.title('Projected training samples FDA')
6379

6480

65-
pl.subplot(1,2,2)
81+
pl.subplot(2,2,2)
6682
pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
6783
pl.legend(loc=0)
68-
pl.title('Projected test samples')
84+
pl.title('Projected test samples FDA')
85+
86+
87+
pl.subplot(2,2,3)
88+
pl.scatter(xspw[:,0],xspw[:,1],c=ys,marker='+',label='Projected samples')
89+
pl.legend(loc=0)
90+
pl.title('Projected training samples WDA')
91+
92+
93+
pl.subplot(2,2,4)
94+
pl.scatter(xtpw[:,0],xtpw[:,1],c=ys,marker='+',label='Projected samples')
95+
pl.legend(loc=0)
96+
pl.title('Projected test samples WDA')

0 commit comments

Comments
 (0)