Skip to content

Commit 7d0d574

Browse files
remove duplicated functions in fast cve
1 parent 2487ae6 commit 7d0d574

File tree

4 files changed

+39
-127
lines changed

4 files changed

+39
-127
lines changed

src/diffpy/labpdfproc/fast_cve.py

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,26 @@
44
import pandas as pd
55
from scipy.interpolate import interp1d
66

7-
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
8-
9-
TTH_GRID = np.arange(1, 180.1, 0.1)
7+
FAST_TTH_GRID = np.arange(1, 180.1, 0.1)
108
MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6]
119
CWD = os.path.dirname(os.path.abspath(__file__))
1210
INVERSE_CVE_DATA = np.loadtxt(CWD + "/data/inverse_cve.xy")
1311
COEFFICIENT_LIST = np.array(pd.read_csv(CWD + "/data/coefficient_list.csv", header=None))
1412
INTERPOLATION_FUNCTIONS = [interp1d(MUD_LIST, coefficients, kind="quadratic") for coefficients in COEFFICIENT_LIST]
1513

1614

17-
def fast_compute_cve(diffraction_data, mud, wavelength):
15+
def fast_compute_cve(mud):
1816
"""
19-
use precomputed datasets to compute the cve for given diffraction data, mud and wavelength
17+
use precomputed datasets to compute the cve for given mud
2018
2119
Parameters
2220
----------
23-
diffraction_data Diffraction_object
24-
the diffraction pattern
2521
mud float
2622
the mu*D of the diffraction object, where D is the diameter of the circle
27-
wavelength float
28-
the wavelength of the diffraction object
2923
3024
Returns
3125
-------
32-
the diffraction object with cve curves
26+
the array of tth grid and the corresponding cve
3327
"""
3428

3529
coefficient_a, coefficient_b, coefficient_c, coefficient_d, coefficient_e = [
@@ -43,39 +37,4 @@ def fast_compute_cve(diffraction_data, mud, wavelength):
4337
+ coefficient_e
4438
)
4539
cve = 1 / np.array(inverse_cve)
46-
47-
orig_grid = diffraction_data.on_tth[0]
48-
newcve = np.interp(orig_grid, TTH_GRID, cve)
49-
abdo = Diffraction_object(wavelength=wavelength)
50-
abdo.insert_scattering_quantity(
51-
orig_grid,
52-
newcve,
53-
"tth",
54-
metadata=diffraction_data.metadata,
55-
name=f"absorption correction, cve, for {diffraction_data.name}",
56-
wavelength=diffraction_data.wavelength,
57-
scat_quantity="cve",
58-
)
59-
60-
return abdo
61-
62-
63-
def apply_fast_corr(diffraction_pattern, absorption_correction):
64-
"""
65-
Apply absorption correction to the given diffraction object modo with the correction diffraction object abdo
66-
67-
Parameters
68-
----------
69-
diffraction_pattern Diffraction_object
70-
the input diffraction object to which the cve will be applied
71-
absorption_correction Diffraction_object
72-
the diffraction object that contains the cve to be applied
73-
74-
Returns
75-
-------
76-
a corrected diffraction object with the correction applied through multiplication
77-
78-
"""
79-
80-
corrected_pattern = diffraction_pattern * absorption_correction
81-
return corrected_pattern
40+
return FAST_TTH_GRID, cve

src/diffpy/labpdfproc/functions.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from diffpy.labpdfproc.fast_cve import fast_compute_cve
56
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
67

78
RADIUS_MM = 1
@@ -27,7 +28,7 @@ def _get_grid_points(self):
2728
self.grid = {(x, y) for x in xs for y in ys if x**2 + y**2 <= self.radius**2}
2829
self.total_points_in_grid = len(self.grid)
2930

30-
# def get_coordinate_index(self, coordinate): # I think we probably dont need this function?
31+
# def get_coordinate_index(self, coordinate):
3132
# count = 0
3233
# for i, target in enumerate(self.grid):
3334
# if coordinate == target:
@@ -172,9 +173,9 @@ def get_path_length(self, grid_point, angle):
172173
return total_distance, primary_distance, secondary_distance
173174

174175

