Skip to content

Commit acbd194

Browse files
committed
move compute measure to Measure classes
1 parent dc783e6 commit acbd194

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

bioimageio/core/prediction_pipeline/_combined_processing.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,7 @@ def compute_sample_statistics(
158158
def _compute_tensor_statistics(tensor: xr.DataArray, measures: Set[Measure]) -> Dict[Measure, Any]:
159159
ret = {}
160160
for measure in measures:
161-
if isinstance(measure, Mean):
162-
v = tensor.mean(dim=measure.axes)
163-
elif isinstance(measure, Std):
164-
v = tensor.std(dim=measure.axes)
165-
elif isinstance(measure, Percentile):
166-
v = tensor.quantile(measure.n / 100.0, dim=measure.axes)
167-
else:
168-
raise NotImplementedError(measure)
169-
170-
ret[measure] = v
161+
ret[measure] = measure.compute(tensor)
171162

172163
return ret
173164

bioimageio/core/prediction_pipeline/_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
250250
mean, std = self.mean, self.std
251251
elif self.mode == "per_sample":
252252
if axes:
253-
mean, std = tensor.mean(axes), tensor.std(axes)
253+
mean, std = Mean(axes).compute(tensor), Std(axes).compute(tensor)
254254
else:
255255
mean, std = tensor.mean(), tensor.std()
256256
elif self.mode == "per_dataset":

bioimageio/core/statistical_measures.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
from dataclasses import dataclass
22
from typing import Optional, Tuple
33

4+
import xarray as xr
5+
46

57
@dataclass(frozen=True)
68
class Measure:
7-
pass
9+
def compute(self, tensor: xr.DataArray):
10+
raise NotImplementedError(self.__class__.__name__)
811

912

1013
@dataclass(frozen=True)
1114
class Mean(Measure):
1215
axes: Optional[Tuple[str]] = None
1316

17+
def compute(self, tensor: xr.DataArray) -> xr.DataArray:
18+
return tensor.mean(dim=self.axes)
19+
1420

1521
@dataclass(frozen=True)
1622
class Std(Measure):
1723
axes: Optional[Tuple[str]] = None
1824

25+
def compute(self, tensor: xr.DataArray) -> xr.DataArray:
26+
return tensor.std(dim=self.axes)
27+
1928

2029
@dataclass(frozen=True)
2130
class Percentile(Measure):
@@ -25,3 +34,6 @@ class Percentile(Measure):
2534
def __post_init__(self):
2635
assert self.n >= 0
2736
assert self.n <= 100
37+
38+
def compute(self, tensor: xr.DataArray) -> xr.DataArray:
39+
return tensor.quantile(self.n / 100.0, dim=self.axes)

0 commit comments

Comments
 (0)