Skip to content

Commit bb28204

Browse files
change compute cve function to call cve methods and interpolate tth
1 parent e7d3d05 commit bb28204

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

src/diffpy/labpdfproc/functions.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,20 @@ def _cve_polynomial_interpolation(mud):
211211
return cve
212212

213213

214-
def _compute_cve(method, mud):
214+
def _cve_method(method):
215215
"""
216-
compute cve for the given mud on a global grid using the specified method
216+
retrieve the cve computation function for the given method
217217
"""
218218
methods = {
219219
"brute_force": _cve_brute_force,
220220
"polynomial_interpolation": _cve_polynomial_interpolation,
221221
}
222-
return methods[method](mud)
222+
if method not in CVE_METHODS:
223+
raise ValueError(f"Unknown method: {method}. Allowed methods are {*CVE_METHODS, }.")
224+
return methods[method]
223225

224226

225-
def interpolate_cve(diffraction_data, mud, wavelength, method="polynomial_interpolation"):
227+
def compute_cve(diffraction_data, mud, wavelength, method="polynomial_interpolation"):
226228
"""
227229
compute and interpolate the cve for the given diffraction data, mud, and wavelength, using the selected method
228230
@@ -243,7 +245,8 @@ def interpolate_cve(diffraction_data, mud, wavelength, method="polynomial_interp
243245
244246
"""
245247

246-
cve = _compute_cve(method, mud)
248+
cve_function = _cve_method(method)
249+
cve = cve_function(mud)
247250
orig_grid = diffraction_data.on_tth[0]
248251
newcve = np.interp(orig_grid, TTH_GRID, cve)
249252
abdo = Diffraction_object(wavelength=wavelength)

src/diffpy/labpdfproc/labpdfprocapp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
from argparse import ArgumentParser
33

4-
from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, interpolate_cve
4+
from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, compute_cve
55
from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args
66
from diffpy.utils.parsers.loaddata import loadData
77
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
@@ -140,7 +140,7 @@ def main():
140140
metadata=load_metadata(args, filepath),
141141
)
142142

143-
absorption_correction = interpolate_cve(input_pattern, args.mud, args.wavelength, args.method)
143+
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength, args.method)
144144
corrected_data = apply_corr(input_pattern, absorption_correction)
145145
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
146146
corrected_data.dump(f"{outfile}", xtype="tth")

src/diffpy/labpdfproc/tests/test_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pytest
33

4-
from diffpy.labpdfproc.functions import Gridded_circle, apply_corr, interpolate_cve
4+
from diffpy.labpdfproc.functions import Gridded_circle, apply_corr, compute_cve
55
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
66

77
params1 = [
@@ -69,13 +69,13 @@ def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"):
6969
return test_do
7070

7171

72-
def test_interpolate_cve(mocker):
72+
def test_compute_cve(mocker):
7373
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
7474
expected_cve = np.array([0.5, 0.5, 0.5])
7575
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
7676
mocker.patch("numpy.interp", return_value=expected_cve)
7777
input_pattern = _instantiate_test_do(xarray, yarray)
78-
actual_abdo = interpolate_cve(input_pattern, mud=1, wavelength=1.54)
78+
actual_abdo = compute_cve(input_pattern, mud=1, wavelength=1.54)
7979
expected_abdo = _instantiate_test_do(
8080
xarray,
8181
expected_cve,

0 commit comments

Comments
 (0)