diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 2692cc70..c3cb1db6 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -23,11 +23,13 @@ from ._resource_tests import test_description as test_description from ._resource_tests import test_model as test_model from ._settings import settings as settings -from .utils import VERSION +from .axis import Axis as Axis +from .axis import AxisId as AxisId +from .sample import Sample as Sample from .tensor import Tensor as Tensor +from .tensor import TensorId as TensorId from .tile import Tile as Tile -from .sample import Sample as Sample - +from .utils import VERSION __version__ = VERSION diff --git a/bioimageio/core/_magic_tensor_ops.py b/bioimageio/core/_magic_tensor_ops.py new file mode 100644 index 00000000..c1526fef --- /dev/null +++ b/bioimageio/core/_magic_tensor_ops.py @@ -0,0 +1,235 @@ +# this file was modified from the generated +# https://github.com/pydata/xarray/blob/cf3655968b8b12cc0ecd28fb324e63fb94d5e7e2/xarray/core/_typed_ops.py +# TODO: should we generate this ourselves? +# TODO: test these magic methods +import operator +from typing import Any, Callable + +from typing_extensions import Self +from xarray.core import nputils, ops + + +class MagicTensorOpsMixin: + __slots__ = () + _Compatible = Any + + def _binary_op( + self, + other: _Compatible, + f: Callable[[Any, Any], Any], + reflexive: bool = False, + ) -> Self: + raise NotImplementedError + + def __add__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.add) + + def __sub__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.sub) + + def __mul__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mul) + + def __pow__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.pow) + + def __truediv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.truediv) + + def __floordiv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.floordiv) + + def __mod__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mod) + + def __and__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.and_) + + def __xor__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.xor) + + def __or__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.or_) + + def __lshift__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.rshift) + + def __lt__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.lt) + + def __le__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.le) + + def __gt__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.gt) + + def __ge__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.ge) + + def __eq__(self, other: _Compatible) -> Self: # type: ignore[override] + return self._binary_op( + other, nputils.array_eq # pyright: ignore[reportUnknownArgumentType] + ) + + def __ne__(self, other: _Compatible) -> Self: # type: ignore[override] + return self._binary_op( + other, nputils.array_ne # pyright: ignore[reportUnknownArgumentType] + ) + + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment] + + def __radd__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.add, reflexive=True) + + def __rsub__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.sub, reflexive=True) + + def __rmul__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mul, reflexive=True) + + def __rpow__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.pow, reflexive=True) + + def __rtruediv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.truediv, reflexive=True) + + def __rfloordiv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.floordiv, reflexive=True) + + def __rmod__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mod, reflexive=True) + + def __rand__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.and_, reflexive=True) + + def __rxor__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.xor, reflexive=True) + + def __ror__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.or_, reflexive=True) + + def _inplace_binary_op( + self, other: _Compatible, f: Callable[[Any, Any], Any] + ) -> Self: + raise NotImplementedError + + def __iadd__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.iadd) + + def __isub__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.isub) + + def __imul__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.imul) + + def __ipow__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ipow) + + def __itruediv__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.itruediv) + + def __ifloordiv__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ifloordiv) + + def __imod__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.imod) + + def __iand__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.iand) + + def __ixor__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ixor) + + def __ior__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ior) + + def __ilshift__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ilshift) + + def __irshift__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.irshift) + + def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError + + def __neg__(self) -> Self: + return self._unary_op(operator.neg) + + def __pos__(self) -> Self: + return self._unary_op(operator.pos) + + def __abs__(self) -> Self: + return self._unary_op(operator.abs) + + def __invert__(self) -> Self: + return self._unary_op(operator.invert) + + def round(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.round_, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + def argsort(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.argsort, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + def conj(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.conj, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + def conjugate(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.conjugate, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + __add__.__doc__ = operator.add.__doc__ + __sub__.__doc__ = operator.sub.__doc__ + __mul__.__doc__ = operator.mul.__doc__ + __pow__.__doc__ = operator.pow.__doc__ + __truediv__.__doc__ = operator.truediv.__doc__ + __floordiv__.__doc__ = operator.floordiv.__doc__ + __mod__.__doc__ = operator.mod.__doc__ + __and__.__doc__ = operator.and_.__doc__ + __xor__.__doc__ = operator.xor.__doc__ + __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ + __lt__.__doc__ = operator.lt.__doc__ + __le__.__doc__ = operator.le.__doc__ + __gt__.__doc__ = operator.gt.__doc__ + __ge__.__doc__ = operator.ge.__doc__ + __eq__.__doc__ = nputils.array_eq.__doc__ + __ne__.__doc__ = nputils.array_ne.__doc__ + __radd__.__doc__ = operator.add.__doc__ + __rsub__.__doc__ = operator.sub.__doc__ + __rmul__.__doc__ = operator.mul.__doc__ + __rpow__.__doc__ = operator.pow.__doc__ + __rtruediv__.__doc__ = operator.truediv.__doc__ + __rfloordiv__.__doc__ = operator.floordiv.__doc__ + __rmod__.__doc__ = operator.mod.__doc__ + __rand__.__doc__ = operator.and_.__doc__ + __rxor__.__doc__ = operator.xor.__doc__ + __ror__.__doc__ = operator.or_.__doc__ + __iadd__.__doc__ = operator.iadd.__doc__ + __isub__.__doc__ = operator.isub.__doc__ + __imul__.__doc__ = operator.imul.__doc__ + __ipow__.__doc__ = operator.ipow.__doc__ + __itruediv__.__doc__ = operator.itruediv.__doc__ + __ifloordiv__.__doc__ = operator.ifloordiv.__doc__ + __imod__.__doc__ = operator.imod.__doc__ + __iand__.__doc__ = operator.iand.__doc__ + __ixor__.__doc__ = operator.ixor.__doc__ + __ior__.__doc__ = operator.ior.__doc__ + __ilshift__.__doc__ = operator.ilshift.__doc__ + __irshift__.__doc__ = operator.irshift.__doc__ + __neg__.__doc__ = operator.neg.__doc__ + __pos__.__doc__ = operator.pos.__doc__ + __abs__.__doc__ = operator.abs.__doc__ + __invert__.__doc__ = operator.invert.__doc__ diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 33c83eb3..56a75407 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -1,17 +1,18 @@ import warnings from types import MappingProxyType -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union - -from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter -from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats -from bioimageio.core.proc_ops import Processing -from bioimageio.core.proc_setup import setup_pre_and_postprocessing -from bioimageio.core.sample import Sample -from bioimageio.core.stat_measures import DatasetMeasure, MeasureValue -from bioimageio.core.Tensor import Tensor, TensorId +from typing import Any, Iterable, List, Mapping, Optional, Sequence, Union + from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.model.v0_5 import WeightsFormat +from .model_adapters import ModelAdapter, create_model_adapter +from .model_adapters import get_weight_formats as get_weight_formats +from .proc_ops import Processing +from .proc_setup import setup_pre_and_postprocessing +from .sample import Sample +from .stat_measures import DatasetMeasure, MeasureValue +from .tensor import PerTensor, Tensor, TensorId + class PredictionPipeline: """ @@ -64,8 +65,7 @@ def predict( ) -> List[Optional[Tensor]]: """Predict input_tensor with the model without applying pre/postprocessing.""" named_tensors = [ - named_input_tensors.get(str(k)) - for k in self.input_ids[len(input_tensors) :] + named_input_tensors.get(k) for k in self.input_ids[len(input_tensors) :] ] return self._adapter.forward(*input_tensors, *named_tensors) @@ -99,7 +99,7 @@ def forward_sample(self, input_sample: Sample) -> Sample: def forward_tensors( self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] - ) -> Dict[TensorId, Tensor]: + ) -> PerTensor[Tensor]: """Apply preprocessing, run prediction and apply postprocessing.""" assert all(TensorId(k) in self.input_ids for k in named_input_tensors) input_sample = Sample( diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 82bd316b..d14e836e 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -1,13 +1,12 @@ import traceback import warnings -from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union +from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union import numpy as np from bioimageio.core._prediction_pipeline import create_prediction_pipeline from bioimageio.core.axis import AxisId, BatchSize from bioimageio.core.utils import VERSION, get_test_inputs, get_test_outputs -from bioimageio.core.utils.tiling import resize_to from bioimageio.spec import ( InvalidDescr, ResourceDescr, @@ -135,7 +134,9 @@ def _test_model_inference( error = "Output tensors for test case may not be None" break try: - np.testing.assert_array_almost_equal(res, exp, decimal=decimal) + np.testing.assert_array_almost_equal( + res.data, exp.data, decimal=decimal + ) except AssertionError as e: error = f"Output and expected output disagree:\n {e}" break @@ -217,8 +218,7 @@ def get_ns(n: int): tested.add(hashable_target_size) resized_test_inputs = [ - resize_to( - t, + t.resize_to( { aid: s for (tid, aid), s in input_target_sizes.items() diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 1c94c77f..5542e897 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -19,16 +19,16 @@ ] -LeftRight_T = TypeVar("LeftRight_T", bound="LeftRight") -LeftRightLike = Union[int, Tuple[int, int], LeftRight_T] +_LeftRight_T = TypeVar("_LeftRight_T", bound="_LeftRight") +_LeftRightLike = Union[int, Tuple[int, int], _LeftRight_T] -class LeftRight(NamedTuple): +class _LeftRight(NamedTuple): left: int right: int @classmethod - def create(cls, like: LeftRightLike[Self]) -> Self: + def create(cls, like: _LeftRightLike[Self]) -> Self: if isinstance(like, cls): return like elif isinstance(like, tuple): @@ -39,20 +39,35 @@ def create(cls, like: LeftRightLike[Self]) -> Self: assert_never(like) -class Halo(LeftRight): +_Where = Literal["left", "right", "left_and_right"] + + +class CropWidth(_LeftRight): + pass + + +CropWidthLike = _LeftRightLike[CropWidth] +CropWhere = _Where + + +class Halo(_LeftRight): pass -HaloLike = LeftRightLike[Halo] +HaloLike = _LeftRightLike[Halo] + + +class OverlapWidth(_LeftRight): + pass -class PadWidth(LeftRight): +class PadWidth(_LeftRight): pass -PadWidthLike = LeftRightLike[PadWidth] +PadWidthLike = _LeftRightLike[PadWidth] PadMode = Literal["edge", "reflect", "symmetric"] -PadWhere = Literal["before", "center", "after"] +PadWhere = _Where class SliceInfo(NamedTuple): diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 557e61bb..9bcab722 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -1,12 +1,9 @@ from pathlib import Path -from typing import Optional, Sequence, Union +from typing import Optional, Sequence import imageio from bioimageio.core.axis import Axis, AxisLike -from bioimageio.spec.model import v0_5 -from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr04 -from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensorDescr04 from bioimageio.spec.utils import load_array from .tensor import Tensor, TensorId @@ -27,4 +24,6 @@ def load_tensor( ) array = imageio.volread(path) if is_volume else imageio.imread(path) - return Tensor.from_numpy(array, axes, id=TensorId(path.stem) if id is None else id) + return Tensor.from_numpy( + array, dims=axes, id=TensorId(path.stem) if id is None else id + ) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 3560d61d..633ee342 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from typing import List, Optional, Sequence, Tuple, Union, final -from bioimageio.core.Tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 +from ..tensor import Tensor + WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] # Known weight formats in order of priority diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 8523f991..984d53e8 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -14,16 +14,16 @@ import numpy as np import xarray as xr -from numpy.typing import DTypeLike from typing_extensions import Self, assert_never -from bioimageio.core._op_base import Operator -from bioimageio.core.axis import ( - AxisId, -) -from bioimageio.core.sample import Sample -from bioimageio.core.stat_calculators import StatsCalculator -from bioimageio.core.stat_measures import ( +from bioimageio.core.common import DTypeStr +from bioimageio.spec.model import v0_4, v0_5 + +from ._op_base import Operator +from .axis import AxisId +from .sample import Sample +from .stat_calculators import StatsCalculator +from .stat_measures import ( DatasetMean, DatasetMeasure, DatasetPercentile, @@ -37,8 +37,7 @@ Stat, StdMeasure, ) -from bioimageio.core.Tensor import Tensor, TensorId -from bioimageio.spec.model import v0_4, v0_5 +from .tensor import Tensor, TensorId def convert_axis_ids( @@ -169,7 +168,7 @@ class Binarize(_SimpleOperator): threshold: Union[float, Sequence[float]] axis: Optional[AxisId] = None - def _apply(self, input: Tensor, stat: Stat) -> xr.DataArray: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input > self.threshold @classmethod @@ -221,18 +220,14 @@ def from_proc_descr( @dataclass class EnsureDtype(_SimpleOperator): - dtype: DTypeLike + dtype: DTypeStr @classmethod def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, tensor_id: TensorId): return cls(input=tensor_id, output=tensor_id, dtype=descr.kwargs.dtype) def get_descr(self): - return v0_5.EnsureDtypeDescr( - kwargs=v0_5.EnsureDtypeKwargs( - dtype=str(self.dtype) # pyright: ignore[reportArgumentType] - ) - ) + return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input.astype(self.dtype) @@ -377,17 +372,17 @@ def __post_init__( ): if lower_percentile is None: tid = self.input if upper_percentile is None else upper_percentile.tensor_id - self.lower = DatasetPercentile(n=0, tensor_id=tid) + self.lower = DatasetPercentile(q=0.0, tensor_id=tid) else: self.lower = lower_percentile if upper_percentile is None: - self.upper = DatasetPercentile(n=100, tensor_id=self.lower.tensor_id) + self.upper = DatasetPercentile(q=1.0, tensor_id=self.lower.tensor_id) else: self.upper = upper_percentile assert self.lower.tensor_id == self.upper.tensor_id - assert self.lower.n < self.upper.n + assert self.lower.q < self.upper.q assert self.lower.axes == self.upper.axes @property @@ -416,14 +411,14 @@ def from_proc_descr( input=tensor_id, output=tensor_id, lower_percentile=Percentile( - n=kwargs.min_percentile, axes=axes, tensor_id=ref_tensor + q=kwargs.min_percentile / 100, axes=axes, tensor_id=ref_tensor ), upper_percentile=Percentile( - n=kwargs.max_percentile, axes=axes, tensor_id=ref_tensor + q=kwargs.max_percentile / 100, axes=axes, tensor_id=ref_tensor ), ) - def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: lower = stat[self.lower] upper = stat[self.upper] return (input - lower) / (upper - lower + self.eps) @@ -435,8 +430,8 @@ def get_descr(self): return v0_5.ScaleRangeDescr( kwargs=v0_5.ScaleRangeKwargs( axes=self.lower.axes, - min_percentile=self.lower.n, - max_percentile=self.upper.n, + min_percentile=self.lower.q * 100, + max_percentile=self.upper.q * 100, eps=self.eps, reference_tensor=self.lower.tensor_id, ) @@ -503,7 +498,7 @@ def from_proc_descr( std=Std(axes=axes, tensor_id=tensor_id), ) - def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: mean = stat[self.mean] std = stat[self.std] return (input - mean) / (std + self.eps) @@ -565,7 +560,7 @@ def get_descr(self): return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) - def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: return (input - self.mean) / (self.std + self.eps) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 82f2fc75..aed8b633 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -1,15 +1,16 @@ from dataclasses import dataclass, field -from typing import Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union, cast +from pprint import pformat +from typing import Dict, Iterable, Iterator, Optional, Tuple, cast import numpy +import xarray as xr from typing_extensions import Self -from xarray.core.utils import Frozen from .axis import AxisId, PerAxis -from .common import Halo, HaloLike, PadMode, PadWidth, SliceInfo, TileNumber +from .common import Halo, HaloLike, PadMode, SliceInfo, TileNumber from .stat_measures import Stat from .tensor import PerTensor, Tensor, TensorId -from .tile import Tile, tile_tensor +from .tile import Tile TiledSample = Iterable[Tile] """A dataset sample split into tiles""" @@ -19,7 +20,7 @@ class Sample: """A dataset sample""" - data: PerTensor[Tensor] + data: Dict[TensorId, Tensor] """the sample's tensors""" stat: Stat = field(default_factory=dict) @@ -32,79 +33,90 @@ def sizes(self) -> PerTensor[PerAxis[int]]: def tile( self, tile_sizes: PerTensor[PerAxis[int]], - minimum_halo: PerTensor[PerAxis[HaloLike]], + halo: PerTensor[PerAxis[HaloLike]], + pad_mode: PadMode, ) -> TiledSample: assert not ( missing := [t for t in tile_sizes if t not in self.data] ), f"`tile_sizes` specified for missing tensors: {missing}" assert not ( - missing := [t for t in minimum_halo if t not in tile_sizes] - ), f"`minimum_halo` specified for tensors without `tile_sizes`: {missing}" - - tensor_ids = list(tile_sizes) + missing := [t for t in halo if t not in tile_sizes] + ), f"`halo` specified for tensors without `tile_sizes`: {missing}" + + # any axis not given in `tile_sizes` is treated + # as tile size equal to the tensor axis' size + explicit_tile_sizes = { + t: {a: tile_sizes.get(t, {}).get(a, s) for a, s in tdata.sizes.items()} + for t, tdata in self.data.items() + } + + tensor_ids = tuple(self.data) + broadcasted_tensors = { + t: Tensor.from_xarray(d) + for t, d in zip( + tensor_ids, xr.broadcast(*(self.data[tt].data for tt in tensor_ids)) + ) + } - tensor_tile_generators: Dict[ - TensorId, Iterable[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]]] + tile_iterators: Dict[ + TensorId, Iterator[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]]] ] = {} - n_tiles: Dict[TensorId, int] = {} + + n_tiles_common = 1 + last_non_trivial: Optional[TensorId] = None for t in tensor_ids: - n_tiles[t], tensor_tile_generators[t] = tile_tensor( - self.data[t], - tile_sizes=tile_sizes.get(t, self.data[t].sizes), - minimum_halo=minimum_halo.get(t, {a: 0 for a in self.data[t].dims}), + n_tiles, generator = broadcasted_tensors[t].tile( + tile_size=explicit_tile_sizes[t], + halo=halo.get(t, {}), pad_mode=pad_mode, ) - - n_tiles_common: Optional[int] = None - single_tile_tensors: Dict[TensorId, Tuple[TensorTilePos, Tensor]] = {} - tile_iterators: Dict[TensorId, Iterator[Tuple[int, TensorTilePos, Tensor]]] = {} - for t, n in n_tiles.items(): - tile_iterator = iter(tensor_tile_generators[t]) - if n == 1: - t0, pos, tensor_tile = next(tile_iterator) - assert t0 == 0 - single_tile_tensors[t] = (pos, tensor_tile) - continue - - if n_tiles_common is None: - n_tiles_common = n - elif n != n_tiles_common: + tile_iterators[t] = iter(generator) + if n_tiles in (1, n_tiles_common): + pass + elif n_tiles_common == 1: + last_non_trivial = t + n_tiles_common = n_tiles + else: + assert last_non_trivial is not None + mismatch = { + last_non_trivial: { + "original sizes": self.data[last_non_trivial].sizes, + "broadcasted sizes": broadcasted_tensors[ + last_non_trivial + ].sizes, + "n_tiles": n_tiles_common, + }, + t: { + "original sizes": self.data[t].sizes, + "broadcasted sizes": broadcasted_tensors[t].sizes, + "n_tiles": n_tiles, + }, + } raise ValueError( - f"{self} tiled by {tile_sizes} yields different numbers of tiles: {n_tiles}" + f"broadcasted tensors {last_non_trivial, t} do not tile to the same" + + f" number of tiles {n_tiles_common, n_tiles}. Details\n" + + pformat(mismatch) ) - tile_iterators[t] = tile_iterator - - if n_tiles_common is None: - assert not tile_iterators - n_tiles_common = 1 - - for t in range(n_tiles_common): + for i in range(n_tiles_common): data: Dict[TensorId, Tensor] = {} - tile_pos: TilePos = {} - inner_slice: TileSlice = {} - outer_slice: TileSlice = {} - for t, (tensor_tile, tensor_pos) in single_tile_tensors.items(): - data[t] = tensor_tile - tile_pos[t] = tensor_pos - inner_slice[t] = inner_tensor_slice - outer_slice[t] = outer_tensor_slice - - for t, tile_iterator in tile_iterators.items(): - assert t not in data - assert t not in tile_pos - _t, tensor_pos, tensor_tile = next(tile_iterator) - assert _t == t, (_t, t) + inner_slice: Dict[TensorId, PerAxis[SliceInfo]] = {} + for t, iterator in tile_iterators.items(): + tn, tensor_tile, tensor_slice = next(iterator) + assert tn == i, f"expected tile number {i}, but got {tn}" data[t] = tensor_tile - tile_pos[t] = tensor_pos + inner_slice[t] = tensor_slice yield Tile( data=data, - pos=tile_pos, inner_slice=inner_slice, - outer_slice=outer_slice, - tile_number=t, - tiles_in_self=n_tiles_common, + halo={ + t: {a: Halo.create(h) for a, h in th.items()} + for t, th in halo.items() + }, + sample_sizes=self.sizes, + tile_number=i, + tiles_in_sample=n_tiles_common, stat=self.stat, ) @@ -113,7 +125,7 @@ def from_tiles( cls, tiles: Iterable[Tile], *, fill_value: float = float("nan") ) -> Self: # TODO: add `mode: Literal['in-memory', 'to-disk']` or similar to save out of mem samples - data: TileData = {} + data: PerTensor[Tensor] = {} stat: Stat = {} for tile in tiles: for t, tile_data in tile.inner_data.items(): diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 851dfba6..380c41b1 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -26,11 +26,9 @@ from numpy.typing import NDArray from typing_extensions import assert_never -from bioimageio.core.axis import ( - AxisId, -) -from bioimageio.core.sample import Sample -from bioimageio.core.stat_measures import ( +from .axis import AxisId +from .sample import Sample +from .stat_measures import ( DatasetMean, DatasetMeasure, DatasetMeasureBase, @@ -45,7 +43,7 @@ SampleStd, SampleVar, ) -from bioimageio.core.Tensor import TensorId +from .tensor import Tensor, TensorId try: import crick @@ -70,7 +68,7 @@ class MeanCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): super().__init__() self._n: int = 0 - self._mean: Optional[xr.DataArray] = None + self._mean: Optional[Tensor] = None self._axes = None if axes is None else tuple(axes) self._tensor_id = tensor_id self._sample_mean = SampleMean(tensor_id=self._tensor_id, axes=self._axes) @@ -79,8 +77,8 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: return {self._sample_mean: self._compute_impl(sample)} - def _compute_impl(self, sample: Sample) -> xr.DataArray: - tensor = sample.data[self._tensor_id].astype(np.float64, copy=False) + def _compute_impl(self, sample: Sample) -> Tensor: + tensor = sample.data[self._tensor_id].astype("float64", copy=False) return tensor.mean(dim=self._axes) def update(self, sample: Sample) -> None: @@ -92,8 +90,8 @@ def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: self._update_impl(sample.data[self._tensor_id], mean) return {self._sample_mean: mean} - def _update_impl(self, tensor: xr.DataArray, tensor_mean: xr.DataArray): - assert tensor_mean.dtype == np.float64 + def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): + assert tensor_mean.dtype == "float64" # reduced voxel count n_b = int(np.prod(tensor.shape) / np.prod(tensor_mean.shape)) @@ -132,7 +130,7 @@ def compute( ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: tensor = sample.data[self._tensor_id] mean = tensor.mean(dim=self._axes) - c = tensor - mean + c = (tensor - mean).data if self._axes is None: n = tensor.size else: @@ -144,12 +142,16 @@ def compute( assert isinstance(std, xr.DataArray) return { SampleMean(axes=self._axes, tensor_id=self._tensor_id): mean, - SampleVar(axes=self._axes, tensor_id=self._tensor_id): var, - SampleStd(axes=self._axes, tensor_id=self._tensor_id): std, + SampleVar(axes=self._axes, tensor_id=self._tensor_id): Tensor.from_xarray( + var + ), + SampleStd(axes=self._axes, tensor_id=self._tensor_id): Tensor.from_xarray( + std + ), } def update(self, sample: Sample): - tensor = sample.data[self._tensor_id].astype(np.float64, copy=False) + tensor = sample.data[self._tensor_id].astype("float64", copy=False) mean_b = tensor.mean(dim=self._axes) assert mean_b.dtype == np.float64 # reduced voxel count diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 83775fc9..fa928eae 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -2,15 +2,18 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, Optional, Tuple, TypeVar, Union +from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union -import xarray as xr +from .axis import AxisId +from .tensor import PerTensor, Tensor, TensorId -from bioimageio.core.axis import AxisId -from bioimageio.core.sample import Sample -from bioimageio.core.Tensor import TensorId +MeasureValue = Union[float, Tensor] -MeasureValue = Union[float, xr.DataArray] + +# using Sample Protocol really only to avoid circular imports +class SampleLike(Protocol): + @property + def data(self) -> PerTensor[Tensor]: ... @dataclass(frozen=True) @@ -21,7 +24,7 @@ class MeasureBase: @dataclass(frozen=True) class SampleMeasureBase(MeasureBase, ABC): @abstractmethod - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: """compute the measure""" ... @@ -41,7 +44,7 @@ class _Mean: class SampleMean(_Mean, SampleMeasureBase): """The mean value of a single tensor""" - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: tensor = sample.data[self.tensor_id] return tensor.mean(dim=self.axes) @@ -67,7 +70,7 @@ class _Std: class SampleStd(_Std, SampleMeasureBase): """The standard deviation of a single tensor""" - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: tensor = sample.data[self.tensor_id] return tensor.std(dim=self.axes) @@ -93,7 +96,7 @@ class _Var: class SampleVar(_Var, SampleMeasureBase): """The variance of a single tensor""" - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: tensor = sample.data[self.tensor_id] return tensor.var(dim=self.axes) @@ -111,22 +114,22 @@ def __post_init__(self): @dataclass(frozen=True) class _Percentile: - n: float + q: float axes: Optional[Tuple[AxisId, ...]] = None """`axes` to reduce""" def __post_init__(self): - assert self.n >= 0 - assert self.n <= 100 + assert self.q >= 0.0 + assert self.q <= 1.0 @dataclass(frozen=True) class SamplePercentile(_Percentile, SampleMeasureBase): """The `n`th percentile of a single tensor""" - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: tensor = sample.data[self.tensor_id] - return tensor.quantile(self.n / 100.0, dim=self.axes) + return tensor.quantile(self.q, dim=self.axes) def __post_init__(self): super().__post_init__() diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py index 26384f0e..e63380ea 100644 --- a/bioimageio/core/tensor.py +++ b/bioimageio/core/tensor.py @@ -3,7 +3,9 @@ import itertools from math import prod from typing import ( + TYPE_CHECKING, Any, + Callable, Dict, Generator, List, @@ -25,31 +27,48 @@ from bioimageio.core.axis import PerAxis from bioimageio.core.common import PadMode, PadWhere -from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model import v0_5 +from ._magic_tensor_ops import MagicTensorOpsMixin from .axis import Axis, AxisId, AxisInfo, AxisLike from .common import ( + CropWhere, DTypeStr, Halo, HaloLike, PadWidth, + PadWidthLike, SliceInfo, TileNumber, TotalNumberOfTiles, ) +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray TensorId = v0_5.TensorId T = TypeVar("T") + PerTensor = Mapping[TensorId, T] -class Tensor: +_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" + + +# TODO: make Tensor a numpy compatible array type, to use e.g. \ +# with `np.testing.assert_array_almost_equal`. +# TODO: complete docstrings +class Tensor(MagicTensorOpsMixin): + """A wrapper around an xr.DataArray for better integration with bioimageio.spec + and improved type annotations.""" + + _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] + def __init__( self, array: NDArray[Any], dims: Union[AxisId, Sequence[AxisId]], - id: TensorId, + id: Optional[TensorId] = None, ) -> None: super().__init__() self._data = xr.DataArray(array, dims=dims, name=id) @@ -63,61 +82,139 @@ def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> N key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} self._data[key] = value._data + def _binary_op( + self, + other: _Compatible, + f: Callable[[Any, Any], Any], + reflexive: bool = False, + ) -> Self: + data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] + (other._data if isinstance(other, Tensor) else other), + f, + reflexive, + ) + return self.__class__.from_xarray(data) + + def _inplace_binary_op( + self, + other: _Compatible, + f: Callable[[Any, Any], Any], + ) -> Self: + _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] + ( + other_d + if (other_d := getattr(other, "data")) is not None + and isinstance( + other_d, + xr.DataArray, + ) + else other + ), + f, + ) + return self + + def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: + data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] + f, *args, **kwargs + ) + return self.__class__.from_xarray(data) + @classmethod def from_xarray(cls, data_array: xr.DataArray) -> Self: - if data_array.name is None: - raise ValueError( - "Expected a named `data_array` to use `data_array.name` as tensor id" - ) + """create a `Tensor` from an xarray data array + note for internal use: this factory method is round-trip save + for any `Tensor`'s `data` property (an xarray.DataArray). + """ return cls( array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims), - id=TensorId(data_array.name), + id=None if data_array.name is None else TensorId(data_array.name), ) @classmethod def from_numpy( - cls, array: NDArray[Any], axes: Optional[Sequence[AxisLike]], id: TensorId + cls, + array: NDArray[Any], + *, + dims: Optional[Union[AxisLike, Sequence[AxisLike]]], + id: TensorId, ) -> Tensor: - if axes is None: + """create a `Tensor` from a numpy array + + Args: + array: the nd numpy array + axes: A description of the array's axes, + if None axes are guessed (which might fail and raise a ValueError.) + id: the created tensor's identifier + + Raises: + ValueError: if `axes` is None and axes guessing fails. + """ + + if dims is None: return cls._interprete_array_wo_known_axes(array, id=id) + elif isinstance(dims, (str, Axis, v0_5.AxisBase)): + dims = [dims] + axis_infos = [AxisInfo.create(a) for a in dims] original_shape = tuple(array.shape) - if len(array.shape) > len(axes): + if len(array.shape) > len(dims): # remove singletons for i, s in enumerate(array.shape): if s == 1: array = np.take(array, 0, axis=i) - if len(array.shape) == len(axes): + if len(array.shape) == len(dims): break # add singletons if nececsary - for a in axes: - a = AxisInfo.create(a) - if len(array.shape) >= len(axes): + for a in axis_infos: + + if len(array.shape) >= len(dims): break if a.maybe_singleton: array = array[None] - if len(array.shape) != len(axes): + if len(array.shape) != len(dims): raise ValueError( - f"Array shape {original_shape} does not map to axes {axes}" + f"Array shape {original_shape} does not map to axes {dims}" ) - normalized_axes = normalize_axes(axes) - assert len(normalized_axes) == len(axes) - return Tensor(array, dims=tuple(a.id for a in normalized_axes)) + return Tensor(array, dims=tuple(a.id for a in axis_infos), id=id) @property def data(self): return self._data @property - def dims(self): + def dims(self): # TODO: rename to `axes`? + """Tuple of dimension names associated with this tensor.""" return cast(Tuple[AxisId, ...], self._data.dims) + @property + def shape(self): + """Tuple of tensor dimension lenghts""" + return self._data.shape + + @property + def size(self): + """Number of elements in the tensor. + + Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. + """ + return self._data.size + + def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + """Reduce this Tensor's data by applying sum along some dimension(s).""" + return self.__class__.from_xarray(self._data.sum(dim=dim)) + + @property + def ndim(self): + """Number of tensor dimensions.""" + return self._data.ndim + @property def dtype(self) -> DTypeStr: dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] @@ -126,38 +223,50 @@ def dtype(self) -> DTypeStr: @property def id(self): + """the tensor's identifier""" return self._id @property def sizes(self): + """Ordered, immutable mapping from axis ids to lengths.""" return cast(Mapping[AxisId, int], self.data.sizes) + def astype(self, dtype: DTypeStr, *, copy: bool = False): + """Return tensor cast to `dtype` + + note: if dtype is already satisfied copy if `copy`""" + return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) + + def clip(self, min: Optional[float] = None, max: Optional[float] = None): + """Return a tensor whose values are limited to [min, max]. + At least one of max or min must be given.""" + return self.__class__.from_xarray(self._data.clip(min, max)) + def crop_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], + self, + sizes: PerAxis[int], crop_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", - ): - """crop `tensor` to match `sizes`""" - axes = [AxisId(str(a)) for a in tensor.dims] - if crop_where in ("before", "center", "after"): - crop_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { - a: crop_where for a in axes - } + CropWhere, + PerAxis[CropWhere], + ] = "left_and_right", + ) -> Self: + """crop to match `sizes`""" + if isinstance(crop_where, str): + crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} else: crop_axis_where = crop_where - slices: Dict[AxisId, slice] = {} + slices: Dict[AxisId, SliceInfo] = {} - for a, s_is in tensor.sizes.items(): - a = AxisId(str(a)) + for a, s_is in self.sizes.items(): if a not in sizes or sizes[a] == s_is: pass elif sizes[a] > s_is: - warnings.warn( - f"Cannot crop axis {a} of size {s_is} to larger size {sizes[a]}" + logger.warning( + "Cannot crop axis {} of size {} to larger size {}", + a, + s_is, + sizes[a], ) elif a not in crop_axis_where: raise ValueError( @@ -165,31 +274,37 @@ def crop_to( ) else: crop_this_axis_where = crop_axis_where[a] - if crop_this_axis_where == "before": - slices[a] = slice(s_is - sizes[a], s_is) - elif crop_this_axis_where == "after": - slices[a] = slice(0, sizes[a]) - elif crop_this_axis_where == "center": - slices[a] = slice(start := (s_is - sizes[a]) // 2, sizes[a] + start) + if crop_this_axis_where == "left": + slices[a] = SliceInfo(s_is - sizes[a], s_is) + elif crop_this_axis_where == "right": + slices[a] = SliceInfo(0, sizes[a]) + elif crop_this_axis_where == "left_and_right": + slices[a] = SliceInfo( + start := (s_is - sizes[a]) // 2, sizes[a] + start + ) else: assert_never(crop_this_axis_where) - return tensor.isel({str(a): s for a, s in slices.items()}) + return self[slices] - def mean(self, dim: Union[AxisId, Sequence[AxisId]]) -> Self: + def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: + return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) + + def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: return self.__class__.from_xarray(self._data.mean(dims=dim)) - def std(self, dim: Union[AxisId, Sequence[AxisId]]) -> Self: + def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: return self.__class__.from_xarray(self._data.std(dims=dim)) - def var(self, dim: Union[AxisId, Sequence[AxisId]]) -> Self: + def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: return self.__class__.from_xarray(self._data.var(dims=dim)) def pad( self, - pad_width: PerAxis[PadWidth], + pad_width: PerAxis[PadWidthLike], mode: PadMode = "symmetric", ) -> Self: + pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} return self.__class__.from_xarray( self._data.pad(pad_width=pad_width, mode=mode) ) @@ -197,7 +312,7 @@ def pad( def pad_to( self, sizes: PerAxis[int], - pad_where: Union[PadWhere, PerAxis[PadWhere]] = "center", + pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", mode: PadMode = "symmetric", ) -> Self: """pad `tensor` to match `sizes`""" @@ -224,37 +339,44 @@ def pad_to( ) else: pad_this_axis_where = pad_axis_where[a] - p = sizes[a] - s_is - if pad_this_axis_where == "before": - pad_width[a] = PadWidth(p, 0) - elif pad_this_axis_where == "after": - pad_width[a] = PadWidth(0, p) - elif pad_this_axis_where == "center": - pad_width[a] = PadWidth(left := p // 2, p - left) + d = sizes[a] - s_is + if pad_this_axis_where == "left": + pad_width[a] = PadWidth(d, 0) + elif pad_this_axis_where == "right": + pad_width[a] = PadWidth(0, d) + elif pad_this_axis_where == "left_and_right": + pad_width[a] = PadWidth(left := d // 2, d - left) else: assert_never(pad_this_axis_where) return self.pad(pad_width, mode) + def quantile( + self, q: float, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None + ) -> Self: + assert q >= 0.0 + assert q <= 1.0 + return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) + def resize_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], + self, + sizes: PerAxis[int], *, pad_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", + PadWhere, + PerAxis[PadWhere], + ] = "left_and_right", crop_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", + CropWhere, + PerAxis[CropWhere], + ] = "left_and_right", pad_mode: PadMode = "symmetric", ): - """crop and pad `tensor` to match `sizes`""" + """return cropped/padded tensor with `sizes`""" crop_to_sizes: Dict[AxisId, int] = {} pad_to_sizes: Dict[AxisId, int] = {} new_axes = dict(sizes) - for a, s_is in tensor.sizes.items(): + for a, s_is in self.sizes.items(): a = AxisId(str(a)) _ = new_axes.pop(a, None) if a not in sizes or sizes[a] == s_is: @@ -264,14 +386,15 @@ def resize_to( else: pad_to_sizes[a] = sizes[a] + tensor = self if crop_to_sizes: - tensor = crop_to(tensor, crop_to_sizes, crop_where=crop_where) + tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) if pad_to_sizes: - tensor = pad_to(tensor, pad_to_sizes, pad_where=pad_where, mode=pad_mode) + tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) if new_axes: - tensor = tensor.expand_dims({str(k): v for k, v in new_axes}) + tensor = tensor.expand_dims(new_axes) return tensor @@ -282,7 +405,7 @@ def tile( pad_mode: PadMode, ) -> Tuple[ TotalNumberOfTiles, - Generator[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]], Any, None], + Generator[Tuple[TileNumber, Tensor, PedrAxis[SliceInfo]], Any, None], ]: """tile this tensor into `tile_size` tiles that overlap by `halo`. At the tensor's edge the `halo` is padded with `pad_mode`. @@ -298,8 +421,11 @@ def tile( assert all(a in self.dims for a in tile_size), (self.dims, set(tile_size)) assert all(a in self.dims for a in halo), (self.dims, set(halo)) + # fill in default halo (0) and tile_size (tensor size) + halo = {a: Halo.create(halo.get(a, 0)) for a in self.dims} + tile_size = {a: tile_size.get(a, s) for a, s in self.sizes.items()} + inner_1d_tiles: List[List[SliceInfo]] = [] - halo = {a: Halo.create(h) for a, h in halo.items()} for a, s in self.sizes.items(): stride = tile_size[a] - sum(halo[a]) tiles_1d = [SliceInfo(p, min(s, p + stride)) for p in range(0, s, stride)] @@ -320,14 +446,14 @@ def transpose( Args: axes: the desired tensor axes """ - # expand the missing image axes - current_axes = tuple( - d if isinstance(d, AxisId) else AxisId(d) for d in tensor.dims - ) - missing_axes = tuple(a for a in axes if a not in current_axes) - tensor = tensor.expand_dims(missing_axes) + # expand missing tensor axes + missing_axes = tuple(a for a in axes if a not in self.dims) + array = self._data + if missing_axes: + array = array.expand_dims(missing_axes) + # transpose to the correct axis order - return tensor.transpose(*map(str, axes)) + return self.__class__.from_xarray(array.transpose(*axes)) @classmethod def _interprete_array_wo_known_axes(cls, array: NDArray[Any], id: TensorId): diff --git a/bioimageio/core/tile.py b/bioimageio/core/tile.py index 03703ce3..d8180af4 100644 --- a/bioimageio/core/tile.py +++ b/bioimageio/core/tile.py @@ -3,7 +3,7 @@ from bioimageio.core.common import TileNumber, TotalNumberOfTiles from .axis import PerAxis -from .common import Halo, LeftRight, PadWidth, SliceInfo +from .common import Halo, OverlapWidth, PadWidth, SliceInfo from .stat_measures import Stat from .tensor import PerTensor, Tensor @@ -36,7 +36,7 @@ class AbstractTile: local_slice: PerTensor[PerAxis[SliceInfo]] = field(init=False) """slice to extract the inner tile from the outer tile""" - overlap: PerTensor[PerAxis[LeftRight]] = field(init=False) + overlap: PerTensor[PerAxis[OverlapWidth]] = field(init=False) """overlap 'into a neighboring tile'""" padding: PerTensor[PerAxis[PadWidth]] = field(init=False) @@ -68,7 +68,7 @@ def __post_init__(self): } self.overlap = { t: { - a: LeftRight( + a: OverlapWidth( self.inner_slice[t][a].start - self.outer_slice[t][a].start, self.outer_slice[t][a].stop - self.inner_slice[t][a].stop, ) diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index d88ea113..3fe41c02 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -27,7 +27,7 @@ def get_test_inputs(model: AnyModelDescr) -> List[Tensor]: tensor_ids = [ipt.id for ipt in model.inputs] return [ - Tensor.from_numpy(arr, ax, t) + Tensor.from_numpy(arr, dims=ax, id=t) for arr, ax, t in zip(arrays, core_axes, tensor_ids) ] @@ -52,6 +52,6 @@ def get_test_outputs(model: AnyModelDescr) -> List[Tensor]: tensor_ids = [ipt.id for ipt in model.inputs] return [ - Tensor.from_numpy(arr, ax, t) + Tensor.from_numpy(arr, dims=ax, id=t) for arr, ax, t in zip(arrays, core_axes, tensor_ids) ] diff --git a/tests/test_tensor.py b/tests/test_tensor.py new file mode 100644 index 00000000..076e0961 --- /dev/null +++ b/tests/test_tensor.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest +import xarray as xr +from xarray.testing import assert_equal # pyright: ignore[reportUnknownVariableType] + +from bioimageio.core import AxisId, Tensor, TensorId + + +@pytest.mark.parametrize( + "axes", + ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"], +) +def test_transpose_tensor_2d(axes: str): + + tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None, id=TensorId("id")) + transposed = tensor.transpose([AxisId(a) for a in axes]) + assert transposed.ndim == len(axes) + + +@pytest.mark.parametrize( + "axes", + ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], +) +def test_transpose_tensor_3d(axes: str): + tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None, id=TensorId("id")) + transposed = tensor.transpose([AxisId(a) for a in axes]) + assert transposed.ndim == len(axes) + + +def test_crop_and_pad(): + tensor = Tensor.from_xarray( + xr.DataArray(np.random.rand(10, 20), dims=("x", "y"), name="id") + ) + padded = tensor.pad({AxisId("x"): 7, AxisId("y"): (3, 3)}) + cropped = padded.crop_to(tensor.sizes) + assert_equal(tensor, cropped) + + +def test_some_magic_ops(): + tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None, id=TensorId("id")) + assert tensor + 2 == 2 + tensor diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py deleted file mode 100644 index 96176f88..00000000 --- a/tests/utils/test_image_helper.py +++ /dev/null @@ -1,52 +0,0 @@ -import numpy as np -import pytest -import xarray as xr -from xarray.testing import assert_equal # pyright: ignore[reportUnknownVariableType] - -from bioimageio.core.axis import AxisId -from bioimageio.core.io import ( - interprete_array, - transpose_tensor, -) -from bioimageio.core.utils.tiling import crop_to, pad - - -@pytest.mark.parametrize( - "axes", - ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"], -) -def test_transpose_tensor_2d(axes: str): - - tensor = interprete_array(np.random.rand(256, 256), None) - transposed = transpose_tensor(tensor, [AxisId(a) for a in axes]) - assert transposed.ndim == len(axes) - - -@pytest.mark.parametrize( - "axes", - ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], -) -def test_transpose_tensor_3d(axes: str): - tensor = interprete_array(np.random.rand(64, 64, 64), None) - transposed = transpose_tensor(tensor, [AxisId(a) for a in axes]) - assert transposed.ndim == len(axes) - - -def test_crop_and_pad(): - tensor = xr.DataArray(np.random.rand(10, 20), dims=("x", "y")) - sizes = {AxisId(str(k)): v for k, v in tensor.sizes.items()} - padded = pad(tensor, {AxisId("x"): 7, AxisId("y"): (3, 3)}) - cropped = crop_to(padded, sizes) - assert_equal(tensor, cropped) - - -# def test_transform_output_tensor(): -# from bioimageio.core.utils.image_helper import transform_output_tensor - -# tensor = np.random.rand(1, 3, 64, 64, 64) -# tensor_axes = "bczyx" - -# out_ax_list = ["bczyx", "cyx", "xyc", "byxc", "zyx", "xyz"] -# for out_axes in out_ax_list: -# out = transform_output_tensor(tensor, tensor_axes, out_axes) -# assert out.ndim == len(out_axes)