Skip to content

Commit 34d5478

Browse files
committed
improve increase_available_weight_formats
1 parent e35735d commit 34d5478

File tree

1 file changed

+78
-23
lines changed

1 file changed

+78
-23
lines changed
Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,92 @@
1-
from copy import deepcopy
2-
from pathlib import Path
3-
from typing import List, Optional, Sequence, Union
1+
from typing import Optional, Sequence
42

5-
from bioimageio.spec.model import v0_4, v0_5
3+
from loguru import logger
4+
from pydantic import DirectoryPath
5+
6+
from bioimageio.core._resource_tests import test_model
7+
from bioimageio.spec import load_model_description, save_bioimageio_package_as_folder
8+
from bioimageio.spec._internal.types import AbsoluteTolerance, RelativeTolerance
9+
from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat
610

711

812
def increase_available_weight_formats(
9-
model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr],
13+
model_descr: ModelDescr,
1014
*,
11-
source_format: Optional[v0_5.WeightsFormat] = None,
12-
target_format: Optional[v0_5.WeightsFormat] = None,
13-
output_path: Path,
14-
devices: Optional[Sequence[str]] = None,
15-
) -> Union[v0_4.ModelDescr, v0_5.ModelDescr]:
16-
"""Convert neural network weights to other formats and add them to the model description"""
17-
if not isinstance(model_descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
18-
raise TypeError(
19-
f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_descr)}"
20-
)
15+
output_path: DirectoryPath,
16+
source_format: Optional[WeightsFormat] = None,
17+
target_format: Optional[WeightsFormat] = None,
18+
devices: Sequence[str] = ("cpu",),
19+
) -> ModelDescr:
20+
"""Convert model weights to other formats and add them to the model description
21+
22+
Args:
23+
output_path: Path to save updated model package to.
24+
source_format: convert from a specific weights format.
25+
Default: choose automatically from any available.
26+
target_format: convert to a specific weights format.
27+
Default: attempt to convert to any missing format.
28+
devices: Devices that may be used during conversion.
29+
"""
30+
if not isinstance(model_descr, ModelDescr):
31+
raise TypeError(type(model_descr))
32+
33+
# save model to local folder
34+
output_path = save_bioimageio_package_as_folder(
35+
model_descr, output_path=output_path
36+
)
37+
# reload from local folder to make sure we do not edit the given model
38+
_model_descr = load_model_description(output_path)
39+
assert isinstance(_model_descr, ModelDescr)
40+
model_descr = _model_descr
41+
del _model_descr
2142

2243
if source_format is None:
23-
available = [wf for wf, w in model_descr.weights if w is not None]
24-
missing = [wf for wf, w in model_descr.weights if w is None]
44+
available = set(model_descr.weights.available_formats)
45+
else:
46+
available = {source_format}
47+
48+
if target_format is None:
49+
missing = set(model_descr.weights.missing_formats)
2550
else:
26-
available = [source_format]
27-
missing = [target_format]
51+
missing = {target_format}
2852

2953
if "pytorch_state_dict" in available and "onnx" in missing:
3054
from .pytorch_to_onnx import convert
3155

32-
onnx = convert(model_descr)
56+
try:
57+
model_descr.weights.onnx = convert(
58+
model_descr,
59+
output_path=output_path,
60+
use_tracing=False,
61+
)
62+
except Exception as e:
63+
logger.error(e)
64+
else:
65+
available.add("onnx")
66+
missing.discard("onnx")
3367

34-
else:
35-
raise NotImplementedError(
36-
f"Converting from '{source_format}' to '{target_format}' is not yet implemented. Please create an issue at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
68+
if "pytorch_state_dict" in available and "torchscript" in missing:
69+
from .pytorch_to_torchscript import convert
70+
71+
try:
72+
model_descr.weights.torchscript = convert(
73+
model_descr,
74+
output_path=output_path,
75+
use_tracing=False,
76+
)
77+
except Exception as e:
78+
logger.error(e)
79+
else:
80+
available.add("torchscript")
81+
missing.discard("torchscript")
82+
83+
if missing:
84+
logger.warning(
85+
f"Converting from any of the available weights formats {available} to any"
86+
+ f" of {missing} is not yet implemented. Please create an issue at"
87+
+ " https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
88+
+ " if you would like bioimageio.core to support a particular conversion."
3789
)
90+
91+
test_model(model_descr).display()
92+
return model_descr

0 commit comments

Comments
 (0)