Skip to content

Commit ad2c50e

Browse files
move changes from previous PR
1 parent 5dc5ae5 commit ad2c50e

File tree

6 files changed

+104
-156
lines changed

6 files changed

+104
-156
lines changed

src/diffpy/labpdfproc/fast_cve.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

src/diffpy/labpdfproc/functions.py

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
import math
2+
from pathlib import Path
23

34
import numpy as np
5+
import pandas as pd
6+
from scipy.interpolate import interp1d
47

58
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
69

710
RADIUS_MM = 1
811
N_POINTS_ON_DIAMETER = 300
9-
TTH_GRID = np.arange(1, 141, 1)
12+
TTH_GRID = np.arange(1, 180.1, 0.1)
13+
CVE_METHODS = ["brute_force", "polynomial_interpolation"]
14+
15+
# pre-computed datasets for polynomial interpolation (fast calculation)
16+
MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6]
17+
CWD = Path(__file__).parent.resolve()
18+
MULS = np.loadtxt(CWD / "data" / "inverse_cve.xy")
19+
COEFFICIENT_LIST = np.array(pd.read_csv(CWD / "data" / "coefficient_list.csv", header=None))
20+
INTERPOLATION_FUNCTIONS = [interp1d(MUD_LIST, coefficients, kind="quadratic") for coefficients in COEFFICIENT_LIST]
1021

1122

1223
class Gridded_circle:
@@ -172,28 +183,10 @@ def get_path_length(self, grid_point, angle):
172183
return total_distance, primary_distance, secondary_distance
173184

174185

175-
def compute_cve(diffraction_data, mud, wavelength):
186+
def _cve_brute_force(diffraction_data, mud):
176187
"""
177-
compute the cve for given diffraction data, mud and wavelength
178-
179-
Parameters
180-
----------
181-
diffraction_data Diffraction_object
182-
the diffraction pattern
183-
mud float
184-
the mu*D of the diffraction object, where D is the diameter of the circle
185-
wavelength float
186-
the wavelength of the diffraction object
187-
188-
Returns
189-
-------
190-
the diffraction object with cve curves
191-
192-
it is computed as follows:
193-
We first resample data and absorption correction to a more reasonable grid,
194-
then calculate corresponding cve for the given mud in the resample grid
195-
(since the same mu*D yields the same cve, we can assume that D/2=1, so mu=mud/2),
196-
and finally interpolate cve to the original grid in diffraction_data.
188+
compute cve for the given mud on a global grid using the brute-force method
189+
assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1
197190
"""
198191

199192
mu_sample_invmm = mud / 2
@@ -208,9 +201,85 @@ def compute_cve(diffraction_data, mud, wavelength):
208201
muls = np.array(muls) / abs_correction.total_points_in_grid
209202
cve = 1 / muls
210203

