|
1 | 1 | import numpy as np |
2 | 2 | import pytest |
3 | 3 | from numpy.polynomial import Polynomial |
| 4 | +from scipy.interpolate import interp1d |
4 | 5 |
|
5 | 6 | from diffpy.morph.morphs.morphsqueeze import MorphSqueeze |
6 | 7 |
|
|
21 | 22 | [0.1, 0.3], |
22 | 23 | # 4th order squeeze coefficients |
23 | 24 | [0.2, -0.01, 0.001, -0.001, 0.0004], |
24 | | - # Zeros and non-zeros, expect 0 + a1x + 0 + a3x**3 |
| 25 | + # Zeros and non-zeros, the full polynomial is applied |
25 | 26 | [0, 0.03, 0, -0.001], |
26 | 27 | # Testing zeros, expect no squeezing |
27 | | - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 28 | + [0, 0, 0, 0, 0, 0], |
28 | 29 | ], |
29 | 30 | ) |
30 | 31 | def test_morphsqueeze(squeeze_coeffs): |
31 | | - x_target = np.linspace(0, 10, 1001) |
32 | | - y_target = np.sin(x_target) |
33 | | - |
34 | | - x_make = np.linspace(-3, 13, 1601) |
35 | | - lower_idx = np.where(x_make == 0.0)[0][0] |
36 | | - upper_idx = np.where(x_make == 10.0)[0][0] |
37 | | - |
| 32 | + x_expected = np.linspace(0, 10, 1001) |
| 33 | + y_expected = np.sin(x_expected) |
| 34 | + x_make = np.linspace(-3, 13, 3250) |
38 | 35 | squeeze_polynomial = Polynomial(squeeze_coeffs) |
39 | 36 | x_squeezed = x_make + squeeze_polynomial(x_make) |
40 | | - |
41 | | - x_morph = x_make.copy() |
42 | 37 | y_morph = np.sin(x_squeezed) |
43 | | - |
44 | 38 | morph = MorphSqueeze() |
45 | 39 | morph.squeeze = squeeze_coeffs |
46 | | - |
47 | | - x_actual, y_actual, x_expected, y_expected = morph( |
48 | | - x_morph, y_morph, x_target, y_target |
| 40 | + x_actual, y_actual, x_target, y_target = morph( |
| 41 | + x_make, y_morph, x_expected, y_expected |
49 | 42 | ) |
50 | | - y_actual = y_actual[lower_idx : upper_idx + 1] |
| 43 | + y_actual = interp1d(x_actual, y_actual)(x_target) |
| 44 | + x_actual = x_target |
51 | 45 | assert np.allclose(y_actual, y_expected) |
| 46 | + assert np.allclose(x_actual, x_expected) |
| 47 | + assert np.allclose(x_target, x_expected) |
| 48 | + assert np.allclose(y_target, y_expected) |
52 | 49 |
|
53 | 50 | # Plotting code used for figures in PR comments |
54 | 51 | # https://github.com/diffpy/diffpy.morph/pull/180 |
55 | 52 | # plt.figure() |
56 | 53 | # plt.scatter(x_expected, y_expected, color='black', label='Expected') |
57 | | - # plt.plot(x_morph, y_morph, color='purple', label='morph') |
| 54 | + # plt.plot(x_make, y_morph, color='purple', label='morph') |
58 | 55 | # plt.plot(x_actual, y_actual, '--', color='gold', label='Actual') |
59 | 56 | # plt.legend() |
60 | 57 | # plt.show() |
0 commit comments