-
Notifications
You must be signed in to change notification settings - Fork 11
add fast_cve to labpdfprocapp.py and some other fixes #87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
329a895
6954a29
97bcc00
2487ae6
7d0d574
b7e037e
a450dda
fc4d876
956408d
4bbe4a2
b91b6ec
e7d3d05
bb28204
3c6b128
377607a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,23 @@ | ||
import math | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from scipy.interpolate import interp1d | ||
|
||
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object | ||
|
||
RADIUS_MM = 1 | ||
N_POINTS_ON_DIAMETER = 300 | ||
TTH_GRID = np.arange(1, 141, 1) | ||
TTH_GRID = np.arange(1, 180.1, 0.1) | ||
CVE_METHODS = ["brute_force", "polynomial_interpolation"] | ||
|
||
# pre-computed datasets for polynomial interpolation (fast calculation) | ||
MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6] | ||
CWD = Path(__file__).parent.resolve() | ||
MULS = np.loadtxt(CWD / "data" / "inverse_cve.xy") | ||
COEFFICIENT_LIST = np.array(pd.read_csv(CWD / "data" / "coefficient_list.csv", header=None)) | ||
INTERPOLATION_FUNCTIONS = [interp1d(MUD_LIST, coefficients, kind="quadratic") for coefficients in COEFFICIENT_LIST] | ||
|
||
|
||
class Gridded_circle: | ||
|
@@ -27,16 +38,6 @@ def _get_grid_points(self): | |
self.grid = {(x, y) for x in xs for y in ys if x**2 + y**2 <= self.radius**2} | ||
self.total_points_in_grid = len(self.grid) | ||
|
||
# def get_coordinate_index(self, coordinate): # I think we probably dont need this function? | ||
# count = 0 | ||
# for i, target in enumerate(self.grid): | ||
# if coordinate == target: | ||
# return i | ||
# else: | ||
# count += 1 | ||
# if count >= len(self.grid): | ||
# raise IndexError(f"WARNING: no coordinate {coordinate} found in coordinates list") | ||
|
||
def set_distances_at_angle(self, angle): | ||
""" | ||
given an angle, set the distances from the grid points to the entry and exit coordinates | ||
|
@@ -172,28 +173,10 @@ def get_path_length(self, grid_point, angle): | |
return total_distance, primary_distance, secondary_distance | ||
|
||
|
||
def compute_cve(diffraction_data, mud, wavelength): | ||
def _cve_brute_force(diffraction_data, mud): | ||
""" | ||
compute the cve for given diffraction data, mud and wavelength | ||
|
||
Parameters | ||
---------- | ||
diffraction_data Diffraction_object | ||
the diffraction pattern | ||
mud float | ||
the mu*D of the diffraction object, where D is the diameter of the circle | ||
wavelength float | ||
the wavelength of the diffraction object | ||
|
||
Returns | ||
------- | ||
the diffraction object with cve curves | ||
|
||
it is computed as follows: | ||
We first resample data and absorption correction to a more reasonable grid, | ||
then calculate corresponding cve for the given mud in the resample grid | ||
(since the same mu*D yields the same cve, we can assume that D/2=1, so mu=mud/2), | ||
and finally interpolate cve to the original grid in diffraction_data. | ||
compute cve for the given mud on a global grid using the brute-force method | ||
assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1 | ||
""" | ||
|
||
mu_sample_invmm = mud / 2 | ||
|
@@ -208,9 +191,87 @@ def compute_cve(diffraction_data, mud, wavelength): | |
muls = np.array(muls) / abs_correction.total_points_in_grid | ||
cve = 1 / muls | ||
|
||
abdo = Diffraction_object(wavelength=diffraction_data.wavelength) | ||
abdo.insert_scattering_quantity( | ||
TTH_GRID, | ||
cve, | ||
"tth", | ||
metadata=diffraction_data.metadata, | ||
name=f"absorption correction, cve, for {diffraction_data.name}", | ||
wavelength=diffraction_data.wavelength, | ||
scat_quantity="cve", | ||
) | ||
return abdo | ||
|
||
|
||
def _cve_polynomial_interpolation(diffraction_data, mud): | ||
""" | ||
compute cve using polynomial interpolation method, raise an error if mu*D is out of the range (0.5 to 6) | ||
""" | ||
|
||
if mud > 6 or mud < 0.5: | ||
raise ValueError( | ||
f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. " | ||
f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }." | ||
) | ||
coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [ | ||
interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS | ||
] | ||
muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e) | ||
cve = 1 / muls | ||
|
||
abdo = Diffraction_object(wavelength=diffraction_data.wavelength) | ||
abdo.insert_scattering_quantity( | ||
TTH_GRID, | ||
cve, | ||
"tth", | ||
metadata=diffraction_data.metadata, | ||
name=f"absorption correction, cve, for {diffraction_data.name}", | ||
wavelength=diffraction_data.wavelength, | ||
scat_quantity="cve", | ||
) | ||
return abdo | ||
|
||
|
||
def _cve_method(method): | ||
""" | ||
retrieve the cve computation function for the given method | ||
""" | ||
methods = { | ||
"brute_force": _cve_brute_force, | ||
"polynomial_interpolation": _cve_polynomial_interpolation, | ||
} | ||
if method not in CVE_METHODS: | ||
raise ValueError(f"Unknown method: {method}. Allowed methods are {*CVE_METHODS, }.") | ||
return methods[method] | ||
|
||
|
||
def compute_cve(diffraction_data, mud, method="polynomial_interpolation"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good practice here is to make the default as
|
||
""" | ||
compute and interpolate the cve for the given diffraction data and mud using the selected method | ||
|
||
Parameters | ||
sbillinge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
---------- | ||
diffraction_data Diffraction_object | ||
the diffraction pattern | ||
mud float | ||
the mu*D of the diffraction object, where D is the diameter of the circle | ||
method str | ||
the method used to calculate cve | ||
yucongalicechen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns | ||
------- | ||
the diffraction object with cve curves | ||
|
||
""" | ||
|
||
cve_function = _cve_method(method) | ||
abdo_on_global_tth = cve_function(diffraction_data, mud) | ||
global_tth = abdo_on_global_tth.on_tth[0] | ||
cve_on_global_tth = abdo_on_global_tth.on_tth[1] | ||
orig_grid = diffraction_data.on_tth[0] | ||
newcve = np.interp(orig_grid, TTH_GRID, cve) | ||
abdo = Diffraction_object(wavelength=wavelength) | ||
newcve = np.interp(orig_grid, global_tth, cve_on_global_tth) | ||
abdo = Diffraction_object(wavelength=diffraction_data.wavelength) | ||
abdo.insert_scattering_quantity( | ||
orig_grid, | ||
newcve, | ||
|
@@ -220,7 +281,6 @@ def compute_cve(diffraction_data, mud, wavelength): | |
wavelength=diffraction_data.wavelength, | ||
scat_quantity="cve", | ||
) | ||
|
||
return abdo | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import sys | ||
from argparse import ArgumentParser | ||
|
||
from diffpy.labpdfproc.functions import apply_corr, compute_cve | ||
from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, compute_cve | ||
from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args | ||
from diffpy.utils.parsers.loaddata import loadData | ||
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object | ||
|
@@ -45,7 +45,7 @@ def get_args(override_cli_inputs=None): | |
"-o", | ||
"--output-directory", | ||
help="The name of the output directory. If not specified " | ||
"then corrected files will be written to the current directory." | ||
"then corrected files will be written to the current directory. " | ||
"If the specified directory doesn't exist it will be created.", | ||
default=None, | ||
) | ||
|
@@ -64,14 +64,20 @@ def get_args(override_cli_inputs=None): | |
action="store_true", | ||
help="The absorption correction will be output to a file if this " | ||
"flag is set. Default is that it is not output.", | ||
default="tth", | ||
) | ||
p.add_argument( | ||
"-f", | ||
"--force-overwrite", | ||
action="store_true", | ||
help="Outputs will not overwrite existing file unless --force is specified.", | ||
) | ||
p.add_argument( | ||
"-m", | ||
"--method", | ||
help=f"The method for computing absorption correction. Allowed methods: {*CVE_METHODS, }. " | ||
f"Default method is polynomial interpolation if not specified. ", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there an indent problem here? |
||
default="polynomial_interpolation", | ||
) | ||
p.add_argument( | ||
"-u", | ||
"--user-metadata", | ||
|
@@ -109,8 +115,8 @@ def main(): | |
for filepath in args.input_paths: | ||
outfilestem = filepath.stem + "_corrected" | ||
corrfilestem = filepath.stem + "_cve" | ||
outfile = args.output_directory / (outfilestem + ".chi") | ||
corrfile = args.output_directory / (corrfilestem + ".chi") | ||
outfile = args.output_directory / (outfilestem + ".xy") | ||
yucongalicechen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
corrfile = args.output_directory / (corrfilestem + ".xy") | ||
|
||
if outfile.exists() and not args.force_overwrite: | ||
sys.exit( | ||
|
@@ -134,7 +140,7 @@ def main(): | |
metadata=load_metadata(args, filepath), | ||
) | ||
|
||
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength) | ||
absorption_correction = compute_cve(input_pattern, args.mud, args.method) | ||
corrected_data = apply_corr(input_pattern, absorption_correction) | ||
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}" | ||
corrected_data.dump(f"{outfile}", xtype="tth") | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -312,6 +312,7 @@ def test_load_metadata(mocker, user_filesystem): | |
"wavelength": 0.71, | ||
"output_directory": str(Path.cwd().resolve()), | ||
"xtype": "tth", | ||
"method": "polynomial_interpolation", | ||
"key": "value", | ||
"username": "cli_username", | ||
"email": "[email protected]", | ||
|
Uh oh!
There was an error while loading. Please reload this page.