@@ -13,6 +13,52 @@ class SNMFOptimizer:
13
13
For more information on sNMF, please reference:
14
14
Gu, R., Rakita, Y., Lan, L. et al. Stretched non-negative matrix factorization.
15
15
npj Comput Mater 10, 193 (2024). https://doi.org/10.1038/s41524-024-01377-5
16
+
17
+ Attributes
18
+ ----------
19
+ MM : ndarray
20
+ The original, unmodified data to be decomposed and later, compared against.
21
+ Shape is (length_of_signal, number_of_conditions).
22
+ Y : ndarray
23
+ The best guess (or while running, the current guess) for the stretching
24
+ factor matrix.
25
+ X : ndarray
26
+ The best guess (or while running, the current guess) for the matrix of
27
+ component intensities.
28
+ A : ndarray
29
+ The best guess (or while running, the current guess) for the matrix of
30
+ component weights.
31
+ rho : float
32
+ The stretching factor that influences the decomposition. Zero corresponds to no
33
+ stretching present. Relatively insensitive and typically adjusted in powers of 10.
34
+ eta : float
35
+ The sparsity factor that influences the decomposition. Should be set to zero for
36
+ non-sparse data such as PDF. Can be used to improve results for sparse data such
37
+ as XRD, but due to instability, should be used only after first selecting the
38
+ best value for rho. Suggested adjustment is by powers of 2.
39
+ max_iter : int
40
+ The maximum number of times to update each of A, X, and Y before stopping
41
+ the optimization.
42
+ tol : float
43
+ The convergence threshold. This is the minimum fractional improvement in the
44
+ objective function to allow without terminating the optimization. Note that
45
+ a minimum of 20 updates are run before this parameter is checked.
46
+ n_components : int
47
+ The number of components to extract from MM. Must be provided when and only when
48
+ Y0 is not provided.
49
+ random_state : int
50
+ The seed for the initial guesses at the matrices (A, X, and Y) created by
51
+ the decomposition.
52
+ num_updates : int
53
+ The total number of times that any of (A, X, and Y) have had their values changed.
54
+ If not terminated by other means, this value is used to stop when reaching max_iter.
55
+ objective_function: float
56
+ The value corresponding to the minimization of the difference between the MM and the
57
+ products of A, X, and Y. For full details see the sNMF paper. Smaller corresponds to
58
+ better agreement and is desirable.
59
+ objective_difference : float
60
+ The change in the objective function value since the last update. A negative value
61
+ means that the result improved.
16
62
"""
17
63
18
64
def __init__ (
@@ -28,7 +74,7 @@ def __init__(
28
74
n_components = None ,
29
75
random_state = None ,
30
76
):
31
- """Initialize an instance of SNMF and run the optimization
77
+ """Initialize an instance of SNMF and run the optimization.
32
78
33
79
Parameters
34
80
----------
@@ -71,7 +117,7 @@ def __init__(
71
117
self .rho = rho
72
118
self .eta = eta
73
119
# Capture matrix dimensions
74
- self .N , self .M = MM .shape
120
+ self ._N , self ._M = MM .shape
75
121
self .num_updates = 0
76
122
self ._rng = np .random .default_rng (random_state )
77
123
@@ -81,21 +127,21 @@ def __init__(
81
127
82
128
# Initialize Y0 and determine number of components
83
129
if Y0 is None :
84
- self .K = n_components
85
- self .Y = self ._rng .beta (a = 2.5 , b = 1.5 , size = (self .K , self .M ))
130
+ self ._K = n_components
131
+ self .Y = self ._rng .beta (a = 2.5 , b = 1.5 , size = (self ._K , self ._M ))
86
132
else :
87
- self .K = Y0 .shape [0 ]
133
+ self ._K = Y0 .shape [0 ]
88
134
self .Y = Y0
89
135
90
136
# Initialize A if not provided
91
137
if self .A is None :
92
- self .A = np .ones ((self .K , self .M )) + self ._rng .normal (0 , 1e-3 , size = (self .K , self .M ))
138
+ self .A = np .ones ((self ._K , self ._M )) + self ._rng .normal (0 , 1e-3 , size = (self ._K , self ._M ))
93
139
else :
94
140
self .A = A0
95
141
96
142
# Initialize X0 if not provided
97
143
if self .X is None :
98
- self .X = self ._rng .random ((self .N , self .K ))
144
+ self .X = self ._rng .random ((self ._N , self ._K ))
99
145
else :
100
146
self .X = X0
101
147
@@ -104,19 +150,19 @@ def __init__(
104
150
self .Y = np .maximum (0 , self .Y )
105
151
106
152
# Second-order spline: Tridiagonal (-2 on diagonal, 1 on sub/superdiagonals)
107
- self .P = 0.25 * diags ([1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self .M - 2 , self .M ))
153
+ self .P = 0.25 * diags ([1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self ._M - 2 , self ._M ))
108
154
self .PP = self .P .T @ self .P
109
155
110
156
# Set up residual matrix, objective function, and history
111
157
self .R = self .get_residual_matrix ()
112
158
self .objective_function = self .get_objective_function ()
113
159
self .objective_difference = None
114
- self .objective_history = [self .objective_function ]
160
+ self ._objective_history = [self .objective_function ]
115
161
116
162
# Set up tracking variables for updateX()
117
- self .preX = None
118
- self .GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
119
- self .preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
163
+ self ._preX = None
164
+ self ._GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
165
+ self ._preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
120
166
121
167
regularization_term = 0.5 * rho * np .linalg .norm (self .P @ self .A .T , "fro" ) ** 2
122
168
sparsity_term = eta * np .sum (np .sqrt (self .X )) # Square root penalty
@@ -151,53 +197,53 @@ def __init__(
151
197
# loop to normalize X
152
198
# effectively just re-running class with non-normalized X, normalized Y/A as inputs, then only update X
153
199
# reset difference trackers and initialize
154
- self .preX = None
155
- self .GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
156
- self .preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
200
+ self ._preX = None
201
+ self ._GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
202
+ self ._preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
157
203
self .R = self .get_residual_matrix ()
158
204
self .objective_function = self .get_objective_function ()
159
205
self .objective_difference = None
160
- self .objective_history = [self .objective_function ]
206
+ self ._objective_history = [self .objective_function ]
161
207
for norm_iter in range (100 ):
162
208
self .updateX ()
163
209
self .R = self .get_residual_matrix ()
164
210
self .objective_function = self .get_objective_function ()
165
211
print (f"Objective function after normX: { self .objective_function :.5e} " )
166
- self .objective_history .append (self .objective_function )
167
- self .objective_difference = self .objective_history [- 2 ] - self .objective_history [- 1 ]
212
+ self ._objective_history .append (self .objective_function )
213
+ self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
168
214
if self .objective_difference < self .objective_function * tol and norm_iter >= 20 :
169
215
break
170
216
# end of normalization (and program)
171
217
# note that objective function may not fully recover after normalization, this is okay
172
218
print ("Finished optimization." )
173
219
174
220
def optimize_loop (self ):
175
- self .preGraX = self .GraX .copy ()
221
+ self ._preGraX = self ._GraX .copy ()
176
222
self .updateX ()
177
223
self .num_updates += 1
178
224
self .R = self .get_residual_matrix ()
179
225
self .objective_function = self .get_objective_function ()
180
226
print (f"Objective function after updateX: { self .objective_function :.5e} " )
181
- self .objective_history .append (self .objective_function )
227
+ self ._objective_history .append (self .objective_function )
182
228
if self .objective_difference is None :
183
- self .objective_difference = self .objective_history [- 1 ] - self .objective_function
229
+ self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
184
230
185
231
# Now we update Y
186
232
self .updateY2 ()
187
233
self .num_updates += 1
188
234
self .R = self .get_residual_matrix ()
189
235
self .objective_function = self .get_objective_function ()
190
236
print (f"Objective function after updateY2: { self .objective_function :.5e} " )
191
- self .objective_history .append (self .objective_function )
237
+ self ._objective_history .append (self .objective_function )
192
238
193
239
self .updateA2 ()
194
240
195
241
self .num_updates += 1
196
242
self .R = self .get_residual_matrix ()
197
243
self .objective_function = self .get_objective_function ()
198
244
print (f"Objective function after updateA2: { self .objective_function :.5e} " )
199
- self .objective_history .append (self .objective_function )
200
- self .objective_difference = self .objective_history [- 2 ] - self .objective_history [- 1 ]
245
+ self ._objective_history .append (self .objective_function )
246
+ self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
201
247
202
248
def apply_interpolation (self , a , x , return_derivatives = False ):
203
249
"""
@@ -469,36 +515,38 @@ def updateX(self):
469
515
# Compute `AX` using the interpolation function
470
516
AX , _ , _ = self .apply_interpolation_matrix () # Skip the other two outputs
471
517
# Compute RA and RR
472
- intermediate_RA = AX .flatten (order = "F" ).reshape ((self .N * self .M , self .K ), order = "F" )
473
- RA = intermediate_RA .sum (axis = 1 ).reshape ((self .N , self .M ), order = "F" )
518
+ intermediate_RA = AX .flatten (order = "F" ).reshape ((self ._N * self ._M , self ._K ), order = "F" )
519
+ RA = intermediate_RA .sum (axis = 1 ).reshape ((self ._N , self ._M ), order = "F" )
474
520
RR = RA - self .MM
475
521
# Compute gradient `GraX`
476
- self .GraX = self .apply_transformation_matrix (R = RR ).toarray () # toarray equivalent of full, make non-sparse
522
+ self ._GraX = self .apply_transformation_matrix (
523
+ R = RR
524
+ ).toarray () # toarray equivalent of full, make non-sparse
477
525
478
526
# Compute initial step size `L0`
479
527
L0 = np .linalg .eigvalsh (self .Y .T @ self .Y ).max () * np .max ([self .A .max (), 1 / self .A .min ()])
480
528
# Compute adaptive step size `L`
481
- if self .preX is None :
529
+ if self ._preX is None :
482
530
L = L0
483
531
else :
484
- num = np .sum ((self .GraX - self .preGraX ) * (self .X - self .preX )) # Element-wise multiplication
485
- denom = np .linalg .norm (self .X - self .preX , "fro" ) ** 2 # Frobenius norm squared
532
+ num = np .sum ((self ._GraX - self ._preGraX ) * (self .X - self ._preX )) # Element-wise multiplication
533
+ denom = np .linalg .norm (self .X - self ._preX , "fro" ) ** 2 # Frobenius norm squared
486
534
L = num / denom if denom > 0 else L0
487
535
if L <= 0 :
488
536
L = L0
489
537
490
538
# Store our old X before updating because it is used in step selection
491
- self .preX = self .X .copy ()
539
+ self ._preX = self .X .copy ()
492
540
493
541
while True : # iterate updating X
494
- x_step = self .preX - self .GraX / L
542
+ x_step = self ._preX - self ._GraX / L
495
543
# Solve x^3 + p*x + q = 0 for the largest real root
496
544
self .X = np .square (cubic_largest_real_root (- x_step , self .eta / (2 * L )))
497
545
# Mask values that should be set to zero
498
546
mask = self .X ** 2 * L / 2 - L * self .X * x_step + self .eta * np .sqrt (self .X ) < 0
499
547
self .X = mask * self .X
500
548
501
- objective_improvement = self .objective_history [- 1 ] - self .get_objective_function (
549
+ objective_improvement = self ._objective_history [- 1 ] - self .get_objective_function (
502
550
R = self .get_residual_matrix ()
503
551
)
504
552
@@ -515,9 +563,9 @@ def updateY2(self):
515
563
Updates Y using matrix operations, solving a quadratic program via `solve_mkr_box`.
516
564
"""
517
565
518
- K = self .K
519
- N = self .N
520
- M = self .M
566
+ K = self ._K
567
+ N = self ._N
568
+ M = self ._M
521
569
522
570
for m in range (M ):
523
571
T = np .zeros ((N , K )) # Initialize T as an (N, K) zero matrix
@@ -544,9 +592,9 @@ def regularize_function(self, A=None):
544
592
if A is None :
545
593
A = self .A
546
594
547
- K = self .K
548
- M = self .M
549
- N = self .N
595
+ K = self ._K
596
+ M = self ._M
597
+ N = self ._N
550
598
551
599
# Compute interpolated matrices
552
600
AX , TX , HX = self .apply_interpolation_matrix (A = A , return_derivatives = True )
0 commit comments