|
| 1 | +import numpy as np |
| 2 | +import argparse |
| 3 | +from pathlib import Path |
| 4 | +from diffpy.snmf.subroutines import lift_data, initialize_components |
| 5 | +from diffpy.snmf.containers import ComponentSignal |
| 6 | +from diffpy.snmf.io import load_input_signals, initialize_variables |
| 7 | + |
| 8 | +ALLOWED_DATA_TYPES = ['powder_diffraction', 'pd', 'pair_distribution_function', 'pdf'] |
| 9 | + |
| 10 | + |
| 11 | +def create_parser(): |
| 12 | + parser = argparse.ArgumentParser( |
| 13 | + prog="stretched_nmf", |
| 14 | + description="Stretched Nonnegative Matrix Factorization" |
| 15 | + ) |
| 16 | + parser.add_argument('-i', '--input-directory', type=str, default=None, |
| 17 | + help="Directory containing experimental data. Defaults to current working directory.") |
| 18 | + parser.add_argument('-o', '--output-directory', type=str, |
| 19 | + help="The directory where the results will be written. Defaults to '<input_directory>/snmf_results'.") |
| 20 | + parser.add_argument('t', '--data-type', type=str, default=None, choices=ALLOWED_DATA_TYPES, |
| 21 | + help="The type of the experimental data.") |
| 22 | + parser.add_argument('-l', '--lift-factor', type=float, default=1, |
| 23 | + help="The lifting factor. Data will be lifted by lifted_data = data + abs(min(data) * lift). Default is 1.") |
| 24 | + parser.add_argument('number-of-components', type=int, |
| 25 | + help="The number of component signals for the NMF decomposition. Must be an integer greater than 0") |
| 26 | + parser.add_argument('-v', '--version', action='version', help='Print the software version number') |
| 27 | + args = parser.parse_args() |
| 28 | + return args |
| 29 | + |
| 30 | + |
| 31 | +def main(): |
| 32 | + args = create_parser() |
| 33 | + if args.input_directory is None: |
| 34 | + args.input_directory = Path.cwd() |
| 35 | + grid, input_data = load_input_signals(args.input_directory) |
| 36 | + lifted_input_data = lift_data(input_data, args.lift_factor) |
| 37 | + variables = initialize_variables(lifted_input_data, args.number_of_components, args.data_type) |
| 38 | + components = initialize_components(variables['number_of_components'], variables['number_of_signals'], grid) |
| 39 | + return components |
0 commit comments