|
3 | 3 | from diffpy.snmf.containers import ComponentSignal
|
4 | 4 | from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
|
5 | 5 | update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, \
|
6 |
| - construct_component_matrix, construct_weight_matrix |
| 6 | + construct_component_matrix, construct_weight_matrix, update_weights |
7 | 7 |
|
8 | 8 | to = [
|
9 | 9 | ([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
|
@@ -228,7 +228,23 @@ def test_construct_weight_matrix(tcwm):
|
228 | 228 | for component in tcwm:
|
229 | 229 | np.testing.assert_allclose(actual[component.id], component.weights)
|
230 | 230 |
|
231 |
| -tuw = [] |
232 |
| -@pytest.mark.parametrize('tuw',tuw) |
| 231 | + |
| 232 | +tuw = [([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), |
| 233 | + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[1, 1], [1.2, 1.3], [1.3, 1.4], [1.4, 1.5], [2, 2.1]], None), |
| 234 | + ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), |
| 235 | + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[1, 1], [1.2, 1.3], [1.3, 1.4], [1.4, 1.5], [2, 2.1]], "align"), |
| 236 | + ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), |
| 237 | + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], None), |
| 238 | + ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), |
| 239 | + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], "align"), |
| 240 | + ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), |
| 241 | + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[-.5, 1], [1.2, -1.3], [1.1, -1], [0, -1.5], [0, .1]], None), |
| 242 | + ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), |
| 243 | + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[-.5, 1], [1.2, -1.3], [1.1, -1], [0, -1.5], [0, .1]], "align"), |
| 244 | + ] |
| 245 | +@pytest.mark.parametrize('tuw', tuw) |
233 | 246 | def test_update_weights(tuw):
|
234 |
| - assert False |
| 247 | + for component in tuw[0]: |
| 248 | + print(component.weights) |
| 249 | + actual = update_weights(tuw[0], tuw[1], tuw[2]) |
| 250 | + assert np.shape(actual) == (len(tuw[0]), len(tuw[0][0].weights)) |
0 commit comments