Skip to content

Commit c734b5e

Browse files
Merge branch 'main' into bump-patch
2 parents ccbc5f4 + d1b46c9 commit c734b5e

File tree

3 files changed

+94
-28
lines changed

3 files changed

+94
-28
lines changed

bioimageio/core/__main__.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,19 @@ def test_model(
7070
model_rdf: str = typer.Argument(
7171
..., help="Path or URL to the model resource description file (rdf.yaml) or zipped model."
7272
),
73-
weight_format: Optional[WeightFormatEnum] = typer.Argument(None, help="The weight format to use."),
74-
devices: Optional[List[str]] = typer.Argument(None, help="Devices for running the model."),
75-
decimal: int = typer.Argument(4, help="The test precision."),
73+
weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."),
74+
devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."),
75+
decimal: int = typer.Option(4, help="The test precision."),
7676
) -> int:
7777
# this is a weird typer bug: default devices are empty tuple although they should be None
7878
if len(devices) == 0:
7979
devices = None
80-
summary = resource_tests.test_model(model_rdf, weight_format=weight_format, devices=devices, decimal=decimal)
80+
summary = resource_tests.test_model(
81+
model_rdf,
82+
weight_format=None if weight_format is None else weight_format.value,
83+
devices=devices,
84+
decimal=decimal,
85+
)
8186
if summary["error"] is None:
8287
print(f"Model test for {model_rdf} has passed.")
8388
return 0
@@ -95,14 +100,16 @@ def test_resource(
95100
rdf: str = typer.Argument(
96101
..., help="Path or URL to the resource description file (rdf.yaml) or zipped resource package."
97102
),
98-
weight_format: Optional[WeightFormatEnum] = typer.Argument(None, help="(for model only) The weight format to use."),
99-
devices: Optional[List[str]] = typer.Argument(None, help="(for model only) Devices for running the model."),
100-
decimal: int = typer.Argument(4, help="(for model only) The test precision."),
103+
weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="(for model only) The weight format to use."),
104+
devices: Optional[List[str]] = typer.Option(None, help="(for model only) Devices for running the model."),
105+
decimal: int = typer.Option(4, help="(for model only) The test precision."),
101106
) -> int:
102107
# this is a weird typer bug: default devices are empty tuple although they should be None
103108
if len(devices) == 0:
104109
devices = None
105-
summary = resource_tests.test_resource(rdf, weight_format=weight_format, devices=devices, decimal=decimal)
110+
summary = resource_tests.test_resource(
111+
rdf, weight_format=None if weight_format is None else weight_format.value, devices=devices, decimal=decimal
112+
)
106113
if summary["error"] is None:
107114
print(f"Resource test for {rdf} has passed.")
108115
return 0
@@ -129,10 +136,10 @@ def predict_image(
129136
# tiling: Optional[Union[str, bool]] = typer.Argument(
130137
# None, help="Padding to apply in each dimension passed as json encoded string."
131138
# ),
132-
padding: Optional[bool] = typer.Argument(None, help="Whether to pad the image to a size suited for the model."),
133-
tiling: Optional[bool] = typer.Argument(None, help="Whether to run prediction in tiling mode."),
134-
weight_format: Optional[str] = typer.Argument(None, help="The weight format to use."),
135-
devices: Optional[List[str]] = typer.Argument(None, help="Devices for running the model."),
139+
padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."),
140+
tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."),
141+
weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."),
142+
devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."),
136143
) -> int:
137144

138145
if isinstance(padding, str):
@@ -145,7 +152,9 @@ def predict_image(
145152
# this is a weird typer bug: default devices are empty tuple although they should be None
146153
if len(devices) == 0:
147154
devices = None
148-
prediction.predict_image(model_rdf, inputs, outputs, padding, tiling, weight_format, devices)
155+
prediction.predict_image(
156+
model_rdf, inputs, outputs, padding, tiling, None if weight_format is None else weight_format.value, devices
157+
)
149158
return 0
150159

151160

@@ -167,10 +176,10 @@ def predict_images(
167176
# tiling: Optional[Union[str, bool]] = typer.Argument(
168177
# None, help="Padding to apply in each dimension passed as json encoded string."
169178
# ),
170-
padding: Optional[bool] = typer.Argument(None, help="Whether to pad the image to a size suited for the model."),
171-
tiling: Optional[bool] = typer.Argument(None, help="Whether to run prediction in tiling mode."),
172-
weight_format: Optional[str] = typer.Argument(None, help="The weight format to use."),
173-
devices: Optional[List[str]] = typer.Argument(None, help="Devices for running the model."),
179+
padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."),
180+
tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."),
181+
weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."),
182+
devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."),
174183
) -> int:
175184
input_files = glob(input_pattern)
176185
input_names = [os.path.split(infile)[1] for infile in input_files]
@@ -194,7 +203,7 @@ def predict_images(
194203
output_files,
195204
padding=padding,
196205
tiling=tiling,
197-
weight_format=weight_format,
206+
weight_format=None if weight_format is None else weight_format.value,
198207
devices=devices,
199208
verbose=True,
200209
)
@@ -213,8 +222,8 @@ def convert_torch_weights_to_onnx(
213222
),
214223
output_path: Path = typer.Argument(..., help="Where to save the onnx weights."),
215224
opset_version: Optional[int] = typer.Argument(12, help="Onnx opset version."),
216-
use_tracing: bool = typer.Argument(True, help="Whether to use torch.jit tracing or scripting."),
217-
verbose: bool = typer.Argument(True, help="Verbosity"),
225+
use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
226+
verbose: bool = typer.Option(True, help="Verbosity"),
218227
) -> int:
219228
return torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose)
220229

