From 27ec9ef356a2558c533559e8499e971e07081b3d Mon Sep 17 00:00:00 2001 From: Satrajit Ghosh Date: Thu, 24 Jun 2021 21:51:51 -0400 Subject: [PATCH] fix: allow processing multiple files at a time --- kwyk/cli.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/kwyk/cli.py b/kwyk/cli.py index 2c4a512..4bf808d 100644 --- a/kwyk/cli.py +++ b/kwyk/cli.py @@ -1,7 +1,9 @@ +#! /usr/bin/env python from pathlib import Path import subprocess import tempfile import json +import os import click import nibabel as nib @@ -21,16 +23,19 @@ 'bvwn_multi_prior': _here.parent / 'saved_models' / 'all_50_bvwn_multi_prior' / '1556816070', } + @click.command() -@click.argument('infile') +@click.argument('infiles', nargs=-1) @click.argument('outprefix') @click.option('-m', '--model', type=click.Choice(_models.keys()), default="bwn_multi", required=True, help='Model to use for prediction.') @click.option('-n', '--n-samples', type=int, default=1, help='Number of samples to predict.') @click.option('-b', '--batch-size', type=int, default=8, help='Batch size during prediction.') @click.option('--save-variance', is_flag=True, help='Save volume with variance across `n-samples` predictions.') @click.option('--save-entropy', is_flag=True, help='Save volume of entropy values.') +@click.option('--overwrite', type=click.Choice(['yes', 'skip'], case_sensitive=False), help='Overwrite existing output or skip') +@click.option('--atlocation', is_flag=True, help='Save output in the same location as input') @click.version_option(version=__version__) -def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, save_entropy): +def predict(*, infiles, outprefix, model, n_samples, batch_size, save_variance, save_entropy, overwrite, atlocation): """Predict labels from features using a trained model. The predictions are saved to OUTPREFIX_* with the same extension as the input file. @@ -46,6 +51,13 @@ def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, s print("Your version: {0} Latest version: {1}".format(__version__, latest["version"])) + savedmodel_path = _models[model] + predictor = _get_predictor(savedmodel_path) + for infile in infiles: + _predict(infile, outprefix, predictor, n_samples, batch_size, save_variance, save_entropy, overwrite, atlocation) + + +def _predict(infile, outprefix, predictor, n_samples, batch_size, save_variance, save_entropy, overwrite, atlocation): _orig_infile = infile # Are there other neuroimaging file extensions with multiple periods? @@ -55,6 +67,9 @@ def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, s outfile_ext = Path(infile).suffix outfile_stem = outprefix + if atlocation: + outfile_stem = Path(infile).parent / outfile_stem + outfile_means = "{}_means{}".format(outfile_stem, outfile_ext) outfile_variance = "{}_variance{}".format(outfile_stem, outfile_ext) outfile_entropy = "{}_entropy{}".format(outfile_stem, outfile_ext) @@ -62,7 +77,12 @@ def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, s for ff in [outfile_means, outfile_variance, outfile_entropy, outfile_uncertainty]: if Path(ff).exists(): - raise FileExistsError("file exists: {}".format(ff)) + if overwrite == "skip": + return + elif overwrite == "yes": + pass + else: + raise FileExistsError("file exists: {}".format(ff)) required_shape = (256, 256, 256) block_shape = (32, 32, 32) @@ -79,10 +99,7 @@ def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, s else: tmp = None - savedmodel_path = _models[model] - print("++ Running forward pass of model.") - predictor = _get_predictor(savedmodel_path) outputs = predict_from_filepath( infile, predictor=predictor, @@ -125,12 +142,12 @@ def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, s json.dump(average_uncertainty, fp, indent=4) - def _conform(input, output): """Conform volume using FreeSurfer.""" subprocess.run(['mri_convert', '--conform', input, output], check=True) return output + def _reslice(input, output, reference, labels=False): """Conform volume using FreeSurfer.""" if labels: @@ -140,3 +157,8 @@ def _reslice(input, output, reference, labels=False): else: subprocess.run(['mri_convert', '-rl', reference, input, output], check=True) return output + + +if __name__ == '__main__': + predict() +