Skip to content

Commit 4bbe4a2

Browse files
add method to argparser and correct functions and docstrings
1 parent 956408d commit 4bbe4a2

File tree

3 files changed

+30
-31
lines changed

3 files changed

+30
-31
lines changed

src/diffpy/labpdfproc/functions.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
RADIUS_MM = 1
1111
N_POINTS_ON_DIAMETER = 300
1212
TTH_GRID = np.arange(1, 180.1, 0.1)
13+
CVE_METHODS = ["brute_force", "polynomial_interpolation"]
1314

1415
# pre-computed datasets for fast calculation
1516
MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6]
@@ -172,15 +173,10 @@ def get_path_length(self, grid_point, angle):
172173
return total_distance, primary_distance, secondary_distance
173174

174175

175-
def _cve_brute_force(diffraction_data, mud, wavelength):
176+
def _cve_brute_force(mud):
176177
"""
177-
compute cve using brute-force method
178-
179-
it is computed as follows:
180-
We first resample data and absorption correction to a more reasonable grid,
181-
then calculate corresponding cve for the given mud in the resample grid
182-
(since the same mu*D yields the same cve, we can assume that D/2=1, so mu=mud/2),
183-
and finally interpolate cve to the original grid in diffraction_data.
178+
compute cve for the given mud on a global grid using the brute-force method
179+
assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1
184180
"""
185181

186182
mu_sample_invmm = mud / 2
@@ -197,15 +193,15 @@ def _cve_brute_force(diffraction_data, mud, wavelength):
197193
return cve
198194

199195

200-
def _cve_interp_polynomial(diffraction_data, mud, wavelength):
196+
def _cve_polynomial_interpolation(mud):
201197
"""
202198
compute cve using polynomial interpolation method, raise an error if mu*D is out of the range (0.5 to 6)
203199
"""
204200

205201
if mud > 6 or mud < 0.5:
206202
raise ValueError(
207-
"mu*D is out of the acceptable range (0.5 to 6) for fast calculation. "
208-
"Please rerun with a value within this range or use -b to enable brute-force calculation. "
203+
f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. "
204+
f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }."
209205
)
210206
coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [
211207
interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS
@@ -215,19 +211,20 @@ def _cve_interp_polynomial(diffraction_data, mud, wavelength):
215211
return cve
216212

217213

218-
def _cve_method(diffraction_data, mud, wavelength, brute_force=False):
214+
def _compute_cve(method, mud):
219215
"""
220-
selects the appropriate CVE calculation method
216+
compute cve for the given mud on a global grid using the specified method
221217
"""
222-
if brute_force:
223-
return _cve_brute_force(diffraction_data, mud, wavelength)
224-
else:
225-
return _cve_interp_polynomial(diffraction_data, mud, wavelength)
218+
methods = {
219+
"brute_force": _cve_brute_force,
220+
"polynomial_interpolation": _cve_polynomial_interpolation,
221+
}
222+
return methods[method](mud)
226223

227224

228-
def compute_cve(diffraction_data, mud, wavelength, brute_force=False):
225+
def interpolate_cve(diffraction_data, mud, wavelength, method="polynomial_interpolation"):
229226
"""
230-
compute the cve for given diffraction data, mud, and wavelength, using the selected method
227+
compute and interpolate the cve for the given diffraction data, mud, and wavelength, using the selected method
231228
232229
Parameters
233230
----------
@@ -237,14 +234,16 @@ def compute_cve(diffraction_data, mud, wavelength, brute_force=False):
237234
the mu*D of the diffraction object, where D is the diameter of the circle
238235
wavelength float
239236
the wavelength of the diffraction object
237+
method str
238+
the method used to calculate cve
240239
241240
Returns
242241
-------
243242
the diffraction object with cve curves
244243
245244
"""
246245

247-
cve = _cve_method(diffraction_data, mud, wavelength, brute_force)
246+
cve = _compute_cve(method, mud)
248247
orig_grid = diffraction_data.on_tth[0]
249248
newcve = np.interp(orig_grid, TTH_GRID, cve)
250249
abdo = Diffraction_object(wavelength=wavelength)

src/diffpy/labpdfproc/labpdfprocapp.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
from argparse import ArgumentParser
33

4-
from diffpy.labpdfproc.functions import apply_corr, compute_cve
4+
from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, interpolate_cve
55
from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args
66
from diffpy.utils.parsers.loaddata import loadData
77
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
@@ -45,7 +45,7 @@ def get_args(override_cli_inputs=None):
4545
"-o",
4646
"--output-directory",
4747
help="The name of the output directory. If not specified "
48-
"then corrected files will be written to the current directory."
48+
"then corrected files will be written to the current directory. "
4949
"If the specified directory doesn't exist it will be created.",
5050
default=None,
5151
)
@@ -72,11 +72,11 @@ def get_args(override_cli_inputs=None):
7272
help="Outputs will not overwrite existing file unless --force is specified.",
7373
)
7474
p.add_argument(
75-
"-b",
76-
"--brute-force",
77-
action="store_true",
78-
help="The absorption correction will be computed using brute-force calculation "
79-
"if this flag is set. Default is using fast calculation. ",
75+
"-m",
76+
"--method",
77+
help=f"The method for computing absorption correction. Allowed methods: {*CVE_METHODS, }. "
78+
f"Default method is polynomial interpolation if not specified. ",
79+
default="polynomial_interpolation",
8080
)
8181
p.add_argument(
8282
"-u",
@@ -140,7 +140,7 @@ def main():
140140
metadata=load_metadata(args, filepath),
141141
)
142142

143-
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength, args.brute_force)
143+
absorption_correction = interpolate_cve(input_pattern, args.mud, args.wavelength, args.method)
144144
corrected_data = apply_corr(input_pattern, absorption_correction)
145145
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
146146
corrected_data.dump(f"{outfile}", xtype="tth")

src/diffpy/labpdfproc/tests/test_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pytest
33

4-
from diffpy.labpdfproc.functions import Gridded_circle, apply_corr, compute_cve
4+
from diffpy.labpdfproc.functions import Gridded_circle, apply_corr, interpolate_cve
55
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
66

77
params1 = [
@@ -69,13 +69,13 @@ def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"):
6969
return test_do
7070

7171

72-
def test_compute_cve(mocker):
72+
def test_interpolate_cve(mocker):
7373
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
7474
expected_cve = np.array([0.5, 0.5, 0.5])
7575
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
7676
mocker.patch("numpy.interp", return_value=expected_cve)
7777
input_pattern = _instantiate_test_do(xarray, yarray)
78-
actual_abdo = compute_cve(input_pattern, mud=1, wavelength=1.54)
78+
actual_abdo = interpolate_cve(input_pattern, mud=1, wavelength=1.54)
7979
expected_abdo = _instantiate_test_do(
8080
xarray,
8181
expected_cve,

0 commit comments

Comments
 (0)