Skip to content

add fast_cve to labpdfprocapp.py and some other fixes #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ template = "{tag}"
dev_template = "{tag}"
dirty_template = "{tag}"

[project.scripts]
labpdfproc = "diffpy.labpdfproc.labpdfprocapp:main"

[tool.setuptools.packages.find]
where = ["src"] # list of folders that contain the packages (["."] by default)
include = ["*"] # package names should match these glob patterns (["*"] by default)
Expand Down
81 changes: 0 additions & 81 deletions src/diffpy/labpdfproc/fast_cve.py

This file was deleted.

130 changes: 95 additions & 35 deletions src/diffpy/labpdfproc/functions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import math
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.interpolate import interp1d

from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object

RADIUS_MM = 1
N_POINTS_ON_DIAMETER = 300
TTH_GRID = np.arange(1, 141, 1)
TTH_GRID = np.arange(1, 180.1, 0.1)
CVE_METHODS = ["brute_force", "polynomial_interpolation"]

# pre-computed datasets for polynomial interpolation (fast calculation)
MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6]
CWD = Path(__file__).parent.resolve()
MULS = np.loadtxt(CWD / "data" / "inverse_cve.xy")
COEFFICIENT_LIST = np.array(pd.read_csv(CWD / "data" / "coefficient_list.csv", header=None))
INTERPOLATION_FUNCTIONS = [interp1d(MUD_LIST, coefficients, kind="quadratic") for coefficients in COEFFICIENT_LIST]


class Gridded_circle:
Expand All @@ -27,16 +38,6 @@ def _get_grid_points(self):
self.grid = {(x, y) for x in xs for y in ys if x**2 + y**2 <= self.radius**2}
self.total_points_in_grid = len(self.grid)

# def get_coordinate_index(self, coordinate): # I think we probably dont need this function?
# count = 0
# for i, target in enumerate(self.grid):
# if coordinate == target:
# return i
# else:
# count += 1
# if count >= len(self.grid):
# raise IndexError(f"WARNING: no coordinate {coordinate} found in coordinates list")

def set_distances_at_angle(self, angle):
"""
given an angle, set the distances from the grid points to the entry and exit coordinates
Expand Down Expand Up @@ -172,28 +173,10 @@ def get_path_length(self, grid_point, angle):
return total_distance, primary_distance, secondary_distance


def compute_cve(diffraction_data, mud, wavelength):
def _cve_brute_force(diffraction_data, mud):
"""
compute the cve for given diffraction data, mud and wavelength

Parameters
----------
diffraction_data Diffraction_object
the diffraction pattern
mud float
the mu*D of the diffraction object, where D is the diameter of the circle
wavelength float
the wavelength of the diffraction object

Returns
-------
the diffraction object with cve curves

it is computed as follows:
We first resample data and absorption correction to a more reasonable grid,
then calculate corresponding cve for the given mud in the resample grid
(since the same mu*D yields the same cve, we can assume that D/2=1, so mu=mud/2),
and finally interpolate cve to the original grid in diffraction_data.
compute cve for the given mud on a global grid using the brute-force method
assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1
"""

mu_sample_invmm = mud / 2
Expand All @@ -208,9 +191,87 @@ def compute_cve(diffraction_data, mud, wavelength):
muls = np.array(muls) / abs_correction.total_points_in_grid
cve = 1 / muls

abdo = Diffraction_object(wavelength=diffraction_data.wavelength)
abdo.insert_scattering_quantity(
TTH_GRID,
cve,
"tth",
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
)
return abdo


def _cve_polynomial_interpolation(diffraction_data, mud):
"""
compute cve using polynomial interpolation method, raise an error if mu*D is out of the range (0.5 to 6)
"""

if mud > 6 or mud < 0.5:
raise ValueError(
f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. "
f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }."
)
coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [
interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS
]
muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e)
cve = 1 / muls

abdo = Diffraction_object(wavelength=diffraction_data.wavelength)
abdo.insert_scattering_quantity(
TTH_GRID,
cve,
"tth",
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
)
return abdo


def _cve_method(method):
"""
retrieve the cve computation function for the given method
"""
methods = {
"brute_force": _cve_brute_force,
"polynomial_interpolation": _cve_polynomial_interpolation,
}
if method not in CVE_METHODS:
raise ValueError(f"Unknown method: {method}. Allowed methods are {*CVE_METHODS, }.")
return methods[method]


def compute_cve(diffraction_data, mud, method="polynomial_interpolation"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good practice here is to make the default as method=None then in the code do

If method is None:
     method = "polynomial_interpolation"

"""
compute and interpolate the cve for the given diffraction data and mud using the selected method

