Skip to content

Commit

Permalink
Adding Diganostics to Ecdf Plots in the spirit of TARP (#261)
Browse files Browse the repository at this point in the history
* ecdf with random points

* single axis

* single axis

* add comments

* clean

* clean

* title fix

* docstring

* posterior 2d

* posterior 2d fix

* posterior 2d fix

* posterior 2d fix

* fix reference

* clean up

* add comment

* add tests

* pass kwargs

* fix title

* make more customizable

* fix conflict

---------

Co-authored-by: Jerry <[email protected]>
  • Loading branch information
arrjon and jerrymhuang authored Dec 2, 2024
1 parent 0537f2a commit 24e70aa
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 39 deletions.
78 changes: 53 additions & 25 deletions bayesflow/diagnostics/plot_posterior_2d.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import seaborn as sns
Expand All @@ -8,10 +10,11 @@


def plot_posterior_2d(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
post_samples: np.ndarray,
prior_samples: np.ndarray = None,
prior=None,
param_names: list = None,
variable_names: list = None,
true_params: np.ndarray = None,
height: int = 3,
label_fontsize: int = 14,
legend_fontsize: int = 16,
Expand All @@ -24,15 +27,17 @@ def plot_posterior_2d(
) -> sns.PairGrid:
"""Generates a bivariate pairplot given posterior draws and optional prior or prior draws.
posterior_draws : np.ndarray of shape (n_post_draws, n_params)
post_samples : np.ndarray of shape (n_post_draws, n_params)
The posterior draws obtained for a SINGLE observed data set.
prior : bayesflow.forward_inference.Prior instance or None, optional, default: None
The optional prior object having an input-output signature as given by ayesflow.forward_inference.Prior
prior_draws : np.ndarray of shape (n_prior_draws, n_params) or None, optonal (default: None)
The optional prior draws obtained from the prior. If both prior and prior_draws are provided, prior_draws
prior_samples : np.ndarray of shape (n_prior_draws, n_params) or None, optional (default: None)
The optional prior samples obtained from the prior. If both prior and prior_samples are provided, prior_samples
will be used.
param_names : list or None, optional, default: None
prior : bayesflow.forward_inference.Prior instance or None, optional, default: None
The optional prior object having an input-output signature as given by bayesflow.forward_inference.Prior
variable_names : list or None, optional, default: None
The parameter names for nice plot titles. Inferred if None
true_params : np.ndarray of shape (n_params,) or None, optional, default: None
The true parameter values to be plotted on the diagonal.
height : float, optional, default: 3
The height of the pairplot
label_fontsize : int, optional, default: 14
Expand All @@ -41,7 +46,7 @@ def plot_posterior_2d(
The font size of the legend text
tick_fontsize : int, optional, default: 12
The font size of the axis ticklabels
post_color : str, optional, default: '#8f2727'
post_color : str, optional, default: '#132a70'
The color for the posterior histograms and KDEs
priors_color : str, optional, default: gray
The color for the optional prior histograms and KDEs
Expand All @@ -64,7 +69,10 @@ def plot_posterior_2d(
assert (len(post_samples.shape)) == 2, "Shape of `posterior_samples` for a single data set should be 2 dimensional!"

# Plot posterior first
g = plot_samples_2d(post_samples, context="\\theta", param_names=param_names, render=False, height=height, **kwargs)
context = ""
g = plot_samples_2d(
post_samples, context=context, variable_names=variable_names, render=False, height=height, **kwargs
)

# Obtain n_draws and n_params
n_draws, n_params = post_samples.shape
Expand All @@ -73,34 +81,54 @@ def plot_posterior_2d(
if prior is not None and prior_samples is None:
draws = prior(n_draws)
if isinstance(draws, dict):
prior_draws = draws["prior_draws"]
prior_samples = draws["prior_draws"]
else:
prior_draws = draws
prior_samples = draws
elif prior_samples is not None:
# trim to the same number of draws as posterior
prior_samples = prior_samples[:n_draws]

# Attempt to determine parameter names
if param_names is None:
if variable_names is None:
if hasattr(prior, "param_names"):
if prior.param_names is not None:
param_names = prior.param_names
if prior.variable_names is not None:
variable_names = prior.variable_names
else:
param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)]
variable_names = [f"{context} $\\theta_{{{i}}}$" for i in range(1, n_params + 1)]
else:
param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)]
variable_names = [f"{context} $\\theta_{{{i}}}$" for i in range(1, n_params + 1)]
else:
variable_names = [f"{context} {p}" for p in variable_names]

# Add prior, if given
if prior_draws is not None:
prior_draws_df = pd.DataFrame(prior_draws, columns=param_names)
g.data = prior_draws_df
if prior_samples is not None:
prior_samples_df = pd.DataFrame(prior_samples, columns=variable_names)
g.data = prior_samples_df
g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1)
g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1)

# Add true parameters
if true_params is not None:
# Custom function to plot true_params on the diagonal
def plot_true_params(x, **kwargs):
param = x.iloc[0] # Get the single true value for the diagonal
plt.axvline(param, color="black", linestyle="--") # Add vertical line

# only plot on the diagonal a vertical line for the true parameter
g.data = pd.DataFrame(true_params[np.newaxis], columns=variable_names)
g.map_diag(plot_true_params)

# Add legend, if prior also given
if prior_draws is not None or prior is not None:
if prior_samples is not None or prior is not None:
handles = [
Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha),
Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha),
]
g.legend(handles, ["Posterior", "Prior"], fontsize=legend_fontsize, loc="center right")
handles_names = ["Posterior", "Prior"]
if true_params is not None:
handles.append(Line2D(xdata=[], ydata=[], color="black", lw=3, linestyle="--"))
handles_names.append("True Parameter")
plt.legend(handles=handles, labels=handles_names, fontsize=legend_fontsize, loc="center right")

n_row, n_col = g.axes.shape

Expand All @@ -115,9 +143,9 @@ def plot_posterior_2d(
g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)

# Add nice labels
for i, param_name in enumerate(param_names):
for i, param_name in enumerate(variable_names):
g.axes[i, 0].set_ylabel(param_name, fontsize=label_fontsize)
g.axes[len(param_names) - 1, i].set_xlabel(param_name, fontsize=label_fontsize)
g.axes[len(variable_names) - 1, i].set_xlabel(param_name, fontsize=label_fontsize)

# Add grids
for i in range(n_params):
Expand Down
46 changes: 41 additions & 5 deletions bayesflow/diagnostics/plot_sbc_ecdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Sequence
from ..utils.plot_utils import preprocess, add_titles_and_labels, prettify_subplots
from ..utils.ecdf import simultaneous_ecdf_bands
from ..utils.ecdf.ranks import fractional_ranks, distance_ranks


def plot_sbc_ecdf(
Expand All @@ -13,6 +14,7 @@ def plot_sbc_ecdf(
variable_names: Sequence[str] = None,
difference: bool = False,
stacked: bool = False,
rank_type: str | np.ndarray = "fractional",
figsize: Sequence[float] = None,
label_fontsize: int = 16,
legend_fontsize: int = 14,
Expand All @@ -33,11 +35,20 @@ def plot_sbc_ecdf(
For models with many parameters, use `stacked=True` to obtain an idea
of the overall calibration of a posterior approximator.
To compute ranks based on the Euclidean distance to the origin or a reference, use `rank_type='distance'` (and
pass a reference array, respectively). This can be used to check the joint calibration of the posterior approximator
and might show potential biases in the posterior approximation which are not detected by the fractional ranks (e.g.,
when the prior equals the posterior). This is motivated by [2].
[1] Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test
for discrete uniformity and its applications in goodness-of-fit evaluation
and multiple sample comparison. Statistics and Computing, 32(2), 1-21.
https://arxiv.org/abs/2103.10522
[2] Lemos, Pablo, et al. "Sampling-based accuracy testing of posterior estimators
for general inference." International Conference on Machine Learning. PMLR, 2023.
https://proceedings.mlr.press/v202/lemos23a.html
Parameters
----------
post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params)
Expand All @@ -51,6 +62,11 @@ def plot_sbc_ecdf(
If `True`, all ECDFs will be plotted on the same plot.
If `False`, each ECDF will have its own subplot,
similar to the behavior of `plot_sbc_histograms`.
rank_type : str, optional, default: 'fractional'
If `fractional` (default), the ranks are computed as the fraction of posterior samples that are smaller than
the prior. If `distance`, the ranks are computed as the fraction of posterior samples that are closer to
a reference points (default here is the origin). You can pass a reference array in the same shape as the
`prior_samples` array by setting `references` in the ``ranks_kwargs``. This is motivated by [2].
variable_names : list or None, optional, default: None
The parameter names for nice plot titles.
Inferred if None. Only relevant if `stacked=False`.
Expand Down Expand Up @@ -79,7 +95,9 @@ def plot_sbc_ecdf(
**kwargs : dict, optional, default: {}
Keyword arguments can be passed to control the behavior of
ECDF simultaneous band computation through the ``ecdf_bands_kwargs``
dictionary. See `simultaneous_ecdf_bands` for keyword arguments
dictionary. See `simultaneous_ecdf_bands` for keyword arguments.
Moreover, additional keyword arguments can be passed to control the behavior of
the rank computation through the ``ranks_kwargs`` dictionary.
Returns
-------
Expand All @@ -90,6 +108,8 @@ def plot_sbc_ecdf(
ShapeError
If there is a deviation form the expected shapes of `post_samples`
and `prior_samples`.
ValueError
If an unknown `rank_type` is passed.
"""

# Preprocessing
Expand All @@ -99,8 +119,16 @@ def plot_sbc_ecdf(
plot_data["post_samples"] = plot_data.pop("post_variables")
plot_data["prior_samples"] = plot_data.pop("prior_variables")

# Compute fractional ranks (using broadcasting)
ranks = np.mean(plot_data["post_samples"] < plot_data["prior_samples"][:, np.newaxis, :], axis=1)
if rank_type == "fractional":
# Compute fractional ranks
ranks = fractional_ranks(plot_data["post_samples"], plot_data["prior_samples"])
elif rank_type == "distance":
# Compute ranks based on distance to the origin
ranks = distance_ranks(
plot_data["post_samples"], plot_data["prior_samples"], stacked=stacked, **kwargs.pop("ranks_kwargs", {})
)
else:
raise ValueError(f"Unknown rank type: {rank_type}. Use 'fractional' or 'distance'.")

# Plot individual ecdf of parameters
for j in range(ranks.shape[-1]):
Expand All @@ -114,6 +142,8 @@ def plot_sbc_ecdf(

if stacked:
if j == 0:
if not isinstance(plot_data["axes"], np.ndarray):
plot_data["axes"] = np.array([plot_data["axes"]]) # in case of single axis
plot_data["axes"][0].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs")
else:
plot_data["axes"][0].plot(xx, yy, color=rank_ecdf_color, alpha=0.95)
Expand All @@ -132,7 +162,13 @@ def plot_sbc_ecdf(
ylab = "ECDF"

# Add simultaneous bounds
titles = plot_data["variable_names"] if not stacked else ["Stacked ECDFs"]
if not stacked:
titles = plot_data["variable_names"]
elif rank_type in ["distance", "random"]:
titles = ["Joint ECDFs"]
else:
titles = ["Stacked ECDFs"]

for ax, title in zip(plot_data["axes"].flat, titles):
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands")
ax.legend(fontsize=legend_fontsize)
Expand All @@ -145,7 +181,7 @@ def plot_sbc_ecdf(
plot_data["axes"],
plot_data["num_row"],
plot_data["num_col"],
xlabel="Fractional rank statistic",
xlabel=f"{rank_type.capitalize()} rank statistic",
ylabel=ylab,
label_fontsize=label_fontsize,
)
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
split_tensors,
)
from .dispatch import find_distribution, find_network, find_permutation, find_pooling, find_recurrent_net
from .ecdf import simultaneous_ecdf_bands
from .ecdf import simultaneous_ecdf_bands, ranks
from .functional import batched_call
from .git import (
issue_url,
Expand Down
1 change: 1 addition & 0 deletions bayesflow/utils/ecdf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .simultaneous_ecdf_bands import simultaneous_ecdf_bands
from .ranks import fractional_ranks, distance_ranks
102 changes: 102 additions & 0 deletions bayesflow/utils/ecdf/ranks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import numpy as np


def fractional_ranks(post_samples: np.ndarray, prior_samples: np.ndarray) -> np.ndarray:
"""Compute fractional ranks (using broadcasting)"""
return np.mean(post_samples < prior_samples[:, np.newaxis, :], axis=1)


def _helper_distance_ranks(
post_samples: np.ndarray,
prior_samples: np.ndarray,
stacked: bool,
references: np.ndarray,
distance: callable,
p_norm: int,
) -> np.ndarray:
"""
Helper function to compute ranks of true parameter wrt posterior samples
based on distances (defined on the p_norm) between samples and a given references.
"""
if distance is None:
# compute distances to references
dist_post = np.abs((references[:, np.newaxis, :] - post_samples))
dist_prior = np.abs(references - prior_samples)

if stacked:
# compute ranks for all parameters jointly
samples_distances = np.sum(dist_post**p_norm, axis=-1) ** (1 / p_norm)
theta_distances = np.sum(dist_prior**p_norm, axis=-1) ** (1 / p_norm)

ranks = np.mean((samples_distances < theta_distances[:, np.newaxis]), axis=1)[:, np.newaxis]
else:
# compute marginal ranks for each parameter
ranks = np.mean((dist_post < dist_prior[:, np.newaxis]), axis=1)

else:
# compute distances using the given distance function
if stacked:
# compute distance over joint parameters
dist_post = np.array([distance(post_samples[i], references[i]) for i in range(references.shape[0])])
dist_prior = np.array([distance(prior_samples[i], references[i]) for i in range(references.shape[0])])
ranks = np.mean((dist_post < dist_prior[:, np.newaxis]), axis=1)[:, np.newaxis]
else:
# compute distances per parameter
dist_post = np.zeros_like(post_samples)
dist_prior = np.zeros_like(prior_samples)
for i in range(references.shape[0]): # Iterate over samples
for j in range(references.shape[1]): # Iterate over parameters
dist_post[i, :, j] = distance(post_samples[i, :, j], references[i, j])
dist_prior[i, j] = distance(prior_samples[i, j], references[i, j])

ranks = np.mean((dist_post < dist_prior[:, np.newaxis]), axis=1)
return ranks


def distance_ranks(
post_samples: np.ndarray,
prior_samples: np.ndarray,
stacked: bool,
references: np.ndarray = None,
distance: callable = None,
p_norm: int = 2,
) -> np.ndarray:
"""
Compute ranks of true parameter wrt posterior samples based on distances between samples and optional references.
Parameters
----------
post_samples : np.ndarray
The posterior samples.
prior_samples : np.ndarray
The prior samples.
references : np.ndarray, optional
The references to compute the ranks.
stacked : bool
If True, compute ranks for all parameters jointly. Otherwise, compute marginal ranks.
distance : callable, optional
The distance function to compute the ranks. If None, the distance defined by the p_norm is used. Must be
a function that takes two arrays (if stacked, it gets the full parameter vectors, if not only the single
parameters) and returns an array with the distances. This could be based on the log-posterior, for example.
p_norm : int, optional
The norm to compute the distance if no distance is passed. Default is L2-norm.
"""
# Reference is the origin
if references is None:
references = np.zeros((prior_samples.shape[0], prior_samples.shape[1]))
else:
# Validate reference
if references.shape[0] != prior_samples.shape[0]:
raise ValueError("The number of references must match the number of prior samples.")
if references.shape[1] != prior_samples.shape[1]:
raise ValueError("The dimension of references must match the dimension of the parameters.")

ranks = _helper_distance_ranks(
post_samples=post_samples,
prior_samples=prior_samples,
stacked=stacked,
references=references,
distance=distance,
p_norm=p_norm,
)
return ranks
Loading

0 comments on commit 24e70aa

Please sign in to comment.