Skip to content

Commit

Permalink
WIP use Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Mar 30, 2024
1 parent 636861a commit 4fa24ce
Show file tree
Hide file tree
Showing 16 changed files with 667 additions and 288 deletions.
8 changes: 5 additions & 3 deletions bioimageio/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from ._resource_tests import test_description as test_description
from ._resource_tests import test_model as test_model
from ._settings import settings as settings
from .utils import VERSION
from .axis import Axis as Axis
from .axis import AxisId as AxisId
from .sample import Sample as Sample
from .tensor import Tensor as Tensor
from .tensor import TensorId as TensorId
from .tile import Tile as Tile
from .sample import Sample as Sample

from .utils import VERSION

__version__ = VERSION

Expand Down
235 changes: 235 additions & 0 deletions bioimageio/core/_magic_tensor_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# this file was modified from the generated
# https://github.com/pydata/xarray/blob/cf3655968b8b12cc0ecd28fb324e63fb94d5e7e2/xarray/core/_typed_ops.py
# TODO: should we generate this ourselves?
# TODO: test these magic methods
import operator
from typing import Any, Callable

from typing_extensions import Self
from xarray.core import nputils, ops


class MagicTensorOpsMixin:
__slots__ = ()
_Compatible = Any

def _binary_op(
self,
other: _Compatible,
f: Callable[[Any, Any], Any],
reflexive: bool = False,
) -> Self:
raise NotImplementedError

def __add__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.add)

def __sub__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.sub)

def __mul__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.mul)

def __pow__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.pow)

def __truediv__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.truediv)

def __floordiv__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.floordiv)

def __mod__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.mod)

def __and__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.and_)

def __xor__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.xor)

def __or__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.or_)

def __lshift__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.lshift)

def __rshift__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.rshift)

def __lt__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.lt)

def __le__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.le)

def __gt__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.gt)

def __ge__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.ge)

def __eq__(self, other: _Compatible) -> Self: # type: ignore[override]
return self._binary_op(
other, nputils.array_eq # pyright: ignore[reportUnknownArgumentType]
)

def __ne__(self, other: _Compatible) -> Self: # type: ignore[override]
return self._binary_op(
other, nputils.array_ne # pyright: ignore[reportUnknownArgumentType]
)

# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]

def __radd__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)

def __rsub__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.sub, reflexive=True)

def __rmul__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)

def __rpow__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.pow, reflexive=True)

def __rtruediv__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.truediv, reflexive=True)

def __rfloordiv__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.floordiv, reflexive=True)

def __rmod__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)

def __rand__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.and_, reflexive=True)

def __rxor__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.xor, reflexive=True)

def __ror__(self, other: _Compatible) -> Self:
return self._binary_op(other, operator.or_, reflexive=True)

def _inplace_binary_op(
self, other: _Compatible, f: Callable[[Any, Any], Any]
) -> Self:
raise NotImplementedError

def __iadd__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.iadd)

def __isub__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.isub)

def __imul__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.imul)

def __ipow__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.ipow)

def __itruediv__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.itruediv)

def __ifloordiv__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.ifloordiv)

def __imod__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.imod)

def __iand__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.iand)

def __ixor__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.ixor)

def __ior__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.ior)

def __ilshift__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.ilshift)

def __irshift__(self, other: _Compatible) -> Self:
return self._inplace_binary_op(other, operator.irshift)

def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError

def __neg__(self) -> Self:
return self._unary_op(operator.neg)

def __pos__(self) -> Self:
return self._unary_op(operator.pos)

def __abs__(self) -> Self:
return self._unary_op(operator.abs)

def __invert__(self) -> Self:
return self._unary_op(operator.invert)

