Skip to content

Commit c644543

Browse files
authored
Merge pull request #45 from aajayi-21/construct_weights
function construct_weight_matrix
2 parents 8cf7163 + 2f3e2d9 commit c644543

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

diffpy/snmf/subroutines.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,34 @@ def construct_component_matrix(components):
111111
return component_matrix
112112

113113

114+
def construct_weight_matrix(components):
115+
"""Constructs the weights matrix
116+
117+
Constructs a Ķ x M matrix where K is the number of components and M is the
118+
number of signals. Each element is the stretching factor for a specific
119+
weights for a specific signal from the data input.
120+
121+
Parameters
122+
----------
123+
components: tuple of ComponentSignal objects
124+
The tuple containing the component signals.
125+
126+
Returns
127+
-------
128+
2d array like
129+
The 2d array containing the weightings for each component for each signal.
130+
"""
131+
number_of_components = len(components)
132+
number_of_signals = len(components[0].weights)
133+
if number_of_components == 0:
134+
raise ValueError(f"Number of components = {number_of_components}. Number of components must be >= 1")
135+
if number_of_signals == 0:
136+
raise ValueError(f"Number of signals = {number_of_signals}. Number_of_signals must be >= 1.")
137+
weights_matrix = np.zeros((number_of_components,number_of_signals))
138+
for i, component in enumerate(components):
139+
weights_matrix[i] = component.weights
140+
return weights_matrix
141+
114142
def initialize_arrays(number_of_components, number_of_moments, signal_length):
115143
"""Generates the initial guesses for the weight, stretching, and component matrices
116144

diffpy/snmf/tests/test_subroutines.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from diffpy.snmf.containers import ComponentSignal
44
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
5-
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, construct_component_matrix
5+
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, construct_component_matrix, construct_weight_matrix
66

77
to = [
88
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -206,3 +206,22 @@ def test_construct_component_matrix(tccm):
206206
actual = construct_component_matrix(tccm)
207207
for component in tccm:
208208
np.testing.assert_allclose(actual[component.id], component.iq)
209+
210+
tcwm = [
211+
([ComponentSignal([0,.25,.5,.75,1],20,0)]),
212+
# ([ComponentSignal([0,.25,.5,.75,1],0,0)]), # 0 signal length. Failure expected
213+
([ComponentSignal([0,.25,.5,.75,1],20,0),ComponentSignal([0,.25,.5,.75,1],20,1),ComponentSignal([0,.25,.5,.75,1],20,2)]),
214+
([ComponentSignal([0, .25, .5, .75, 1], 20, 0), ComponentSignal([0, .25, .5, .75, 1], 20, 1),
215+
ComponentSignal([0, .25, .5, .75, 1], 20, 2)]),
216+
([ComponentSignal([0, .25, .5, .75, 1], 20, 0), ComponentSignal([0, .25, .5, 2.75, 1], 20, 1),
217+
ComponentSignal([0, .25, .5, .75, 1], 20, 2)]),
218+
([ComponentSignal([.25], 20, 0), ComponentSignal([.25], 20, 1), ComponentSignal([.25], 20, 2)]),
219+
([ComponentSignal([0, .25, .5, .75, 1], 20, 0), ComponentSignal([0, .25, .5, .75, 1], 20, 1)]),
220+
#(ComponentSignal([], 20, 0)), # Expected to fail
221+
#([]), #Expected to fail
222+
]
223+
@pytest.mark.parametrize('tcwm',tcwm)
224+
def test_construct_weight_matrix(tcwm):
225+
actual = construct_weight_matrix(tcwm)
226+
for component in tcwm:
227+
np.testing.assert_allclose(actual[component.id], component.weights)

0 commit comments

Comments
 (0)