Skip to content

Commit 6ba837f

Browse files
author
John Halloran
committed
Make logic for n_components and Y0 more rigid
1 parent 7f8e33d commit 6ba837f

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

src/diffpy/snmf/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Y0 = np.loadtxt("input/W0.txt", dtype=float)
88
N, M = MM.shape
99

10-
my_model = snmf_class.SNMFOptimizer(MM=MM, Y0=Y0, X0=X0, A=A0, n_components=2)
10+
my_model = snmf_class.SNMFOptimizer(MM=MM, Y0=Y0, X0=X0, A0=A0)
1111
print("Done")
1212
np.savetxt("my_norm_X.txt", my_model.X, fmt="%.6g", delimiter=" ")
1313
np.savetxt("my_norm_Y.txt", my_model.Y, fmt="%.6g", delimiter=" ")

src/diffpy/snmf/snmf_class.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
MM,
2121
Y0=None,
2222
X0=None,
23-
A=None,
23+
A0=None,
2424
rho=1e12,
2525
eta=610,
2626
max_iter=500,
@@ -36,12 +36,12 @@ def __init__(
3636
The data to be decomposed. Shape is (length_of_signal, number_of_conditions).
3737
Y0 : ndarray
3838
The initial guesses for the component weights at each stretching condition.
39-
Shape is (number of components, number ofconditions) Must be provided if
40-
n_components is not provided. Will override n_components if both are provided.
39+
Shape is (number_of_components, number_of_conditions) Must provide exactly one
40+
of this or n_components.
4141
X0 : ndarray
4242
The initial guesses for the intensities of each component per
4343
row/sample/angle. Shape is (length_of_signal, number_of_components).
44-
A : ndarray
44+
A0 : ndarray
4545
The initial guesses for the stretching factor for each component, at each
4646
condition. Shape is (number_of_components, number_of_conditions).
4747
rho : float
@@ -60,41 +60,48 @@ def __init__(
6060
objective function to allow without terminating the optimization. Note that
6161
a minimum of 20 updates are run before this parameter is checked.
6262
n_components : int
63-
The number of components to extract from MM. Note that this will
64-
be overridden by Y0 if that is provided, but must be provided if no Y0 is
65-
provided.
63+
The number of components to extract from MM. Must be provided when and only when
64+
Y0 is not provided.
6665
random_state : int
6766
The seed for the initial guesses at the matrices (A, X, and Y) created by
6867
the decomposition.
6968
"""
7069

7170
self.MM = MM
72-
self.X0 = X0
73-
self.Y0 = Y0
74-
self.A = A
7571
self.rho = rho
7672
self.eta = eta
7773
# Capture matrix dimensions
7874
self.N, self.M = MM.shape
7975
self.num_updates = 0
8076
self._rng = np.random.default_rng(random_state)
8177

78+
# Enforce exclusive specification of n_components or Y0
79+
if (n_components is None) == (Y0 is not None):
80+
raise ValueError("Must provide exactly one of Y0 or n_components, but not both.")
81+
82+
# Initialize Y0 and determine number of components
8283
if Y0 is None:
83-
if n_components is None:
84-
raise ValueError("Must provide either Y0 or n_components.")
85-
else:
86-
self.K = n_components
87-
self.Y0 = self._rng.beta(a=2.5, b=1.5, size=(self.K, self.M))
84+
self.K = n_components
85+
self.Y = self._rng.beta(a=2.5, b=1.5, size=(self.K, self.M))
8886
else:
8987
self.K = Y0.shape[0]
88+
self.Y = Y0
9089

90+
# Initialize A if not provided
9191
if self.A is None:
9292
self.A = np.ones((self.K, self.M)) + self._rng.normal(0, 1e-3, size=(self.K, self.M))
93-
if self.X0 is None:
94-
self.X0 = self._rng.random((self.N, self.K))
93+
else:
94+
self.A = A0
95+
96+
# Initialize X0 if not provided
97+
if self.X is None:
98+
self.X = self._rng.random((self.N, self.K))
99+
else:
100+
self.X = X0
95101

96-
self.X = np.maximum(0, self.X0)
97-
self.Y = np.maximum(0, self.Y0)
102+
# Enforce non-negativity
103+
self.X = np.maximum(0, self.X)
104+
self.Y = np.maximum(0, self.Y)
98105

99106
# Second-order spline: Tridiagonal (-2 on diagonal, 1 on sub/superdiagonals)
100107
self.P = 0.25 * diags([1, -2, 1], offsets=[0, 1, 2], shape=(self.M - 2, self.M))

0 commit comments

Comments
 (0)