Skip to content

Commit 1d4f646

Browse files
authored
Implement changes to NNPE adapter for #510 (#514)
* Move docstring to comment * Always cast to _resolve_scale * Fix typo
1 parent 990df1e commit 1d4f646

File tree

1 file changed

+18
-18
lines changed
  • bayesflow/adapters/transforms

1 file changed

+18
-18
lines changed

bayesflow/adapters/transforms/nnpe.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ class NNPE(ElementwiseTransform):
6565
def __init__(
6666
self,
6767
*,
68-
spike_scale: float | np.ndarray | None = None,
69-
slab_scale: float | np.ndarray | None = None,
68+
spike_scale: np.typing.ArrayLike | None = None,
69+
slab_scale: np.typing.ArrayLike | None = None,
7070
per_dimension: bool = True,
7171
seed: int | None = None,
7272
):
@@ -80,14 +80,14 @@ def __init__(
8080
def _resolve_scale(
8181
self,
8282
name: str,
83-
passed: float | np.ndarray | None,
83+
passed: np.typing.ArrayLike | None,
8484
default: float,
8585
data: np.ndarray,
8686
) -> np.ndarray | float:
8787
"""
8888
Determine spike/slab scale:
89-
- If passed is None: Automatic determination via default * std(data) (per‐dimension or global).
90-
- Else: validate & cast passed to the correct shape/type.
89+
- If `passed` is None: Automatic determination via default * std(data) (per‐dimension or global).
90+
- Else: Validate & cast `passed` to the correct shape/type.
9191
9292
Parameters
9393
----------
@@ -103,8 +103,8 @@ def _resolve_scale(
103103
104104
Returns
105105
-------
106-
float or np.ndarray
107-
The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1]
106+
np.ndarray
107+
The resolved scale, either as a 0D array (if per_dimension=False) or an 1D array of length data.shape[-1]
108108
(if per_dimension=True).
109109
"""
110110

@@ -119,22 +119,22 @@ def _resolve_scale(
119119

120120
# If no scale is passed, determine scale automatically given the dimensionwise or global std
121121
if passed is None:
122-
return default * std
122+
return np.array(default * std)
123123
# If a scale is passed, check if the passed shape matches the expected shape
124124
else:
125-
if self.per_dimension:
125+
try:
126126
arr = np.asarray(passed, dtype=float)
127-
if arr.shape != expected_shape or arr.ndim != 1:
127+
except Exception as e:
128+
raise TypeError(f"{name}: expected values convertible to float, got {type(passed).__name__}") from e
129+
130+
if self.per_dimension:
131+
if arr.ndim != 1 or arr.shape != expected_shape:
128132
raise ValueError(f"{name}: expected array of shape {expected_shape}, got {arr.shape}")
129133
return arr
130134
else:
131-
try:
132-
scalar = float(passed)
133-
except TypeError:
134-
raise TypeError(f"{name}: expected a scalar convertible to float, got type {type(passed).__name__}")
135-
except ValueError:
136-
raise ValueError(f"{name}: expected a scalar convertible to float, got value {passed!r}")
137-
return scalar
135+
if arr.ndim != 0:
136+
raise ValueError(f"{name}: expected scalar, got array of shape {arr.shape}")
137+
return arr
138138

139139
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
140140
"""
@@ -173,7 +173,7 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd
173173
return data + noise
174174

175175
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
176-
"""Non-invertible transform."""
176+
# Non-invertible transform
177177
return data
178178

179179
def get_config(self) -> dict:

0 commit comments

Comments
 (0)