Skip to content

Commit ddcc758

Browse files
add functionality to support xtype=q
1 parent d3f3161 commit ddcc758

File tree

2 files changed

+106
-16
lines changed

2 files changed

+106
-16
lines changed

src/diffpy/labpdfproc/functions.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import pandas as pd
66
from scipy.interpolate import interp1d
77

8-
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
8+
from diffpy.utils.scattering_objects.diffraction_objects import (
9+
ANGLEQUANTITIES,
10+
DQUANTITIES,
11+
XQUANTITIES,
12+
Diffraction_object,
13+
)
914

1015
RADIUS_MM = 1
1116
N_POINTS_ON_DIAMETER = 300
@@ -198,7 +203,6 @@ def _cve_brute_force(diffraction_data, mud):
198203
"tth",
199204
metadata=diffraction_data.metadata,
200205
name=f"absorption correction, cve, for {diffraction_data.name}",
201-
wavelength=diffraction_data.wavelength,
202206
scat_quantity="cve",
203207
)
204208
return cve_do
@@ -227,7 +231,6 @@ def _cve_polynomial_interpolation(diffraction_data, mud):
227231
"tth",
228232
metadata=diffraction_data.metadata,
229233
name=f"absorption correction, cve, for {diffraction_data.name}",
230-
wavelength=diffraction_data.wavelength,
231234
scat_quantity="cve",
232235
)
233236
return cve_do
@@ -246,15 +249,54 @@ def _cve_method(method):
246249
return methods[method]
247250

248251

