@@ -65,8 +65,8 @@ class NNPE(ElementwiseTransform):
65
65
def __init__ (
66
66
self ,
67
67
* ,
68
- spike_scale : float | np .ndarray | None = None ,
69
- slab_scale : float | np .ndarray | None = None ,
68
+ spike_scale : np .typing . ArrayLike | None = None ,
69
+ slab_scale : np .typing . ArrayLike | None = None ,
70
70
per_dimension : bool = True ,
71
71
seed : int | None = None ,
72
72
):
@@ -80,14 +80,14 @@ def __init__(
80
80
def _resolve_scale (
81
81
self ,
82
82
name : str ,
83
- passed : float | np .ndarray | None ,
83
+ passed : np .typing . ArrayLike | None ,
84
84
default : float ,
85
85
data : np .ndarray ,
86
86
) -> np .ndarray | float :
87
87
"""
88
88
Determine spike/slab scale:
89
- - If passed is None: Automatic determination via default * std(data) (per‐dimension or global).
90
- - Else: validate & cast passed to the correct shape/type.
89
+ - If ` passed` is None: Automatic determination via default * std(data) (per‐dimension or global).
90
+ - Else: Validate & cast ` passed` to the correct shape/type.
91
91
92
92
Parameters
93
93
----------
@@ -103,8 +103,8 @@ def _resolve_scale(
103
103
104
104
Returns
105
105
-------
106
- float or np.ndarray
107
- The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1]
106
+ np.ndarray
107
+ The resolved scale, either as a 0D array (if per_dimension=False) or an 1D array of length data.shape[-1]
108
108
(if per_dimension=True).
109
109
"""
110
110
@@ -119,22 +119,22 @@ def _resolve_scale(
119
119
120
120
# If no scale is passed, determine scale automatically given the dimensionwise or global std
121
121
if passed is None :
122
- return default * std
122
+ return np . array ( default * std )
123
123
# If a scale is passed, check if the passed shape matches the expected shape
124
124
else :
125
- if self . per_dimension :
125
+ try :
126
126
arr = np .asarray (passed , dtype = float )
127
- if arr .shape != expected_shape or arr .ndim != 1 :
127
+ except Exception as e :
128
+ raise TypeError (f"{ name } : expected values convertible to float, got { type (passed ).__name__ } " ) from e
129
+
130
+ if self .per_dimension :
131
+ if arr .ndim != 1 or arr .shape != expected_shape :
128
132
raise ValueError (f"{ name } : expected array of shape { expected_shape } , got { arr .shape } " )
129
133
return arr
130
134
else :
131
- try :
132
- scalar = float (passed )
133
- except TypeError :
134
- raise TypeError (f"{ name } : expected a scalar convertible to float, got type { type (passed ).__name__ } " )
135
- except ValueError :
136
- raise ValueError (f"{ name } : expected a scalar convertible to float, got value { passed !r} " )
137
- return scalar
135
+ if arr .ndim != 0 :
136
+ raise ValueError (f"{ name } : expected scalar, got array of shape { arr .shape } " )
137
+ return arr
138
138
139
139
def forward (self , data : np .ndarray , stage : str = "inference" , ** kwargs ) -> np .ndarray :
140
140
"""
@@ -173,7 +173,7 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd
173
173
return data + noise
174
174
175
175
def inverse (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
176
- """ Non-invertible transform."""
176
+ # Non-invertible transform
177
177
return data
178
178
179
179
def get_config (self ) -> dict :
0 commit comments