Skip to content

Commit

Permalink
fix: allow processing multiple files at a time
Browse files Browse the repository at this point in the history
  • Loading branch information
satra committed Jun 25, 2021
1 parent ebdc566 commit 27ec9ef
Showing 1 changed file with 29 additions and 7 deletions.
36 changes: 29 additions & 7 deletions kwyk/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#! /usr/bin/env python
from pathlib import Path
import subprocess
import tempfile
import json
import os

import click
import nibabel as nib
Expand All @@ -21,16 +23,19 @@
'bvwn_multi_prior': _here.parent / 'saved_models' / 'all_50_bvwn_multi_prior' / '1556816070',
}


@click.command()
@click.argument('infile')
@click.argument('infiles', nargs=-1)
@click.argument('outprefix')
@click.option('-m', '--model', type=click.Choice(_models.keys()), default="bwn_multi", required=True, help='Model to use for prediction.')
@click.option('-n', '--n-samples', type=int, default=1, help='Number of samples to predict.')
@click.option('-b', '--batch-size', type=int, default=8, help='Batch size during prediction.')
@click.option('--save-variance', is_flag=True, help='Save volume with variance across `n-samples` predictions.')
@click.option('--save-entropy', is_flag=True, help='Save volume of entropy values.')
@click.option('--overwrite', type=click.Choice(['yes', 'skip'], case_sensitive=False), help='Overwrite existing output or skip')
@click.option('--atlocation', is_flag=True, help='Save output in the same location as input')
@click.version_option(version=__version__)
def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, save_entropy):
def predict(*, infiles, outprefix, model, n_samples, batch_size, save_variance, save_entropy, overwrite, atlocation):
"""Predict labels from features using a trained model.
The predictions are saved to OUTPREFIX_* with the same extension as the input file.
Expand All @@ -46,6 +51,13 @@ def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, s
print("Your version: {0} Latest version: {1}".format(__version__,
latest["version"]))

savedmodel_path = _models[model]
predictor = _get_predictor(savedmodel_path)
for infile in infiles:
_predict(infile, outprefix, predictor, n_samples, batch_size, save_variance, save_entropy, overwrite, atlocation)


def _predict(infile, outprefix, predictor, n_samples, batch_size, save_variance, save_entropy, overwrite, atlocation):
_orig_infile = infile

# Are there other neuroimaging file extensions with multiple periods?
Expand All @@ -55,14 +67,22 @@ def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, s
outfile_ext = Path(infile).suffix
outfile_stem = outprefix

if atlocation:
outfile_stem = Path(infile).parent / outfile_stem

outfile_means = "{}_means{}".format(outfile_stem, outfile_ext)
outfile_variance = "{}_variance{}".format(outfile_stem, outfile_ext)
outfile_entropy = "{}_entropy{}".format(outfile_stem, outfile_ext)
outfile_uncertainty = "{}_uncertainty{}".format(outfile_stem, '.json')

for ff in [outfile_means, outfile_variance, outfile_entropy, outfile_uncertainty]:
if Path(ff).exists():
raise FileExistsError("file exists: {}".format(ff))
if overwrite == "skip":
return
elif overwrite == "yes":
pass
else:
raise FileExistsError("file exists: {}".format(ff))

required_shape = (256, 256, 256)
block_shape = (32, 32, 32)
Expand All @@ -79,10 +99,7 @@ def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, s
else:
tmp = None

savedmodel_path = _models[model]

print("++ Running forward pass of model.")
predictor = _get_predictor(savedmodel_path)
outputs = predict_from_filepath(
infile,
predictor=predictor,
Expand Down Expand Up @@ -125,12 +142,12 @@ def predict(*, infile, outprefix, model, n_samples, batch_size, save_variance, s
json.dump(average_uncertainty, fp, indent=4)



def _conform(input, output):
"""Conform volume using FreeSurfer."""
subprocess.run(['mri_convert', '--conform', input, output], check=True)
return output


def _reslice(input, output, reference, labels=False):
"""Conform volume using FreeSurfer."""
if labels:
Expand All @@ -140,3 +157,8 @@ def _reslice(input, output, reference, labels=False):
else:
subprocess.run(['mri_convert', '-rl', reference, input, output], check=True)
return output


if __name__ == '__main__':
predict()

0 comments on commit 27ec9ef

Please sign in to comment.