Skip to content

Commit ad46e97

Browse files
merged updates from wavelength
2 parents 0d5492b + 0e2ab8c commit ad46e97

File tree

3 files changed

+74
-6
lines changed

3 files changed

+74
-6
lines changed

src/diffpy/labpdfproc/labpdfprocapp.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44

55
from diffpy.labpdfproc.functions import apply_corr, compute_cve
6-
from diffpy.labpdfproc.tools import set_output_directory
6+
from diffpy.labpdfproc.tools import set_output_directory, set_wavelength
77
from diffpy.utils.parsers.loaddata import loadData
88
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
99

@@ -64,9 +64,10 @@ def get_args():
6464

6565
def main():
6666
args = get_args()
67-
wavelength = WAVELENGTHS[args.anode_type]
68-
filepath = Path(args.input_file)
6967
args.output_directory = set_output_directory(args)
68+
args.wavelength = set_wavelength(args)
69+
70+
filepath = Path(args.input_file)
7071
outfilestem = filepath.stem + "_corrected"
7172
corrfilestem = filepath.stem + "_cve"
7273
outfile = args.output_directory / (outfilestem + ".chi")
@@ -83,7 +84,7 @@ def main():
8384
f"exists. Please rerun specifying -f if you want to overwrite it"
8485
)
8586

86-
input_pattern = Diffraction_object(wavelength=wavelength)
87+
input_pattern = Diffraction_object(wavelength=args.wavelength)
8788
xarray, yarray = loadData(args.input_file, unpack=True)
8889
input_pattern.insert_scattering_quantity(
8990
xarray,
@@ -94,7 +95,7 @@ def main():
9495
metadata={"muD": args.mud, "anode_type": args.anode_type},
9596
)
9697

97-
absorption_correction = compute_cve(input_pattern, args.mud, wavelength)
98+
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
9899
corrected_data = apply_corr(input_pattern, absorption_correction)
99100
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
100101
corrected_data.dump(f"{outfile}", xtype="tth")

src/diffpy/labpdfproc/tests/test_tools.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from diffpy.labpdfproc.tools import set_output_directory
6+
from diffpy.labpdfproc.tools import set_output_directory, set_wavelength
77

88
params1 = [
99
([None], [Path.cwd().resolve()]),
@@ -35,3 +35,33 @@ def test_set_output_directory_bad():
3535
actual_args.output_directory = set_output_directory(actual_args)
3636
assert Path(actual_args.output_directory).exists()
3737
assert not Path(actual_args.output_directory).is_dir()
38+
39+
40+
params2 = [
41+
([None, None], [0.71]),
42+
([None, "Ag"], [0.59]),
43+
([0.25, "Ag"], [0.25]),
44+
([0.25, None], [0.25]),
45+
]
46+
47+
48+
@pytest.mark.parametrize("inputs, expected", params2)
49+
def test_set_wavelength(inputs, expected):
50+
expected_wavelength = expected[0]
51+
actual_args = argparse.Namespace(wavelength=inputs[0], anode_type=inputs[1])
52+
actual_wavelength = set_wavelength(actual_args)
53+
assert actual_wavelength == expected_wavelength
54+
55+
56+
params3 = [
57+
([None, "invalid"]),
58+
([0, None]),
59+
([-1, "Mo"]),
60+
]
61+
62+
63+
@pytest.mark.parametrize("inputs", params3)
64+
def test_set_wavelength_bad(inputs):
65+
with pytest.raises(ValueError):
66+
actual_args = argparse.Namespace(wavelength=inputs[0], anode_type=inputs[1])
67+
actual_args.wavelength = set_wavelength(actual_args)

src/diffpy/labpdfproc/tools.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from pathlib import Path
22

3+
WAVELENGTHS = {"Mo": 0.71, "Ag": 0.59, "Cu": 1.54}
4+
known_sources = [key for key in WAVELENGTHS.keys()]
5+
36

47
def set_output_directory(args):
58
"""
@@ -23,3 +26,37 @@ def set_output_directory(args):
2326
output_dir = Path(args.output_directory).resolve() if args.output_directory else Path.cwd().resolve()
2427
output_dir.mkdir(parents=True, exist_ok=True)
2528
return output_dir
29+
30+
31+
def set_wavelength(args):
32+
"""
33+
Set the wavelength based on the given input arguments
34+
35+
Parameters
36+
----------
37+
args argparse.Namespace
38+
the arguments from the parser
39+
40+
Returns
41+
-------
42+
float: the wavelength value
43+
44+
we raise an ValueError if the input wavelength is non-positive
45+
or if the input anode_type is not one of the known sources
46+
47+
"""
48+
if args.wavelength is not None and args.wavelength <= 0:
49+
raise ValueError("Please rerun the program specifying a positive float number.")
50+
if not args.wavelength and args.anode_type and args.anode_type not in WAVELENGTHS:
51+
raise ValueError(
52+
f"Invalid anode type {args.anode_type}. "
53+
f"Please rerun the program to either specify a wavelength as a positive float number "
54+
f"or specify anode_type as one of {known_sources}."
55+
)
56+
57+
if args.wavelength:
58+
return args.wavelength
59+
elif args.anode_type:
60+
return WAVELENGTHS[args.anode_type]
61+
else:
62+
return WAVELENGTHS["Mo"]

0 commit comments

Comments
 (0)