204+
abdo = Diffraction_object(wavelength=diffraction_data.wavelength)
205+
abdo.insert_scattering_quantity(
206+
TTH_GRID,
207+
cve,
208+
"tth",
209+
metadata=diffraction_data.metadata,
210+
name=f"absorption correction, cve, for {diffraction_data.name}",
211+
wavelength=diffraction_data.wavelength,
212+
scat_quantity="cve",
213+
)
214+
return abdo
215+
216+
217+
def _cve_polynomial_interpolation(diffraction_data, mud):
218+
"""
219+
compute cve using polynomial interpolation method, raise an error if mu*D is out of the range (0.5 to 6)
220+
"""
221+
222+
if mud > 6 or mud < 0.5:
223+
raise ValueError(
224+
f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. "
225+
f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }."
226+
)
227+
coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [
228+
interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS
229+
]
230+
muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e)
231+
cve = 1 / muls
232+
233+
abdo = Diffraction_object(wavelength=diffraction_data.wavelength)
234+
abdo.insert_scattering_quantity(
235+
TTH_GRID,
236+
cve,
237+
"tth",
238+
metadata=diffraction_data.metadata,
239+
name=f"absorption correction, cve, for {diffraction_data.name}",
240+
wavelength=diffraction_data.wavelength,
241+
scat_quantity="cve",
242+
)
243+
return abdo
244+
245+
246+
def _cve_method(method):
247+
"""
248+
retrieve the cve computation function for the given method
249+
"""
250+
methods = {
251+
"brute_force": _cve_brute_force,
252+
"polynomial_interpolation": _cve_polynomial_interpolation,
253+
}
254+
if method not in CVE_METHODS:
255+
raise ValueError(f"Unknown method: {method}. Allowed methods are {*CVE_METHODS, }.")
256+
return methods[method]
257+
258+
259+
def compute_cve(diffraction_data, mud, method="polynomial_interpolation"):
260+
"""
261+
compute and interpolate the cve for the given diffraction data and mud using the selected method
262+
Parameters
263+
----------
264+
diffraction_data Diffraction_object
265+
the diffraction pattern
266+
mud float
267+
the mu*D of the diffraction object, where D is the diameter of the circle
268+
method str
269+
the method used to calculate cve
270+
271+
Returns
272+
-------
273+
the diffraction object with cve curves
274+
"""
275+
276+
cve_function = _cve_method(method)
277+
abdo_on_global_tth = cve_function(diffraction_data, mud)
278+
global_tth = abdo_on_global_tth.on_tth[0]
279+
cve_on_global_tth = abdo_on_global_tth.on_tth[1]
211280
orig_grid = diffraction_data.on_tth[0]
212-
newcve = np.interp(orig_grid, TTH_GRID, cve)
213-
abdo = Diffraction_object(wavelength=wavelength)
281+
newcve = np.interp(orig_grid, global_tth, cve_on_global_tth)
282+
abdo = Diffraction_object(wavelength=diffraction_data.wavelength)
214283
abdo.insert_scattering_quantity(
215284
orig_grid,
216285
newcve,

src/diffpy/labpdfproc/labpdfprocapp.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
from argparse import ArgumentParser
33

4-
from diffpy.labpdfproc.functions import apply_corr, compute_cve
4+
from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, compute_cve
55
from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args
66
from diffpy.utils.parsers.loaddata import loadData
77
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
@@ -72,6 +72,13 @@ def get_args(override_cli_inputs=None):
7272
action="store_true",
7373
help="Outputs will not overwrite existing file unless --force is specified.",
7474
)
75+
p.add_argument(
76+
"-m",
77+
"--method",
78+
help=f"The method for computing absorption correction. Allowed methods: {*CVE_METHODS, }. "
79+
f"Default method is polynomial interpolation if not specified. ",
80+
default="polynomial_interpolation",
81+
)
7582
p.add_argument(
7683
"-u",
7784
"--user-metadata",
@@ -134,7 +141,7 @@ def main():
134141
metadata=load_metadata(args, filepath),
135142
)
136143

137-
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
144+
absorption_correction = compute_cve(input_pattern, args.mud, args.method)
138145
corrected_data = apply_corr(input_pattern, absorption_correction)
139146
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
140147
corrected_data.dump(f"{outfile}", xtype="tth")

src/diffpy/labpdfproc/tests/test_fast_cve.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/diffpy/labpdfproc/tests/test_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_compute_cve(mocker):
7575
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
7676
mocker.patch("numpy.interp", return_value=expected_cve)
7777
input_pattern = _instantiate_test_do(xarray, yarray)
78-
actual_abdo = compute_cve(input_pattern, mud=1, wavelength=1.54)
78+
actual_abdo = compute_cve(input_pattern, mud=1)
7979
expected_abdo = _instantiate_test_do(
8080
xarray,
8181
expected_cve,

src/diffpy/labpdfproc/tests/test_tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def test_load_metadata(mocker, user_filesystem):
312312
"wavelength": 0.71,
313313
"output_directory": str(Path.cwd().resolve()),
314314
"xtype": "tth",
315+
"method": "polynomial_interpolation",
315316
"key": "value",
316317
"username": "cli_username",
317318
"email": "[email protected]",

0 commit comments

Comments
 (0)