12
12
from typing_extensions import Literal , get_args # type: ignore
13
13
14
14
15
+ def _get_fixed (
16
+ fixed : Union [float , Sequence [float ]], tensor : xr .DataArray , axes : Optional [Sequence [str ]]
17
+ ) -> Union [float , xr .DataArray ]:
18
+ if axes is None :
19
+ return fixed
20
+
21
+ fixed_shape = tuple (s for d , s in tensor .sizes .items () if d not in axes )
22
+ fixed_dims = tuple (d for d in tensor .dims if d not in axes )
23
+ fixed = np .array (fixed ).reshape (fixed_shape )
24
+ return xr .DataArray (fixed , dims = fixed_dims )
25
+
26
+
15
27
@dataclass
16
28
class Processing :
17
29
"""base class for all Pre- and Postprocessing transformations"""
@@ -226,8 +238,8 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
226
238
@dataclass
227
239
class ZeroMeanUnitVariance (Processing ):
228
240
mode : Literal ["fixed" , "per_sample" , "per_dataset" ] = "per_sample"
229
- mean : Optional [float ] = None
230
- std : Optional [float ] = None
241
+ mean : Optional [Union [ float , Sequence [ float ]] ] = None
242
+ std : Optional [Union [ float , Sequence [ float ]] ] = None
231
243
axes : Optional [Sequence [str ]] = None
232
244
eps : float = 1.0e-6
233
245
@@ -247,12 +259,11 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
247
259
axes = None if self .axes is None else tuple (self .axes )
248
260
if self .mode == "fixed" :
249
261
assert self .mean is not None and self .std is not None
250
- mean , std = self .mean , self .std
262
+ mean = _get_fixed (self .mean , tensor , axes )
263
+ std = _get_fixed (self .std , tensor , axes )
251
264
elif self .mode == "per_sample" :
252
- if axes :
253
- mean , std = Mean (axes ).compute (tensor ), Std (axes ).compute (tensor )
254
- else :
255
- mean , std = tensor .mean (), tensor .std ()
265
+ mean = Mean (axes ).compute (tensor )
266
+ std = Std (axes ).compute (tensor )
256
267
elif self .mode == "per_dataset" :
257
268
mean = self .get_computed_dataset_statistics (self .tensor_name , Mean (axes ))
258
269
std = self .get_computed_dataset_statistics (self .tensor_name , Std (axes ))
0 commit comments