Skip to content

Commit d5c26e8

Browse files
committed
fix fixed
1 parent db6cdd0 commit d5c26e8

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

bioimageio/core/prediction_pipeline/_processing.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212
from typing_extensions import Literal, get_args # type: ignore
1313

1414

15+
def _get_fixed(
16+
fixed: Union[float, Sequence[float]], tensor: xr.DataArray, axes: Optional[Sequence[str]]
17+
) -> Union[float, xr.DataArray]:
18+
if axes is None:
19+
return fixed
20+
21+
fixed_shape = tuple(s for d, s in tensor.sizes.items() if d not in axes)
22+
fixed_dims = tuple(d for d in tensor.dims if d not in axes)
23+
fixed = np.array(fixed).reshape(fixed_shape)
24+
return xr.DataArray(fixed, dims=fixed_dims)
25+
26+
1527
@dataclass
1628
class Processing:
1729
"""base class for all Pre- and Postprocessing transformations"""
@@ -226,8 +238,8 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
226238
@dataclass
227239
class ZeroMeanUnitVariance(Processing):
228240
mode: Literal["fixed", "per_sample", "per_dataset"] = "per_sample"
229-
mean: Optional[float] = None
230-
std: Optional[float] = None
241+
mean: Optional[Union[float, Sequence[float]]] = None
242+
std: Optional[Union[float, Sequence[float]]] = None
231243
axes: Optional[Sequence[str]] = None
232244
eps: float = 1.0e-6
233245

@@ -247,12 +259,11 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
247259
axes = None if self.axes is None else tuple(self.axes)
248260
if self.mode == "fixed":
249261
assert self.mean is not None and self.std is not None
250-
mean, std = self.mean, self.std
262+
mean = _get_fixed(self.mean, tensor, axes)
263+
std = _get_fixed(self.std, tensor, axes)
251264
elif self.mode == "per_sample":
252-
if axes:
253-
mean, std = Mean(axes).compute(tensor), Std(axes).compute(tensor)
254-
else:
255-
mean, std = tensor.mean(), tensor.std()
265+
mean = Mean(axes).compute(tensor)
266+
std = Std(axes).compute(tensor)
256267
elif self.mode == "per_dataset":
257268
mean = self.get_computed_dataset_statistics(self.tensor_name, Mean(axes))
258269
std = self.get_computed_dataset_statistics(self.tensor_name, Std(axes))

0 commit comments

Comments
 (0)