Skip to content

Commit eb5b771

Browse files
fix: update diffraction objects in functions
1 parent f9a9eb1 commit eb5b771

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

src/diffpy/labpdfproc/functions.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import pandas as pd
66
from scipy.interpolate import interp1d
77

8-
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
8+
from diffpy.utils.diffraction_objects import XQUANTITIES, DiffractionObject
99

1010
RADIUS_MM = 1
1111
N_POINTS_ON_DIAMETER = 300
1212
TTH_GRID = np.arange(1, 180.1, 0.1)
13+
TTH_GRID = np.round(TTH_GRID, 1)
1314
CVE_METHODS = ["brute_force", "polynomial_interpolation"]
1415

1516
# pre-computed datasets for polynomial interpolation (fast calculation)
@@ -191,14 +192,14 @@ def _cve_brute_force(diffraction_data, mud):
191192
muls = np.array(muls) / abs_correction.total_points_in_grid
192193
cve = 1 / muls
193194

194-
cve_do = Diffraction_object(wavelength=diffraction_data.wavelength)
195-
cve_do.insert_scattering_quantity(
196-
TTH_GRID,
197-
cve,
198-
"tth",
199-
metadata=diffraction_data.metadata,
200-
name=f"absorption correction, cve, for {diffraction_data.name}",
195+
cve_do = DiffractionObject(
196+
xarray=TTH_GRID,
197+
yarray=cve,
198+
xtype="tth",
199+
wavelength=diffraction_data.wavelength,
201200
scat_quantity="cve",
201+
name=f"absorption correction, cve, for {diffraction_data.name}",
202+
metadata=diffraction_data.metadata,
202203
)
203204
return cve_do
204205

@@ -211,22 +212,22 @@ def _cve_polynomial_interpolation(diffraction_data, mud):
211212
if mud > 6 or mud < 0.5:
212213
raise ValueError(
213214
f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. "
214-
f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }."
215+
f"Please rerun with a value within this range or specifying another method from {*CVE_METHODS, }."
215216
)
216217
coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [
217218
interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS
218219
]
219220
muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e)
220221
cve = 1 / muls
221222

222-
cve_do = Diffraction_object(wavelength=diffraction_data.wavelength)
223-
cve_do.insert_scattering_quantity(
224-
TTH_GRID,
225-
cve,
226-
"tth",
227-
metadata=diffraction_data.metadata,
228-
name=f"absorption correction, cve, for {diffraction_data.name}",
223+
cve_do = DiffractionObject(
224+
xarray=TTH_GRID,
225+
yarray=cve,
226+
xtype="tth",
227+
wavelength=diffraction_data.wavelength,
229228
scat_quantity="cve",
229+
name=f"absorption correction, cve, for {diffraction_data.name}",
230+
metadata=diffraction_data.metadata,
230231
)
231232
return cve_do
232233

@@ -257,7 +258,7 @@ def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype=
257258
xtype str
258259
the quantity on the independent variable axis, allowed values are {*XQUANTITIES, }
259260
method str
260-
the method used to calculate cve, must be one of {* CVE_METHODS, }
261+
the method used to calculate cve, must be one of {*CVE_METHODS, }
261262
262263
Returns
263264
-------
@@ -270,14 +271,14 @@ def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype=
270271
global_xtype = cve_do_on_global_grid.on_xtype(xtype)[0]
271272
cve_on_global_xtype = cve_do_on_global_grid.on_xtype(xtype)[1]
272273
newcve = np.interp(orig_grid, global_xtype, cve_on_global_xtype)
273-
cve_do = Diffraction_object(wavelength=diffraction_data.wavelength)
274-
cve_do.insert_scattering_quantity(
275-
orig_grid,
276-
newcve,
277-
xtype,
278-
metadata=diffraction_data.metadata,
279-
name=f"absorption correction, cve, for {diffraction_data.name}",
274+
cve_do = DiffractionObject(
275+
xarray=orig_grid,
276+
yarray=newcve,
277+
xtype=xtype,
278+
wavelength=diffraction_data.wavelength,
280279
scat_quantity="cve",
280+
name=f"absorption correction, cve, for {diffraction_data.name}",
281+
metadata=diffraction_data.metadata,
281282
)
282283
return cve_do
283284

tests/test_functions.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from diffpy.labpdfproc.functions import CVE_METHODS, Gridded_circle, apply_corr, compute_cve
7-
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
7+
from diffpy.utils.diffraction_objects import DiffractionObject
88

99
params1 = [
1010
([0.5, 3, 1], {(0.0, -0.5), (0.0, 0.0), (0.5, 0.0), (-0.5, 0.0), (0.0, 0.5)}),
@@ -59,11 +59,11 @@ def test_set_muls_at_angle(inputs, expected):
5959

6060

6161
def _instantiate_test_do(xarray, yarray, xtype="tth", name="test", scat_quantity="x-ray"):
62-
test_do = Diffraction_object(wavelength=1.54)
63-
test_do.insert_scattering_quantity(
64-
xarray,
65-
yarray,
66-
xtype,
62+
test_do = DiffractionObject(
63+
xarray=xarray,
64+
yarray=yarray,
65+
xtype=xtype,
66+
wavelength=1.54,
6767
scat_quantity=scat_quantity,
6868
name=name,
6969
metadata={"thing1": 1, "thing2": "thing2"},
@@ -81,14 +81,13 @@ def _instantiate_test_do(xarray, yarray, xtype="tth", name="test", scat_quantity
8181
def test_compute_cve(inputs, expected, mocker):
8282
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
8383
expected_cve = np.array([0.5, 0.5, 0.5])
84-
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
8584
mocker.patch("numpy.interp", return_value=expected_cve)
8685
input_pattern = _instantiate_test_do(xarray, yarray)
8786
actual_cve_do = compute_cve(input_pattern, mud=1, method="polynomial_interpolation", xtype=inputs[0])
8887
expected_cve_do = _instantiate_test_do(
89-
expected[0],
90-
expected[1],
91-
expected[2],
88+
xarray=expected[0],
89+
yarray=expected[1],
90+
xtype=expected[2],
9291
name="absorption correction, cve, for test",
9392
scat_quantity="cve",
9493
)
@@ -126,8 +125,8 @@ def test_apply_corr(mocker):
126125
mocker.patch("numpy.interp", return_value=expected_cve)
127126
input_pattern = _instantiate_test_do(xarray, yarray)
128127
absorption_correction = _instantiate_test_do(
129-
xarray,
130-
expected_cve,
128+
xarray=xarray,
129+
yarray=expected_cve,
131130
name="absorption correction, cve, for test",
132131
scat_quantity="cve",
133132
)

0 commit comments

Comments
 (0)