Skip to content

Commit 956408d

Browse files
add private function _cve_method in compute_cve to select cve computers _cve_brute_force or _cve_interp_polynomial
1 parent fc4d876 commit 956408d

File tree

1 file changed

+55
-42
lines changed

1 file changed

+55
-42
lines changed

src/diffpy/labpdfproc/functions.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,6 @@ def _get_grid_points(self):
3737
self.grid = {(x, y) for x in xs for y in ys if x**2 + y**2 <= self.radius**2}
3838
self.total_points_in_grid = len(self.grid)
3939

40-
# def get_coordinate_index(self, coordinate):
41-
# count = 0
42-
# for i, target in enumerate(self.grid):
43-
# if coordinate == target:
44-
# return i
45-
# else:
46-
# count += 1
47-
# if count >= len(self.grid):
48-
# raise IndexError(f"WARNING: no coordinate {coordinate} found in coordinates list")
49-
5040
def set_distances_at_angle(self, angle):
5141
"""
5242
given an angle, set the distances from the grid points to the entry and exit coordinates
@@ -182,9 +172,62 @@ def get_path_length(self, grid_point, angle):
182172
return total_distance, primary_distance, secondary_distance
183173

184174

175+
def _cve_brute_force(diffraction_data, mud, wavelength):
176+
"""
177+
compute cve using brute-force method
178+
179+
it is computed as follows:
180+
We first resample data and absorption correction to a more reasonable grid,
181+
then calculate corresponding cve for the given mud in the resample grid
182+
(since the same mu*D yields the same cve, we can assume that D/2=1, so mu=mud/2),
183+
and finally interpolate cve to the original grid in diffraction_data.
184+
"""
185+
186+
mu_sample_invmm = mud / 2
187+
abs_correction = Gridded_circle(mu=mu_sample_invmm)
188+
distances, muls = [], []
189+
for angle in TTH_GRID:
190+
abs_correction.set_distances_at_angle(angle)
191+
abs_correction.set_muls_at_angle(angle)
192+
distances.append(sum(abs_correction.distances))
193+
muls.append(sum(abs_correction.muls))
194+
distances = np.array(distances) / abs_correction.total_points_in_grid
195+
muls = np.array(muls) / abs_correction.total_points_in_grid
196+
cve = 1 / muls
197+
return cve
198+
199+
200+
def _cve_interp_polynomial(diffraction_data, mud, wavelength):
201+
"""
202+
compute cve using polynomial interpolation method, raise an error if mu*D is out of the range (0.5 to 6)
203+
"""
204+
205+
if mud > 6 or mud < 0.5:
206+
raise ValueError(
207+
"mu*D is out of the acceptable range (0.5 to 6) for fast calculation. "
208+
"Please rerun with a value within this range or use -b to enable brute-force calculation. "
209+
)
210+
coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [
211+
interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS
212+
]
213+
muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e)
214+
cve = 1 / muls
215+
return cve
216+
217+
218+
def _cve_method(diffraction_data, mud, wavelength, brute_force=False):
219+
"""
220+
selects the appropriate CVE calculation method
221+
"""
222+
if brute_force:
223+
return _cve_brute_force(diffraction_data, mud, wavelength)
224+
else:
225+
return _cve_interp_polynomial(diffraction_data, mud, wavelength)
226+
227+
185228
def compute_cve(diffraction_data, mud, wavelength, brute_force=False):
186229
"""
187-
compute the cve for given diffraction data, mud and wavelength, and a boolean to determine the way to compute
230+
compute the cve for given diffraction data, mud, and wavelength, using the selected method
188231
189232
Parameters
190233
----------
@@ -195,42 +238,13 @@ def compute_cve(diffraction_data, mud, wavelength, brute_force=False):
195238
wavelength float
196239
the wavelength of the diffraction object
197240
198-
the brute-force method is computed as follows:
199-
We first resample data and absorption correction to a more reasonable grid,
200-
then calculate corresponding cve for the given mud in the resample grid
201-
(since the same mu*D yields the same cve, we can assume that D/2=1, so mu=mud/2),
202-
and finally interpolate cve to the original grid in diffraction_data.
203-
204241
Returns
205242
-------
206243
the diffraction object with cve curves
207244
208245
"""
209246

210-
if brute_force:
211-
mu_sample_invmm = mud / 2
212-
abs_correction = Gridded_circle(mu=mu_sample_invmm)
213-
distances, muls = [], []
214-
for angle in TTH_GRID:
215-
abs_correction.set_distances_at_angle(angle)
216-
abs_correction.set_muls_at_angle(angle)
217-
distances.append(sum(abs_correction.distances))
218-
muls.append(sum(abs_correction.muls))
219-
distances = np.array(distances) / abs_correction.total_points_in_grid
220-
muls = np.array(muls) / abs_correction.total_points_in_grid
221-
cve = 1 / muls
222-
else:
223-
if mud > 6 or mud < 0.5:
224-
raise ValueError(
225-
"mu*D is out of the acceptable range (0.5 to 6) for fast calculation. "
226-
"Please rerun with a value within this range or use -b to enable brute-force calculation. "
227-
)
228-
coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [
229-
interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS
230-
]
231-
muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e)
232-
cve = 1 / muls
233-
247+
cve = _cve_method(diffraction_data, mud, wavelength, brute_force)
234248
orig_grid = diffraction_data.on_tth[0]
235249
newcve = np.interp(orig_grid, TTH_GRID, cve)
236250
abdo = Diffraction_object(wavelength=wavelength)
@@ -243,7 +257,6 @@ def compute_cve(diffraction_data, mud, wavelength, brute_force=False):
243257
wavelength=diffraction_data.wavelength,
244258
scat_quantity="cve",
245259
)
246-
247260
return abdo
248261

249262

0 commit comments

Comments
 (0)