Skip to content

Commit 7be6477

Browse files
committed
initial commit
1 parent c644543 commit 7be6477

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

diffpy/snmf/subroutines.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,35 @@ def construct_weight_matrix(components):
134134
raise ValueError(f"Number of components = {number_of_components}. Number of components must be >= 1")
135135
if number_of_signals == 0:
136136
raise ValueError(f"Number of signals = {number_of_signals}. Number_of_signals must be >= 1.")
137-
weights_matrix = np.zeros((number_of_components,number_of_signals))
137+
weights_matrix = np.zeros((number_of_components, number_of_signals))
138138
for i, component in enumerate(components):
139139
weights_matrix[i] = component.weights
140140
return weights_matrix
141141

142+
143+
def update_weights(components, data_input, method=None):
144+
"""Updates the weights matrix.
145+
146+
Updates the weights matrix and the weights vector for each ComponentSignal object.
147+
148+
Parameters
149+
----------
150+
components: tuple of ComponentSignal objects
151+
The tuple containing the component signals.
152+
method: str
153+
The string specifying which method should be used to find a new weight matrix: non-negative least squares or a
154+
quadratic program.
155+
data_input: 2d array
156+
The 2d array containing the user-provided signals.
157+
158+
Returns
159+
-------
160+
2d array
161+
The 2d array containing the weight factors for each component for each signal from `data_input`. Has dimensions
162+
K x M where K is the number of components and M is the number of signals in `data_input.`
163+
"""
164+
165+
142166
def initialize_arrays(number_of_components, number_of_moments, signal_length):
143167
"""Generates the initial guesses for the weight, stretching, and component matrices
144168

diffpy/snmf/tests/test_subroutines.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import numpy as np
33
from diffpy.snmf.containers import ComponentSignal
44
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
5-
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, construct_component_matrix, construct_weight_matrix
5+
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, \
6+
construct_component_matrix, construct_weight_matrix
67

78
to = [
89
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -207,6 +208,7 @@ def test_construct_component_matrix(tccm):
207208
for component in tccm:
208209
np.testing.assert_allclose(actual[component.id], component.iq)
209210

211+
210212
tcwm = [
211213
([ComponentSignal([0,.25,.5,.75,1],20,0)]),
212214
# ([ComponentSignal([0,.25,.5,.75,1],0,0)]), # 0 signal length. Failure expected
@@ -225,3 +227,8 @@ def test_construct_weight_matrix(tcwm):
225227
actual = construct_weight_matrix(tcwm)
226228
for component in tcwm:
227229
np.testing.assert_allclose(actual[component.id], component.weights)
230+
231+
tuw = []
232+
@pytest.mark.parametrize('tuw',tuw)
233+
def test_update_weights(tuw):
234+
assert False

0 commit comments

Comments
 (0)