Skip to content

Commit f353884

Browse files
committed
added function contents and tests
1 parent 7be6477 commit f353884

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

diffpy/snmf/subroutines.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,22 @@ def update_weights(components, data_input, method=None):
161161
The 2d array containing the weight factors for each component for each signal from `data_input`. Has dimensions
162162
K x M where K is the number of components and M is the number of signals in `data_input.`
163163
"""
164+
data_input = np.asarray(data_input)
165+
weight_matrix = construct_weight_matrix(components)
166+
number_of_signals = len(components[0].weights)
167+
number_of_components = len(components)
168+
signal_length = len(components[0].grid)
169+
for signal in range(number_of_signals):
170+
stretched_components = np.zeros((signal_length, number_of_components))
171+
for i, component in enumerate(components):
172+
stretched_components[:, i] = component.apply_stretch(signal)[0]
173+
if method == 'align':
174+
weights = lsqnonneg(stretched_components, data_input[:,signal])
175+
else:
176+
weights = get_weights(stretched_components.T @ stretched_components,
177+
-stretched_components.T @ data_input[:, signal], 0, 1)
178+
weight_matrix[:, signal] = weights
179+
return weight_matrix
164180

165181

166182
def initialize_arrays(number_of_components, number_of_moments, signal_length):

diffpy/snmf/tests/test_subroutines.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from diffpy.snmf.containers import ComponentSignal
44
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
55
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, \
6-
construct_component_matrix, construct_weight_matrix
6+
construct_component_matrix, construct_weight_matrix, update_weights
77

88
to = [
99
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -228,7 +228,23 @@ def test_construct_weight_matrix(tcwm):
228228
for component in tcwm:
229229
np.testing.assert_allclose(actual[component.id], component.weights)
230230

231-
tuw = []
232-
@pytest.mark.parametrize('tuw',tuw)
231+
232+
tuw = [([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
233+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[1, 1], [1.2, 1.3], [1.3, 1.4], [1.4, 1.5], [2, 2.1]], None),
234+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
235+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[1, 1], [1.2, 1.3], [1.3, 1.4], [1.4, 1.5], [2, 2.1]], "align"),
236+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
237+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], None),
238+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
239+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], "align"),
240+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
241+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[-.5, 1], [1.2, -1.3], [1.1, -1], [0, -1.5], [0, .1]], None),
242+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
243+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[-.5, 1], [1.2, -1.3], [1.1, -1], [0, -1.5], [0, .1]], "align"),
244+
]
245+
@pytest.mark.parametrize('tuw', tuw)
233246
def test_update_weights(tuw):
234-
assert False
247+
for component in tuw[0]:
248+
print(component.weights)
249+
actual = update_weights(tuw[0], tuw[1], tuw[2])
250+
assert np.shape(actual) == (len(tuw[0]), len(tuw[0][0].weights))

0 commit comments

Comments
 (0)