7
7
from pymanopt .manifolds import Stiefel
8
8
from pymanopt import Problem
9
9
from pymanopt .solvers import SteepestDescent , TrustRegions
10
+ import scipy .linalg as la
10
11
11
12
def dist (x1 ,x2 ):
12
13
""" Compute squared euclidean distance between samples (autograd)
@@ -32,9 +33,73 @@ def split_classes(X,y):
32
33
"""
33
34
lstsclass = np .unique (y )
34
35
return [X [y == i ,:].astype (np .float32 ) for i in lstsclass ]
36
+
37
+
38
+ def fda (X ,y ,p = 2 ,reg = 1e-16 ):
39
+ """
40
+ Fisher Discriminant Analysis
41
+
35
42
43
+ Parameters
44
+ ----------
45
+ X : numpy.ndarray (n,d)
46
+ Training samples
47
+ y : np.ndarray (n,)
48
+ labels for training samples
49
+ p : int, optional
50
+ size of dimensionnality reduction
51
+ reg : float, optional
52
+ Regularization term >0 (ridge regularization)
36
53
37
54
55
+ Returns
56
+ -------
57
+ P : (d x p) ndarray
58
+ Optimal transportation matrix for the given parameters
59
+ proj : fun
60
+ projection function including mean centering
61
+
62
+
63
+ """
64
+
65
+ mx = np .mean (X )
66
+ X -= mx .reshape ((1 ,- 1 ))
67
+
68
+ # data split between classes
69
+ d = X .shape [1 ]
70
+ xc = split_classes (X ,y )
71
+ nc = len (xc )
72
+
73
+ p = min (nc - 1 ,p )
74
+
75
+ Cw = 0
76
+ for x in xc :
77
+ Cw += np .cov (x ,rowvar = False )
78
+ Cw /= nc
79
+
80
+ mxc = np .zeros ((d ,nc ))
81
+
82
+ for i in range (nc ):
83
+ mxc [:,i ]= np .mean (xc [i ])
84
+
85
+ mx0 = np .mean (mxc ,1 )
86
+ Cb = 0
87
+ for i in range (nc ):
88
+ Cb += (mxc [:,i ]- mx0 ).reshape ((- 1 ,1 ))* (mxc [:,i ]- mx0 ).reshape ((1 ,- 1 ))
89
+
90
+ w ,V = la .eig (Cb ,Cw + reg * np .eye (d ))
91
+
92
+ idx = np .argsort (w .real )
93
+
94
+ Popt = V [:,idx [- p :]]
95
+
96
+
97
+
98
+ def proj (X ):
99
+ return (X - mx .reshape ((1 ,- 1 ))).dot (Popt )
100
+
101
+ return Popt , proj
102
+
38
103
def wda (X ,y ,p = 2 ,reg = 1 ,k = 10 ,solver = None ,maxiter = 100 ,verbose = 0 ):
39
104
"""
40
105
Wasserstein Discriminant Analysis [11]_
@@ -73,16 +138,13 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
73
138
P : (d x p) ndarray
74
139
Optimal transportation matrix for the given parameters
75
140
proj : fun
76
- projectiuon function including mean centering
141
+ projection function including mean centering
77
142
78
143
79
144
References
80
145
----------
81
146
82
147
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
83
-
84
-
85
-
86
148
87
149
"""
88
150
@@ -131,3 +193,6 @@ def proj(X):
131
193
return (X - mx .reshape ((1 ,- 1 ))).dot (Popt )
132
194
133
195
return Popt , proj
196
+
197
+
198
+
0 commit comments