Skip to content

Commit 3ec3c4d

Browse files
authored
Merge pull request #40 from aajayi-21/mainfunction
stretchednmfapp.py
2 parents c0be760 + 3a91d7e commit 3ec3c4d

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

diffpy/snmf/stretchednmfapp.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
import argparse
3+
from pathlib import Path
4+
from diffpy.snmf.subroutines import lift_data, initialize_components
5+
from diffpy.snmf.containers import ComponentSignal
6+
from diffpy.snmf.io import load_input_signals, initialize_variables
7+
8+
ALLOWED_DATA_TYPES = ['powder_diffraction', 'pd', 'pair_distribution_function', 'pdf']
9+
10+
11+
def create_parser():
12+
parser = argparse.ArgumentParser(
13+
prog="stretched_nmf",
14+
description="Stretched Nonnegative Matrix Factorization"
15+
)
16+
parser.add_argument('-i', '--input-directory', type=str, default=None,
17+
help="Directory containing experimental data. Defaults to current working directory.")
18+
parser.add_argument('-o', '--output-directory', type=str,
19+
help="The directory where the results will be written. Defaults to '<input_directory>/snmf_results'.")
20+
parser.add_argument('t', '--data-type', type=str, default=None, choices=ALLOWED_DATA_TYPES,
21+
help="The type of the experimental data.")
22+
parser.add_argument('-l', '--lift-factor', type=float, default=1,
23+
help="The lifting factor. Data will be lifted by lifted_data = data + abs(min(data) * lift). Default is 1.")
24+
parser.add_argument('number-of-components', type=int,
25+
help="The number of component signals for the NMF decomposition. Must be an integer greater than 0")
26+
parser.add_argument('-v', '--version', action='version', help='Print the software version number')
27+
args = parser.parse_args()
28+
return args
29+
30+
31+
def main():
32+
args = create_parser()
33+
if args.input_directory is None:
34+
args.input_directory = Path.cwd()
35+
grid, input_data = load_input_signals(args.input_directory)
36+
lifted_input_data = lift_data(input_data, args.lift_factor)
37+
variables = initialize_variables(lifted_input_data, args.number_of_components, args.data_type)
38+
components = initialize_components(variables['number_of_components'], variables['number_of_signals'], grid)
39+
return components

0 commit comments

Comments
 (0)