Skip to content

Commit

Permalink
fix _output_axes
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Apr 23, 2024
1 parent 73063bd commit 2be9913
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
14 changes: 11 additions & 3 deletions bioimageio/core/model_adapters/_keras_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from loguru import logger
from numpy.typing import NDArray

from bioimageio.core.tensor import Tensor
from bioimageio.spec._internal.io_utils import download
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import Version

from .._settings import settings
from ..digest_spec import get_axes_infos
from ..tensor import Tensor
from ._model_adapter import ModelAdapter

os.environ["KERAS_BACKEND"] = settings.keras_backend
Expand Down Expand Up @@ -74,7 +75,10 @@ def __init__(
weight_path = download(model_description.weights.keras_hdf5.source).path

self._network = keras.models.load_model(weight_path)
self._output_axes = [tuple(out.axes) for out in model_description.outputs]
self._output_axes = [
tuple(a.id for a in get_axes_infos(out))
for out in model_description.outputs
]

def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
_result: Union[Sequence[NDArray[Any]], NDArray[Any]]
Expand All @@ -87,7 +91,11 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
result = [_result] # type: ignore

assert len(result) == len(self._output_axes)
return [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)]
ret: List[Optional[Tensor]] = []
ret.extend(
[Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)]
)
return ret

def unload(self) -> None:
logger.warning(
Expand Down
7 changes: 2 additions & 5 deletions bioimageio/core/model_adapters/_pytorch_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bioimageio.spec.utils import download

from ..axis import AxisId
from ..digest_spec import import_callable
from ..digest_spec import get_axes_infos, import_callable
from ..tensor import Tensor
from ._model_adapter import ModelAdapter

Expand All @@ -31,10 +31,7 @@ def __init__(
if torch is None:
raise ImportError("torch")
super().__init__()
self.output_dims = [
tuple(AxisId(a) if isinstance(a, str) else a.id for a in out.axes)
for out in outputs
]
self.output_dims = [tuple(a.id for a in get_axes_infos(out)) for out in outputs]
self._network = self.get_network(weights)
self._devices = self.get_devices(devices)
self._network = self._network.to(self._devices[0])
Expand Down

0 comments on commit 2be9913

Please sign in to comment.