249-
def compute_cve(diffraction_data, mud, method="polynomial_interpolation"):
252+
def interpolate_to_xtype_grid(cve_do, xtype):
253+
f"""
254+
interpolates the cve grid to the xtype user specifies, raise an error if xtype is invalid
255+
256+
Parameters
257+
----------
258+
cve_do Diffraction_object
259+
the diffraction object that contains the cve to be applied
260+
xtype str
261+
the quantity on the independent variable axis, allowed values are {*XQUANTITIES, }
262+
263+
Returns
264+
-------
265+
the new diffraction object with interpolated cve curves
266+
"""
267+
268+
if xtype.lower() not in XQUANTITIES:
269+
raise ValueError(f"Unknown xtype: {xtype}. Allowed xtypes are {*XQUANTITIES, }.")
270+
if xtype.lower() in ANGLEQUANTITIES or xtype.lower() in DQUANTITIES:
271+
return cve_do
272+
273+
orig_grid, orig_cve = cve_do.on_tth[0], cve_do.on_tth[1]
274+
new_grid = cve_do.tth_to_q()
275+
new_cve = np.interp(new_grid, orig_grid, orig_cve)
276+
new_cve_do = Diffraction_object(wavelength=cve_do.wavelength)
277+
new_cve_do.insert_scattering_quantity(
278+
new_grid,
279+
new_cve,
280+
xtype,
281+
metadata=cve_do.metadata,
282+
name=cve_do.name,
283+
scat_quantity="cve",
284+
)
285+
return new_cve_do
286+
287+
288+
def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype="tth"):
250289
f"""
251290
compute and interpolate the cve for the given diffraction data and mud using the selected method
291+
252292
Parameters
253293
----------
254294
diffraction_data Diffraction_object
255295
the diffraction pattern
256296
mud float
257297
the mu*D of the diffraction object, where D is the diameter of the circle
298+
xtype str
299+
the quantity on the independent variable axis, allowed values are {*XQUANTITIES, }
258300
method str
259301
the method used to calculate cve, must be one of {* CVE_METHODS, }
260302
@@ -265,21 +307,20 @@ def compute_cve(diffraction_data, mud, method="polynomial_interpolation"):
265307

266308
cve_function = _cve_method(method)
267309
cve_do_on_global_tth = cve_function(diffraction_data, mud)
268-
global_tth = cve_do_on_global_tth.on_tth[0]
269-
cve_on_global_tth = cve_do_on_global_tth.on_tth[1]
270-
orig_grid = diffraction_data.on_tth[0]
271-
newcve = np.interp(orig_grid, global_tth, cve_on_global_tth)
310+
cve_do_on_global_xtype = interpolate_to_xtype_grid(cve_do_on_global_tth, xtype)
311+
orig_grid = diffraction_data.on_xtype(xtype)[0]
312+
global_xtype = cve_do_on_global_xtype.on_xtype(xtype)[0]
313+
cve_on_global_xtype = cve_do_on_global_xtype.on_xtype(xtype)[1]
314+
newcve = np.interp(orig_grid, global_xtype, cve_on_global_xtype)
272315
cve_do = Diffraction_object(wavelength=diffraction_data.wavelength)
273316
cve_do.insert_scattering_quantity(
274317
orig_grid,
275318
newcve,
276-
"tth",
319+
xtype,
277320
metadata=diffraction_data.metadata,
278321
name=f"absorption correction, cve, for {diffraction_data.name}",
279-
wavelength=diffraction_data.wavelength,
280322
scat_quantity="cve",
281323
)
282-
283324
return cve_do
284325

285326

tests/test_functions.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
import numpy as np
44
import pytest
55

6-
from diffpy.labpdfproc.functions import CVE_METHODS, Gridded_circle, apply_corr, compute_cve
7-
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
6+
from diffpy.labpdfproc.functions import (
7+
CVE_METHODS,
8+
Gridded_circle,
9+
apply_corr,
10+
compute_cve,
11+
interpolate_to_xtype_grid,
12+
)
13+
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
814

915
params1 = [
1016
([0.5, 3, 1], {(0.0, -0.5), (0.0, 0.0), (0.5, 0.0), (-0.5, 0.0), (0.0, 0.5)}),
@@ -58,19 +64,62 @@ def test_set_muls_at_angle(inputs, expected):
5864
assert actual_muls_sorted == pytest.approx(expected_muls_sorted, rel=1e-4, abs=1e-6)
5965

6066

61-
def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"):
67+
def _instantiate_test_do(xarray, yarray, xtype="tth", name="test", scat_quantity="x-ray"):
6268
test_do = Diffraction_object(wavelength=1.54)
6369
test_do.insert_scattering_quantity(
6470
xarray,
6571
yarray,
66-
"tth",
72+
xtype,
6773
scat_quantity=scat_quantity,
6874
name=name,
6975
metadata={"thing1": 1, "thing2": "thing2"},
7076
)
7177
return test_do
7278

7379

80+
params4 = [
81+
([np.array([30, 60, 90]), np.array([1, 2, 3]), "tth"], [np.array([30, 60, 90]), np.array([1, 2, 3]), "tth"]),
82+
(
83+
[np.array([30, 60, 90]), np.array([1, 2, 3]), "q"],
84+
[np.array([2.11195, 4.07999, 5.76998]), np.array([1, 1, 1]), "q"],
85+
),
86+
]
87+
88+
89+
@pytest.mark.parametrize("inputs, expected", params4)
90+
def test_interpolate_xtype(inputs, expected, mocker):
91+
expected_cve_do = _instantiate_test_do(
92+
expected[0],
93+
expected[1],
94+
xtype=expected[2],
95+
name="absorption correction, cve, for test",
96+
scat_quantity="cve",
97+
)
98+
input_cve_do = _instantiate_test_do(
99+
inputs[0],
100+
inputs[1],
101+
xtype="tth",
102+
name="absorption correction, cve, for test",
103+
scat_quantity="cve",
104+
)
105+
actual_cve_do = interpolate_to_xtype_grid(input_cve_do, xtype=inputs[2])
106+
assert actual_cve_do == expected_cve_do
107+
108+
109+
def test_interpolate_xtype_bad():
110+
input_cve_do = _instantiate_test_do(
111+
np.array([30, 60, 90]),
112+
np.array([1, 2, 3]),
113+
xtype="tth",
114+
name="absorption correction, cve, for test",
115+
scat_quantity="cve",
116+
)
117+
with pytest.raises(
118+
ValueError, match=re.escape(f"Unknown xtype: invalid. Allowed xtypes are {*XQUANTITIES, }.")
119+
):
120+
interpolate_to_xtype_grid(input_cve_do, xtype="invalid")
121+
122+
74123
def test_compute_cve(mocker):
75124
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
76125
expected_cve = np.array([0.5, 0.5, 0.5])
@@ -92,7 +141,7 @@ def test_compute_cve(mocker):
92141
[7, "polynomial_interpolation"],
93142
[
94143
f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. "
95-
f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }."
144+
f"Please rerun with a value within this range or specifying another method from {*CVE_METHODS, }."
96145
],
97146
),
98147
([1, "invalid_method"], [f"Unknown method: invalid_method. Allowed methods are {*CVE_METHODS, }."]),

0 commit comments

Comments
 (0)