@@ -226,7 +235,7 @@ def convert_torch_weights_to_torchscript(
226235
..., help="Path to the model resource description file (rdf.yaml) or zipped model."
227236
),
228237
output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."),
229-
use_tracing: bool = typer.Argument(True, help="Whether to use torch.jit tracing or scripting."),
238+
use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
230239
) -> int:
231240
return torch_converter.convert_weights_to_pytorch_script(model_rdf, output_path, use_tracing)
232241

tests/test_cli.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,46 @@ def test_cli_test_model(unet2d_nuclei_broad_model):
1616
assert ret.returncode == 0
1717

1818

19-
def test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path):
20-
spec = load_resource_description(unet2d_nuclei_broad_model)
21-
in_path = spec.test_inputs[0]
22-
out_path = tmp_path.with_suffix(".npy")
19+
def test_cli_test_model_with_weight_format(unet2d_nuclei_broad_model):
20+
ret = subprocess.run(
21+
["bioimageio", "test-model", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"]
22+
)
23+
assert ret.returncode == 0
24+
25+
26+
def test_cli_test_resource(unet2d_nuclei_broad_model):
27+
ret = subprocess.run(["bioimageio", "test-model", unet2d_nuclei_broad_model])
28+
assert ret.returncode == 0
29+
30+
31+
def test_cli_test_resource_with_weight_format(unet2d_nuclei_broad_model):
2332
ret = subprocess.run(
24-
["bioimageio", "predict-image", unet2d_nuclei_broad_model, "--inputs", str(in_path), "--outputs", str(out_path)]
33+
["bioimageio", "test-model", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"]
2534
)
2635
assert ret.returncode == 0
36+
37+
38+
def _test_cli_predict_image(model, tmp_path, extra_kwargs=None):
39+
spec = load_resource_description(model)
40+
in_path = spec.test_inputs[0]
41+
out_path = tmp_path.with_suffix(".npy")
42+
cmd = ["bioimageio", "predict-image", model, "--inputs", str(in_path), "--outputs", str(out_path)]
43+
if extra_kwargs is not None:
44+
cmd.extend(extra_kwargs)
45+
ret = subprocess.run(cmd)
46+
assert ret.returncode == 0
2747
assert out_path.exists()
2848

2949

30-
def test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path):
50+
def test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path):
51+
_test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path)
52+
53+
54+
def test_cli_predict_image_with_weight_format(unet2d_nuclei_broad_model, tmp_path):
55+
_test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"])
56+
57+
58+
def _test_cli_predict_images(model, tmp_path, extra_kwargs=None):
3159
n_images = 3
3260
shape = (1, 1, 128, 128)
3361
expected_shape = (1, 1, 128, 128)
@@ -45,14 +73,25 @@ def test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path):
4573
expected_outputs.append(out_folder / f"im-{i}.npy")
4674

4775
input_pattern = str(in_folder / "*.npy")
48-
ret = subprocess.run(["bioimageio", "predict-images", unet2d_nuclei_broad_model, input_pattern, str(out_folder)])
76+
cmd = ["bioimageio", "predict-images", model, input_pattern, str(out_folder)]
77+
if extra_kwargs is not None:
78+
cmd.extend(extra_kwargs)
79+
ret = subprocess.run(cmd)
4980
assert ret.returncode == 0
5081

5182
for out_path in expected_outputs:
5283
assert out_path.exists()
5384
assert np.load(out_path).shape == expected_shape
5485

5586

87+
def test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path):
88+
_test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path)
89+
90+
91+
def test_cli_predict_images_with_weight_format(unet2d_nuclei_broad_model, tmp_path):
92+
_test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"])
93+
94+
5695
def test_torch_to_torchscript(unet2d_nuclei_broad_model, tmp_path):
5796
out_path = tmp_path.with_suffix(".pt")
5897
ret = subprocess.run(

tests/test_prediction.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@ def test_predict_image(unet2d_fixed_shape_or_not, tmpdir):
3939
assert_array_almost_equal(res, exp, decimal=4)
4040

4141

42+
def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir):
43+
from bioimageio.core.prediction import predict_image
44+
45+
spec = load_resource_description(unet2d_fixed_shape_or_not)
46+
assert isinstance(spec, Model)
47+
inputs = spec.test_inputs
48+
49+
outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))]
50+
predict_image(unet2d_fixed_shape_or_not, inputs, outputs, weight_format="pytorch_state_dict")
51+
for out_path in outputs:
52+
assert out_path.exists()
53+
54+
result = [np.load(str(p)) for p in outputs]
55+
expected = [np.load(str(p)) for p in spec.test_outputs]
56+
for res, exp in zip(result, expected):
57+
assert_array_almost_equal(res, exp, decimal=4)
58+
59+
4260
def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path):
4361
any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
4462
from bioimageio.core.prediction import predict_image

0 commit comments

Comments
 (0)