Skip to content

Implement changes to NNPE adapter for #510 #514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions bayesflow/adapters/transforms/nnpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
def __init__(
self,
*,
spike_scale: float | np.ndarray | None = None,
slab_scale: float | np.ndarray | None = None,
spike_scale: np.typing.ArrayLike | None = None,
slab_scale: np.typing.ArrayLike | None = None,
per_dimension: bool = True,
seed: int | None = None,
):
Expand All @@ -80,14 +80,14 @@
def _resolve_scale(
self,
name: str,
passed: float | np.ndarray | None,
passed: np.typing.ArrayLike | None,
default: float,
data: np.ndarray,
) -> np.ndarray | float:
"""
Determine spike/slab scale:
- If passed is None: Automatic determination via default * std(data) (per‐dimension or global).
- Else: validate & cast passed to the correct shape/type.
- If `passed` is None: Automatic determination via default * std(data) (per‐dimension or global).
- Else: Validate & cast `passed` to the correct shape/type.

Parameters
----------
Expand All @@ -103,8 +103,8 @@

Returns
-------
float or np.ndarray
The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1]
np.ndarray
The resolved scale, either as a 0D array (if per_dimension=False) or an 1D array of length data.shape[-1]
(if per_dimension=True).
"""

Expand All @@ -119,22 +119,22 @@

# If no scale is passed, determine scale automatically given the dimensionwise or global std
if passed is None:
return default * std
return np.array(default * std)
# If a scale is passed, check if the passed shape matches the expected shape
else:
if self.per_dimension:
try:
arr = np.asarray(passed, dtype=float)
if arr.shape != expected_shape or arr.ndim != 1:
except Exception as e:
raise TypeError(f"{name}: expected values convertible to float, got {type(passed).__name__}") from e

Check warning on line 128 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L127-L128

Added lines #L127 - L128 were not covered by tests

if self.per_dimension:
if arr.ndim != 1 or arr.shape != expected_shape:
raise ValueError(f"{name}: expected array of shape {expected_shape}, got {arr.shape}")
return arr
else:
try:
scalar = float(passed)
except TypeError:
raise TypeError(f"{name}: expected a scalar convertible to float, got type {type(passed).__name__}")
except ValueError:
raise ValueError(f"{name}: expected a scalar convertible to float, got value {passed!r}")
return scalar
if arr.ndim != 0:
raise ValueError(f"{name}: expected scalar, got array of shape {arr.shape}")

Check warning on line 136 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L136

Added line #L136 was not covered by tests
return arr

def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
"""
Expand Down Expand Up @@ -173,7 +173,7 @@
return data + noise

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""Non-invertible transform."""
# Non-invertible transform
return data

def get_config(self) -> dict:
Expand Down