|
18 | 18 |
|
19 | 19 | n=1000 # nb samples in source and target datasets
|
20 | 20 | nz=0.2
|
21 |
| -xs,ys=ot.datasets.get_data_classif('3gauss',n,nz) |
22 |
| -xt,yt=ot.datasets.get_data_classif('3gauss',n,nz) |
| 21 | + |
| 22 | +# generate circle dataset |
| 23 | +t=np.random.rand(n)*2*np.pi |
| 24 | +ys=np.floor((np.arange(n)*1.0/n*3))+1 |
| 25 | +xs=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) |
| 26 | +xs=xs*ys.reshape(-1,1)+nz*np.random.randn(n,2) |
| 27 | + |
| 28 | +t=np.random.rand(n)*2*np.pi |
| 29 | +yt=np.floor((np.arange(n)*1.0/n*3))+1 |
| 30 | +xt=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) |
| 31 | +xt=xt*yt.reshape(-1,1)+nz*np.random.randn(n,2) |
23 | 32 |
|
24 | 33 | nbnoise=8
|
25 | 34 |
|
26 | 35 | xs=np.hstack((xs,np.random.randn(n,nbnoise)))
|
27 | 36 | xt=np.hstack((xt,np.random.randn(n,nbnoise)))
|
28 | 37 |
|
29 | 38 | #%% plot samples
|
| 39 | +pl.figure(1,(10,5)) |
30 | 40 |
|
31 |
| -pl.figure(1) |
32 |
| - |
33 |
| - |
| 41 | +pl.subplot(1,2,1) |
34 | 42 | pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
|
35 | 43 | pl.legend(loc=0)
|
36 | 44 | pl.title('Discriminant dimensions')
|
37 | 45 |
|
| 46 | +pl.subplot(1,2,2) |
| 47 | +pl.scatter(xt[:,2],xt[:,3],c=ys,marker='+',label='Source samples') |
| 48 | +pl.legend(loc=0) |
| 49 | +pl.title('Other dimensions') |
| 50 | +pl.show() |
38 | 51 |
|
39 |
| -#%% Comlpute FDA |
| 52 | +#%% Compute FDA |
40 | 53 | p=2
|
41 | 54 |
|
42 | 55 | Pfda,projfda = fda(xs,ys,p)
|
43 | 56 |
|
44 | 57 | #%% Compute WDA
|
45 | 58 | p=2
|
46 |
| -reg=1 |
| 59 | +reg=1e-1 |
47 | 60 | k=10
|
48 | 61 | maxiter=100
|
49 | 62 |
|
50 |
| -P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter) |
| 63 | +Pwda,projwda = wda(xs,ys,p,reg,k,maxiter=maxiter) |
51 | 64 |
|
52 | 65 | #%% plot samples
|
53 | 66 |
|
54 | 67 | xsp=projfda(xs)
|
55 | 68 | xtp=projfda(xt)
|
56 | 69 |
|
57 |
| -pl.figure(1,(10,5)) |
| 70 | +xspw=projwda(xs) |
| 71 | +xtpw=projwda(xt) |
58 | 72 |
|
59 |
| -pl.subplot(1,2,1) |
| 73 | +pl.figure(1,(10,10)) |
| 74 | + |
| 75 | +pl.subplot(2,2,1) |
60 | 76 | pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
|
61 | 77 | pl.legend(loc=0)
|
62 |
| -pl.title('Projected training samples') |
| 78 | +pl.title('Projected training samples FDA') |
63 | 79 |
|
64 | 80 |
|
65 |
| -pl.subplot(1,2,2) |
| 81 | +pl.subplot(2,2,2) |
66 | 82 | pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
|
67 | 83 | pl.legend(loc=0)
|
68 |
| -pl.title('Projected test samples') |
| 84 | +pl.title('Projected test samples FDA') |
| 85 | + |
| 86 | + |
| 87 | +pl.subplot(2,2,3) |
| 88 | +pl.scatter(xspw[:,0],xspw[:,1],c=ys,marker='+',label='Projected samples') |
| 89 | +pl.legend(loc=0) |
| 90 | +pl.title('Projected training samples WDA') |
| 91 | + |
| 92 | + |
| 93 | +pl.subplot(2,2,4) |
| 94 | +pl.scatter(xtpw[:,0],xtpw[:,1],c=ys,marker='+',label='Projected samples') |
| 95 | +pl.legend(loc=0) |
| 96 | +pl.title('Projected test samples WDA') |
0 commit comments