Skip to content

Commit 8cf7163

Browse files
authored
Merge pull request #43 from aajayi-21/construct_component
function construct_component_matrix
2 parents e2fd159 + 50ea324 commit 8cf7163

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

diffpy/snmf/subroutines.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,43 @@ def construct_stretching_matrix(components, number_of_components, number_of_sign
7474
if (len(components)) == 0:
7575
raise ValueError(f"Number of components = {number_of_components}. Number_of_components must be >= 1.")
7676
number_of_components = len(components)
77-
77+
7878
if number_of_signals <= 0:
7979
raise ValueError(f"Number of signals = {number_of_signals}. Number_of_signals must be >= 1.")
80-
80+
8181
stretching_factor_matrix = np.zeros((number_of_components, number_of_signals))
8282
for i, component in enumerate(components):
8383
stretching_factor_matrix[i] = component.stretching_factors
8484
return stretching_factor_matrix
8585

8686

87+
def construct_component_matrix(components):
88+
"""Constructs the component matrix
89+
90+
Parameters
91+
----------
92+
components: tuple of ComponentSignal objects
93+
The tuple containing the component signals in ComponentSignal objects.
94+
95+
Returns
96+
-------
97+
2d array
98+
The matrix containing the component signal values. Has dimensions `signal_length` x `number_of_components`.
99+
100+
"""
101+
signal_length = len(components[0].iq)
102+
number_of_components = len(components)
103+
if signal_length == 0:
104+
raise ValueError(f"Signal length = {signal_length}. Signal length must be >= 1")
105+
if number_of_components == 0:
106+
raise ValueError(f"Number of components = {number_of_components}. Number_of_components must be >= 1")
107+
108+
component_matrix = np.zeros((number_of_components, signal_length))
109+
for i, component in enumerate(components):
110+
component_matrix[i] = component.iq
111+
return component_matrix
112+
113+
87114
def initialize_arrays(number_of_components, number_of_moments, signal_length):
88115
"""Generates the initial guesses for the weight, stretching, and component matrices
89116

diffpy/snmf/tests/test_subroutines.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from diffpy.snmf.containers import ComponentSignal
44
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
5-
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix
5+
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, construct_component_matrix
66

77
to = [
88
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -186,3 +186,23 @@ def test_construct_stretching_matrix(tcso):
186186
for component in tcso[0]:
187187
np.testing.assert_allclose(actual[component.id,:], component.stretching_factors)
188188
#assert actual[component.id, :] == component.stretching_factors
189+
190+
tccm = [
191+
([ComponentSignal([0,.25,.5,.75,1],20,0)]),
192+
([ComponentSignal([0,.25,.5,.75,1],0,0)]),
193+
([ComponentSignal([0,.25,.5,.75,1],20,0),ComponentSignal([0,.25,.5,.75,1],20,1),ComponentSignal([0,.25,.5,.75,1],20,2)]),
194+
([ComponentSignal([0, .25, .5, .75, 1], 20, 0), ComponentSignal([0, .25, .5, .75, 1], 20, 1),
195+
ComponentSignal([0, .25, .5, .75, 1], 20, 2)]),
196+
([ComponentSignal([0, .25, .5, .75, 1], 20, 0), ComponentSignal([0, .25, .5, 2.75, 1], 20, 1),
197+
ComponentSignal([0, .25, .5, .75, 1], 20, 2)]),
198+
([ComponentSignal([.25], 20, 0), ComponentSignal([.25], 20, 1), ComponentSignal([.25], 20, 2)]),
199+
([ComponentSignal([0, .25, .5, .75, 1], 20, 0), ComponentSignal([0, .25, .5, .75, 1], 20, 1)]),
200+
# ([ComponentSignal([[0, .25, .5, .75, 1],[0, .25, .5, .75, 1]], 20, 0), ComponentSignal([[0, .25, .5, .75, 1],[0, .25, .5, .75, 1]], 20, 1)]), # iq is multidimensional. Expected to fail
201+
# (ComponentSignal([], 20, 0)), # Expected to fail
202+
# ([]), #Expected to fail
203+
]
204+
@pytest.mark.parametrize('tccm',tccm)
205+
def test_construct_component_matrix(tccm):
206+
actual = construct_component_matrix(tccm)
207+
for component in tccm:
208+
np.testing.assert_allclose(actual[component.id], component.iq)

0 commit comments

Comments
 (0)