def round(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(
ops.round_, *args, **kwargs # pyright: ignore[reportUnknownArgumentType]
)

def argsort(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(
ops.argsort, *args, **kwargs # pyright: ignore[reportUnknownArgumentType]
)

def conj(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(
ops.conj, *args, **kwargs # pyright: ignore[reportUnknownArgumentType]
)

def conjugate(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(
ops.conjugate, *args, **kwargs # pyright: ignore[reportUnknownArgumentType]
)

__add__.__doc__ = operator.add.__doc__
__sub__.__doc__ = operator.sub.__doc__
__mul__.__doc__ = operator.mul.__doc__
__pow__.__doc__ = operator.pow.__doc__
__truediv__.__doc__ = operator.truediv.__doc__
__floordiv__.__doc__ = operator.floordiv.__doc__
__mod__.__doc__ = operator.mod.__doc__
__and__.__doc__ = operator.and_.__doc__
__xor__.__doc__ = operator.xor.__doc__
__or__.__doc__ = operator.or_.__doc__
__lshift__.__doc__ = operator.lshift.__doc__
__rshift__.__doc__ = operator.rshift.__doc__
__lt__.__doc__ = operator.lt.__doc__
__le__.__doc__ = operator.le.__doc__
__gt__.__doc__ = operator.gt.__doc__
__ge__.__doc__ = operator.ge.__doc__
__eq__.__doc__ = nputils.array_eq.__doc__
__ne__.__doc__ = nputils.array_ne.__doc__
__radd__.__doc__ = operator.add.__doc__
__rsub__.__doc__ = operator.sub.__doc__
__rmul__.__doc__ = operator.mul.__doc__
__rpow__.__doc__ = operator.pow.__doc__
__rtruediv__.__doc__ = operator.truediv.__doc__
__rfloordiv__.__doc__ = operator.floordiv.__doc__
__rmod__.__doc__ = operator.mod.__doc__
__rand__.__doc__ = operator.and_.__doc__
__rxor__.__doc__ = operator.xor.__doc__
__ror__.__doc__ = operator.or_.__doc__
__iadd__.__doc__ = operator.iadd.__doc__
__isub__.__doc__ = operator.isub.__doc__
__imul__.__doc__ = operator.imul.__doc__
__ipow__.__doc__ = operator.ipow.__doc__
__itruediv__.__doc__ = operator.itruediv.__doc__
__ifloordiv__.__doc__ = operator.ifloordiv.__doc__
__imod__.__doc__ = operator.imod.__doc__
__iand__.__doc__ = operator.iand.__doc__
__ixor__.__doc__ = operator.ixor.__doc__
__ior__.__doc__ = operator.ior.__doc__
__ilshift__.__doc__ = operator.ilshift.__doc__
__irshift__.__doc__ = operator.irshift.__doc__
__neg__.__doc__ = operator.neg.__doc__
__pos__.__doc__ = operator.pos.__doc__
__abs__.__doc__ = operator.abs.__doc__
__invert__.__doc__ = operator.invert.__doc__
24 changes: 12 additions & 12 deletions bioimageio/core/_prediction_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import warnings
from types import MappingProxyType
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union

from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter
from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats
from bioimageio.core.proc_ops import Processing
from bioimageio.core.proc_setup import setup_pre_and_postprocessing
from bioimageio.core.sample import Sample
from bioimageio.core.stat_measures import DatasetMeasure, MeasureValue
from bioimageio.core.Tensor import Tensor, TensorId
from typing import Any, Iterable, List, Mapping, Optional, Sequence, Union

from bioimageio.spec.model import AnyModelDescr, v0_4
from bioimageio.spec.model.v0_5 import WeightsFormat

from .model_adapters import ModelAdapter, create_model_adapter
from .model_adapters import get_weight_formats as get_weight_formats
from .proc_ops import Processing
from .proc_setup import setup_pre_and_postprocessing
from .sample import Sample
from .stat_measures import DatasetMeasure, MeasureValue
from .tensor import PerTensor, Tensor, TensorId


class PredictionPipeline:
"""
Expand Down Expand Up @@ -64,8 +65,7 @@ def predict(
) -> List[Optional[Tensor]]:
"""Predict input_tensor with the model without applying pre/postprocessing."""
named_tensors = [
named_input_tensors.get(str(k))
for k in self.input_ids[len(input_tensors) :]
named_input_tensors.get(k) for k in self.input_ids[len(input_tensors) :]
]
return self._adapter.forward(*input_tensors, *named_tensors)

Expand Down Expand Up @@ -99,7 +99,7 @@ def forward_sample(self, input_sample: Sample) -> Sample:

def forward_tensors(
self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor]
) -> Dict[TensorId, Tensor]:
) -> PerTensor[Tensor]:
"""Apply preprocessing, run prediction and apply postprocessing."""
assert all(TensorId(k) in self.input_ids for k in named_input_tensors)
input_sample = Sample(
Expand Down
10 changes: 5 additions & 5 deletions bioimageio/core/_resource_tests.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import traceback
import warnings
from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union
from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union

import numpy as np

from bioimageio.core._prediction_pipeline import create_prediction_pipeline
from bioimageio.core.axis import AxisId, BatchSize
from bioimageio.core.utils import VERSION, get_test_inputs, get_test_outputs
from bioimageio.core.utils.tiling import resize_to
from bioimageio.spec import (
InvalidDescr,
ResourceDescr,
Expand Down Expand Up @@ -135,7 +134,9 @@ def _test_model_inference(
error = "Output tensors for test case may not be None"
break
try:
np.testing.assert_array_almost_equal(res, exp, decimal=decimal)
np.testing.assert_array_almost_equal(
res.data, exp.data, decimal=decimal
)
except AssertionError as e:
error = f"Output and expected output disagree:\n {e}"
break
Expand Down Expand Up @@ -217,8 +218,7 @@ def get_ns(n: int):
tested.add(hashable_target_size)

resized_test_inputs = [
resize_to(
t,
t.resize_to(
{
aid: s
for (tid, aid), s in input_target_sizes.items()
Expand Down
33 changes: 24 additions & 9 deletions bioimageio/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
]


LeftRight_T = TypeVar("LeftRight_T", bound="LeftRight")
LeftRightLike = Union[int, Tuple[int, int], LeftRight_T]
_LeftRight_T = TypeVar("_LeftRight_T", bound="_LeftRight")
_LeftRightLike = Union[int, Tuple[int, int], _LeftRight_T]


class LeftRight(NamedTuple):
class _LeftRight(NamedTuple):
left: int
right: int

@classmethod
def create(cls, like: LeftRightLike[Self]) -> Self:
def create(cls, like: _LeftRightLike[Self]) -> Self:
if isinstance(like, cls):
return like
elif isinstance(like, tuple):
Expand All @@ -39,20 +39,35 @@ def create(cls, like: LeftRightLike[Self]) -> Self:
assert_never(like)


class Halo(LeftRight):
_Where = Literal["left", "right", "left_and_right"]


class CropWidth(_LeftRight):
pass


CropWidthLike = _LeftRightLike[CropWidth]
CropWhere = _Where


class Halo(_LeftRight):
pass


HaloLike = LeftRightLike[Halo]
HaloLike = _LeftRightLike[Halo]


class OverlapWidth(_LeftRight):
pass


class PadWidth(LeftRight):
class PadWidth(_LeftRight):
pass


PadWidthLike = LeftRightLike[PadWidth]
PadWidthLike = _LeftRightLike[PadWidth]
PadMode = Literal["edge", "reflect", "symmetric"]
PadWhere = Literal["before", "center", "after"]
PadWhere = _Where


class SliceInfo(NamedTuple):
Expand Down
Loading

0 comments on commit 4fa24ce

Please sign in to comment.