Parameters
----------
diffraction_data Diffraction_object
the diffraction pattern
mud float
the mu*D of the diffraction object, where D is the diameter of the circle
method str
the method used to calculate cve

Returns
-------
the diffraction object with cve curves

"""

cve_function = _cve_method(method)
abdo_on_global_tth = cve_function(diffraction_data, mud)
global_tth = abdo_on_global_tth.on_tth[0]
cve_on_global_tth = abdo_on_global_tth.on_tth[1]
orig_grid = diffraction_data.on_tth[0]
newcve = np.interp(orig_grid, TTH_GRID, cve)
abdo = Diffraction_object(wavelength=wavelength)
newcve = np.interp(orig_grid, global_tth, cve_on_global_tth)
abdo = Diffraction_object(wavelength=diffraction_data.wavelength)
abdo.insert_scattering_quantity(
orig_grid,
newcve,
Expand All @@ -220,7 +281,6 @@ def compute_cve(diffraction_data, mud, wavelength):
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
)

return abdo


Expand Down
18 changes: 12 additions & 6 deletions src/diffpy/labpdfproc/labpdfprocapp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from argparse import ArgumentParser

from diffpy.labpdfproc.functions import apply_corr, compute_cve
from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, compute_cve
from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args
from diffpy.utils.parsers.loaddata import loadData
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_args(override_cli_inputs=None):
"-o",
"--output-directory",
help="The name of the output directory. If not specified "
"then corrected files will be written to the current directory."
"then corrected files will be written to the current directory. "
"If the specified directory doesn't exist it will be created.",
default=None,
)
Expand All @@ -64,14 +64,20 @@ def get_args(override_cli_inputs=None):
action="store_true",
help="The absorption correction will be output to a file if this "
"flag is set. Default is that it is not output.",
default="tth",
)
p.add_argument(
"-f",
"--force-overwrite",
action="store_true",
help="Outputs will not overwrite existing file unless --force is specified.",
)
p.add_argument(
"-m",
"--method",
help=f"The method for computing absorption correction. Allowed methods: {*CVE_METHODS, }. "
f"Default method is polynomial interpolation if not specified. ",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an indent problem here?

default="polynomial_interpolation",
)
p.add_argument(
"-u",
"--user-metadata",
Expand Down Expand Up @@ -109,8 +115,8 @@ def main():
for filepath in args.input_paths:
outfilestem = filepath.stem + "_corrected"
corrfilestem = filepath.stem + "_cve"
outfile = args.output_directory / (outfilestem + ".chi")
corrfile = args.output_directory / (corrfilestem + ".chi")
outfile = args.output_directory / (outfilestem + ".xy")
corrfile = args.output_directory / (corrfilestem + ".xy")

if outfile.exists() and not args.force_overwrite:
sys.exit(
Expand All @@ -134,7 +140,7 @@ def main():
metadata=load_metadata(args, filepath),
)

absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
absorption_correction = compute_cve(input_pattern, args.mud, args.method)
corrected_data = apply_corr(input_pattern, absorption_correction)
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
corrected_data.dump(f"{outfile}", xtype="tth")
Expand Down
48 changes: 0 additions & 48 deletions src/diffpy/labpdfproc/tests/test_fast_cve.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/diffpy/labpdfproc/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_compute_cve(mocker):
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
mocker.patch("numpy.interp", return_value=expected_cve)
input_pattern = _instantiate_test_do(xarray, yarray)
actual_abdo = compute_cve(input_pattern, mud=1, wavelength=1.54)
actual_abdo = compute_cve(input_pattern, mud=1)
expected_abdo = _instantiate_test_do(
xarray,
expected_cve,
Expand Down
1 change: 1 addition & 0 deletions src/diffpy/labpdfproc/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def test_load_metadata(mocker, user_filesystem):
"wavelength": 0.71,
"output_directory": str(Path.cwd().resolve()),
"xtype": "tth",
"method": "polynomial_interpolation",
"key": "value",
"username": "cli_username",
"email": "[email protected]",
Expand Down
Loading
Loading