Skip to content

Commit 05765e2

Browse files
committed
add FDA for comparison
1 parent 315d812 commit 05765e2

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

examples/plot_WDA.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import matplotlib.pylab as pl
1212
import ot
1313
from ot.datasets import get_1D_gauss as gauss
14-
from ot.dr import wda
14+
from ot.dr import wda, fda
1515

1616

1717
#%% parameters
@@ -36,7 +36,12 @@
3636
pl.title('Discriminant dimensions')
3737

3838

39-
#%% plot distributions and loss matrix
39+
#%% Comlpute FDA
40+
p=2
41+
42+
Pfda,projfda = fda(xs,ys,p)
43+
44+
#%% Compute WDA
4045
p=2
4146
reg=1
4247
k=10
@@ -46,8 +51,8 @@
4651

4752
#%% plot samples
4853

49-
xsp=proj(xs)
50-
xtp=proj(xt)
54+
xsp=projfda(xs)
55+
xtp=projfda(xt)
5156

5257
pl.figure(1,(10,5))
5358

ot/dr.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pymanopt.manifolds import Stiefel
88
from pymanopt import Problem
99
from pymanopt.solvers import SteepestDescent, TrustRegions
10+
import scipy.linalg as la
1011

1112
def dist(x1,x2):
1213
""" Compute squared euclidean distance between samples (autograd)
@@ -32,9 +33,73 @@ def split_classes(X,y):
3233
"""
3334
lstsclass=np.unique(y)
3435
return [X[y==i,:].astype(np.float32) for i in lstsclass]
36+
37+
38+
def fda(X,y,p=2,reg=1e-16):
39+
"""
40+
Fisher Discriminant Analysis
41+
3542
43+
Parameters
44+
----------
45+
X : numpy.ndarray (n,d)
46+
Training samples
47+
y : np.ndarray (n,)
48+
labels for training samples
49+
p : int, optional
50+
size of dimensionnality reduction
51+
reg : float, optional
52+
Regularization term >0 (ridge regularization)
3653
3754
55+
Returns
56+
-------
57+
P : (d x p) ndarray
58+
Optimal transportation matrix for the given parameters
59+
proj : fun
60+
projection function including mean centering
61+
62+
63+
"""
64+
65+
mx=np.mean(X)
66+
X-=mx.reshape((1,-1))
67+
68+
# data split between classes
69+
d=X.shape[1]
70+
xc=split_classes(X,y)
71+
nc=len(xc)
72+
73+
p=min(nc-1,p)
74+
75+
Cw=0
76+
for x in xc:
77+
Cw+=np.cov(x,rowvar=False)
78+
Cw/=nc
79+
80+
mxc=np.zeros((d,nc))
81+
82+
for i in range(nc):
83+
mxc[:,i]=np.mean(xc[i])
84+
85+
mx0=np.mean(mxc,1)
86+
Cb=0
87+
for i in range(nc):
88+
Cb+=(mxc[:,i]-mx0).reshape((-1,1))*(mxc[:,i]-mx0).reshape((1,-1))
89+
90+
w,V=la.eig(Cb,Cw+reg*np.eye(d))
91+
92+
idx=np.argsort(w.real)
93+
94+
Popt=V[:,idx[-p:]]
95+
96+
97+
98+
def proj(X):
99+
return (X-mx.reshape((1,-1))).dot(Popt)
100+
101+
return Popt, proj
102+
38103
def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
39104
"""
40105
Wasserstein Discriminant Analysis [11]_
@@ -73,16 +138,13 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
73138
P : (d x p) ndarray
74139
Optimal transportation matrix for the given parameters
75140
proj : fun
76-
projectiuon function including mean centering
141+
projection function including mean centering
77142
78143
79144
References
80145
----------
81146
82147
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
83-
84-
85-
86148
87149
"""
88150

@@ -131,3 +193,6 @@ def proj(X):
131193
return (X-mx.reshape((1,-1))).dot(Popt)
132194

133195
return Popt, proj
196+
197+
198+

0 commit comments

Comments
 (0)