175-
def compute_cve(diffraction_data, mud, wavelength):
176+
def compute_cve(diffraction_data, mud, wavelength, brute_force=False):
176177
"""
177-
compute the cve for given diffraction data, mud and wavelength
178+
compute the cve for given diffraction data, mud and wavelength, and a boolean to determine the way to compute
178179
179180
Parameters
180181
----------
@@ -185,31 +186,41 @@ def compute_cve(diffraction_data, mud, wavelength):
185186
wavelength float
186187
the wavelength of the diffraction object
187188
188-
Returns
189-
-------
190-
the diffraction object with cve curves
191-
192-
it is computed as follows:
189+
the brute-force method is computed as follows:
193190
We first resample data and absorption correction to a more reasonable grid,
194191
then calculate corresponding cve for the given mud in the resample grid
195192
(since the same mu*D yields the same cve, we can assume that D/2=1, so mu=mud/2),
196193
and finally interpolate cve to the original grid in diffraction_data.
194+
195+
Returns
196+
-------
197+
the diffraction object with cve curves
198+
197199
"""
198200

199-
mu_sample_invmm = mud / 2
200-
abs_correction = Gridded_circle(mu=mu_sample_invmm)
201-
distances, muls = [], []
202-
for angle in TTH_GRID:
203-
abs_correction.set_distances_at_angle(angle)
204-
abs_correction.set_muls_at_angle(angle)
205-
distances.append(sum(abs_correction.distances))
206-
muls.append(sum(abs_correction.muls))
207-
distances = np.array(distances) / abs_correction.total_points_in_grid
208-
muls = np.array(muls) / abs_correction.total_points_in_grid
209-
cve = 1 / muls
201+
if brute_force:
202+
tth_grid = TTH_GRID
203+
mu_sample_invmm = mud / 2
204+
abs_correction = Gridded_circle(mu=mu_sample_invmm)
205+
distances, muls = [], []
206+
for angle in TTH_GRID:
207+
abs_correction.set_distances_at_angle(angle)
208+
abs_correction.set_muls_at_angle(angle)
209+
distances.append(sum(abs_correction.distances))
210+
muls.append(sum(abs_correction.muls))
211+
distances = np.array(distances) / abs_correction.total_points_in_grid
212+
muls = np.array(muls) / abs_correction.total_points_in_grid
213+
cve = 1 / muls
214+
else:
215+
if mud > 6 or mud < 0.5:
216+
raise ValueError(
217+
"mu*D is out of the acceptable range (0.5 to 6) for fast calculation. "
218+
"Please rerun with a value within this range or use -b to enable brute-force calculation. "
219+
)
220+
tth_grid, cve = fast_compute_cve(mud)
210221

211222
orig_grid = diffraction_data.on_tth[0]
212-
newcve = np.interp(orig_grid, TTH_GRID, cve)
223+
newcve = np.interp(orig_grid, tth_grid, cve)
213224
abdo = Diffraction_object(wavelength=wavelength)
214225
abdo.insert_scattering_quantity(
215226
orig_grid,

src/diffpy/labpdfproc/labpdfprocapp.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import sys
22
from argparse import ArgumentParser
33

4-
from diffpy.labpdfproc.fast_cve import apply_fast_corr, fast_compute_cve
54
from diffpy.labpdfproc.functions import apply_corr, compute_cve
65
from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args
76
from diffpy.utils.parsers.loaddata import loadData
@@ -141,20 +140,11 @@ def main():
141140
metadata=load_metadata(args, filepath),
142141
)
143142

144-
if args.brute_force:
145-
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
146-
corrected_data = apply_corr(input_pattern, absorption_correction)
147-
else:
148-
if args.mud > 6 or args.mud < 0.5:
149-
sys.exit(
150-
"mu*D is out of the acceptable range (0.5 to 6) for fast calculation. "
151-
"Please rerun with a value within this range or use -b enable brute-force calculation. "
152-
)
153-
absorption_correction = fast_compute_cve(input_pattern, args.mud, args.wavelength)
154-
corrected_data = apply_fast_corr(input_pattern, absorption_correction)
155-
143+
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength, args.brute_force)
144+
corrected_data = apply_corr(input_pattern, absorption_correction)
156145
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
157146
corrected_data.dump(f"{outfile}", xtype="tth")
147+
158148
if args.output_correction:
159149
absorption_correction.dump(f"{corrfile}", xtype="tth")
160150

src/diffpy/labpdfproc/tests/test_fast_cve.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

0 commit comments

Comments
 (0)