Skip to content

Commit

Permalink
Merge branch 'dev' into feat-bypass-dispatch-instance
Browse files Browse the repository at this point in the history
  • Loading branch information
vpratz committed Dec 13, 2024
2 parents f9355ad + fc86d4d commit 603f3e9
Show file tree
Hide file tree
Showing 35 changed files with 1,435 additions and 435 deletions.
2 changes: 2 additions & 0 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
distributions,
networks,
simulators,
workflows,
utils,
)

from .workflows import BasicWorkflow
from .approximators import ContinuousApproximator
from .adapters import Adapter
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
Expand Down
9 changes: 9 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .transforms import (
AsSet,
AsTimeSeries,
Broadcast,
Concatenate,
Constrain,
Expand Down Expand Up @@ -112,6 +113,14 @@ def as_set(self, keys: str | Sequence[str]):
self.transforms.append(transform)
return self

def as_time_series(self, keys: str | Sequence[str]):
if isinstance(keys, str):
keys = [keys]

transform = MapTransform({key: AsTimeSeries() for key in keys})
self.transforms.append(transform)
return self

def broadcast(
self, keys: str | Sequence[str], *, to: str, expand: str | int | tuple = "left", exclude: int | tuple = -1
):
Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .as_set import AsSet
from .as_time_series import AsTimeSeries
from .broadcast import Broadcast
from .concatenate import Concatenate
from .constrain import Constrain
Expand Down
6 changes: 6 additions & 0 deletions bayesflow/adapters/transforms/as_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ class AsSet(ElementwiseTransform):
This is useful, for example, in a linear regression context where we can index
the observations in arbitrary order and always get the same regression line.
Currently, all this transform does is to ensure that the variable
arrays are at least 3D. The 2rd dimension is treated as the
set dimension and the 3rd dimension as the data dimension.
In the future, the transform will have more advanced behavior
to better ensure the correct treatment of sets.
Useage:
adapter = (
Expand Down
32 changes: 32 additions & 0 deletions bayesflow/adapters/transforms/as_time_series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np

from .elementwise_transform import ElementwiseTransform


class AsTimeSeries(ElementwiseTransform):
"""
The `.as_time_series` transform can be used to indicate that
variables shall be treated as time series.
Currently, all this transformation does is to ensure that the variable
arrays are at least 3D. The 2rd dimension is treated as the
time series dimension and the 3rd dimension as the data dimension.
In the future, the transform will have more advanced behavior
to better ensure the correct treatment of time series data.
Useage:
adapter = (
bf.Adapter()
.as_time_series(["x", "y"])
)
"""

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.atleast_3d(data)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
if data.shape[2] == 1:
return np.squeeze(data, axis=2)

return data
5 changes: 5 additions & 0 deletions bayesflow/adapters/transforms/filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def __call__(self, key: str, value: np.ndarray, inverse: bool) -> bool:

@serializable(package="bayesflow.adapters")
class FilterTransform(Transform):
"""
Implements a transform that applies a different transform on a subset of the data. Used by other transforms and
base adapter class.
"""

def __init__(
self,
*,
Expand Down
18 changes: 18 additions & 0 deletions bayesflow/adapters/transforms/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@

@serializable(package="bayesflow.adapters")
class Rename(Transform):
"""
Transform to rename keys in data dictionary. Useful to rename variables to match those required by
approximator. This transform can only rename one variable at a time.
Parameters:
- from_key: str of variable name that should be renamed
- to_key: str representing new name
Example:
adapter = (
bf.adapters.Adapter()
# rename the variables to match the required approximator inputs
.rename("theta", "inference_variables")
.rename("x", "inference_conditions")
)
"""

def __init__(self, from_key: str, to_key: str):
super().__init__()
self.from_key = from_key
Expand Down
25 changes: 14 additions & 11 deletions bayesflow/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from .plots import calibration_ecdf
from .plots import calibration_histogram
from .plots import loss
from .plots import mc_calibration
from .plots import mc_confusion_matrix
from .plots import mmd_hypothesis_test
from .plots import pairs_posterior
from .plots import pairs_prior
from .plots import pairs_samples
from .plots import recovery
from .plots import z_score_contraction
from .metrics import root_mean_squared_error, calibration_error, posterior_contraction

from .plots import (
calibration_ecdf,
calibration_histogram,
loss,
mc_calibration,
mc_confusion_matrix,
mmd_hypothesis_test,
pairs_posterior,
pairs_samples,
recovery,
z_score_contraction,
)
3 changes: 3 additions & 0 deletions bayesflow/diagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .calibration_error import calibration_error
from .posterior_contraction import posterior_contraction
from .root_mean_squared_error import root_mean_squared_error
82 changes: 82 additions & 0 deletions bayesflow/diagnostics/metrics/calibration_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Sequence, Any, Mapping, Callable

import numpy as np

from ...utils.dict_utils import dicts_to_arrays


def calibration_error(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
resolution: int = 20,
aggregation: Callable = np.median,
min_quantile: float = 0.005,
max_quantile: float = 0.995,
variable_names: Sequence[str] = None,
) -> Mapping[str, Any]:
"""Computes an aggregate score for the marginal calibration error over an ensemble of approximate
posteriors. The calibration error is given as the aggregate (e.g., median) of the absolute deviation
between an alpha-CI and the relative number of inliers from ``prior_samples`` over multiple alphas in
(0, 1).
Parameters
----------
targets : np.ndarray of shape (num_datasets, num_draws, num_variables)
The random draws from the approximate posteriors over ``num_datasets``
references : np.ndarray of shape (num_datasets, num_variables)
The corresponding ground-truth values sampled from the prior
resolution : int, optional, default: 20
The number of credibility intervals (CIs) to consider
aggregation : callable or None, optional, default: np.median
The function used to aggregate the marginal calibration errors.
If ``None`` provided, the per-alpha calibration errors will be returned.
min_quantile : float in (0, 1), optional, default: 0.005
The minimum posterior quantile to consider.
max_quantile : float in (0, 1), optional, default: 0.995
The maximum posterior quantile to consider.
variable_names : Sequence[str], optional (default = None)
Optional variable names to select from the available variables.
Returns
-------
result : dict
Dictionary containing:
- "values" : float or np.ndarray
The aggregated calibration error per variable
- "metric_name" : str
The name of the metric ("Calibration Error").
- "variable_names" : str
The (inferred) variable names.
"""

samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)

# Define alpha values and the corresponding quantile bounds
alphas = np.linspace(start=min_quantile, stop=max_quantile, num=resolution)
regions = 1 - alphas
lowers = regions / 2
uppers = 1 - lowers

# Compute quantiles for each alpha, for each dataset and parameter
quantiles = np.quantile(samples["targets"], [lowers, uppers], axis=1)

# Shape: (2, resolution, num_datasets, num_params)
lower_bounds, upper_bounds = quantiles[0], quantiles[1]

# Compute masks for inliers
lower_mask = lower_bounds <= samples["references"][None, ...]
upper_mask = upper_bounds >= samples["references"][None, ...]

# Logical AND to identify inliers for each alpha
inlier_id = np.logical_and(lower_mask, upper_mask)

# Compute the relative number of inliers for each alpha
alpha_pred = np.mean(inlier_id, axis=1)

# Calculate absolute error between predicted inliers and alpha
absolute_errors = np.abs(alpha_pred - alphas[:, None])

# Aggregate errors across alpha
error = aggregation(absolute_errors, axis=0)

return {"values": error, "metric_name": "Calibration Error", "variable_names": variable_names}
52 changes: 52 additions & 0 deletions bayesflow/diagnostics/metrics/posterior_contraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Sequence, Any, Mapping, Callable

import numpy as np

from ...utils.dict_utils import dicts_to_arrays


def posterior_contraction(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
aggregation: Callable = np.median,
variable_names: Sequence[str] = None,
) -> Mapping[str, Any]:
"""Computes the posterior contraction (PC) from prior to posterior for the given samples.
Parameters
----------
targets : np.ndarray of shape (num_datasets, num_draws_post, num_variables)
Posterior samples, comprising `num_draws_post` random draws from the posterior distribution
for each data set from `num_datasets`.
references : np.ndarray of shape (num_datasets, num_variables)
Prior samples, comprising `num_datasets` ground truths.
aggregation : callable, optional (default = np.median)
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
variable_names : Sequence[str], optional (default = None)
Optional variable names to select from the available variables.
Returns
-------
result : dict
Dictionary containing:
- "values" : float or np.ndarray
The aggregated posterior contraction per variable
- "metric_name" : str
The name of the metric ("Posterior Contraction").
- "variable_names" : str
The (inferred) variable names.
Notes
-----
Posterior contraction measures the reduction in uncertainty from the prior to the posterior.
Values close to 1 indicate strong contraction (high reduction in uncertainty), while values close to 0
indicate low contraction.
"""

samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)

post_vars = samples["targets"].var(axis=1, ddof=1)
prior_vars = samples["references"].var(axis=0, keepdims=True, ddof=1)
contraction = 1 - (post_vars / prior_vars)
contraction = aggregation(contraction, axis=0)
return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": samples["variable_names"]}
59 changes: 59 additions & 0 deletions bayesflow/diagnostics/metrics/root_mean_squared_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Sequence, Any, Mapping, Callable

import numpy as np

from ...utils.dict_utils import dicts_to_arrays


def root_mean_squared_error(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
normalize: bool = True,
aggregation: Callable = np.median,
variable_names: Sequence[str] = None,
) -> Mapping[str, Any]:
"""Computes the (Normalized) Root Mean Squared Error (RMSE/NRMSE) for the given posterior and prior samples.
Parameters
----------
targets : np.ndarray of shape (num_datasets, num_draws_post, num_variables)
Posterior samples, comprising `num_draws_post` random draws from the posterior distribution
for each data set from `num_datasets`.
references : np.ndarray of shape (num_datasets, num_variables)
Prior samples, comprising `num_datasets` ground truths.
normalize : bool, optional (default = True)
Whether to normalize the RMSE using the range of the prior samples.
aggregation : callable, optional (default = np.median)
Function to aggregate the RMSE across draws. Typically `np.mean` or `np.median`.
variable_names : Sequence[str], optional (default = None)
Optional variable names to select from the available variables.
Notes
-----
Aggregation is performed after computing the RMSE for each posterior draw, instead of first aggregating
the posterior draws and then computing the RMSE between aggregates and ground truths.
Returns
-------
result : dict
Dictionary containing:
- "values" : np.ndarray
The aggregated (N)RMSE for each variable.
- "metric_name" : str
The name of the metric ("RMSE" or "NRMSE").
- "variable_names" : str
The (inferred) variable names.
"""

samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)

rmse = np.sqrt(np.mean((samples["targets"] - samples["references"][:, None, :]) ** 2, axis=0))

if normalize:
rmse /= (samples["references"].max(axis=0) - samples["references"].min(axis=0))[None, :]
metric_name = "NRMSE"
else:
metric_name = "RMSE"

rmse = aggregation(rmse, axis=0)
return {"values": rmse, "metric_name": metric_name, "variable_names": samples["variable_names"]}
1 change: 0 additions & 1 deletion bayesflow/diagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .mc_confusion_matrix import mc_confusion_matrix
from .mmd_hypothesis_test import mmd_hypothesis_test
from .pairs_posterior import pairs_posterior
from .pairs_prior import pairs_prior
from .pairs_samples import pairs_samples
from .recovery import recovery
from .z_score_contraction import z_score_contraction
Loading

0 comments on commit 603f3e9

Please sign in to comment.