Skip to content

Commit d8d4e11

Browse files
author
John Halloran
committed
feat: Add random state feature.
1 parent 8613ea0 commit d8d4e11

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,20 @@
44

55

66
class SNMFOptimizer:
7-
def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, max_iter=500, tol=5e-7, components=None):
8-
print("Initializing SNMF Optimizer")
7+
def __init__(
8+
self,
9+
MM,
10+
Y0=None,
11+
X0=None,
12+
A=None,
13+
rho=1e12,
14+
eta=610,
15+
max_iter=500,
16+
tol=5e-7,
17+
components=None,
18+
random_state=None,
19+
):
20+
921
self.MM = MM
1022
self.X0 = X0
1123
self.Y0 = Y0
@@ -15,23 +27,22 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, max_iter=500
1527
# Capture matrix dimensions
1628
self.N, self.M = MM.shape
1729
self.num_updates = 0
30+
self.rng = np.random.default_rng(random_state)
1831

1932
if Y0 is None:
2033
if components is None:
2134
raise ValueError("Must provide either Y0 or a number of components.")
2235
else:
2336
self.K = components
24-
self.Y0 = np.random.beta(a=2.5, b=1.5, size=(self.K, self.M)) # This is untested
37+
self.Y0 = self.rng.beta(a=2.5, b=1.5, size=(self.K, self.M))
2538
else:
2639
self.K = Y0.shape[0]
2740

28-
# Initialize A, X0 if not provided
2941
if self.A is None:
30-
self.A = np.ones((self.K, self.M)) + np.random.randn(self.K, self.M) * 1e-3 # Small perturbation
42+
self.A = np.ones((self.K, self.M)) + self.rng.normal(0, 1e-3, size=(self.K, self.M))
3143
if self.X0 is None:
32-
self.X0 = np.random.rand(self.N, self.K) # Ensures values in [0,1]
44+
self.X0 = self.rng.random((self.N, self.K))
3345

34-
# Initialize solution matrices to be iterated on
3546
self.X = np.maximum(0, self.X0)
3647
self.Y = np.maximum(0, self.Y0)
3748

0 commit comments

Comments
 (0)