Skip to content

Commit

Permalink
Merge pull request #158 from han-ol/Development
Browse files Browse the repository at this point in the history
Point estimation update: quantile loss, activation function, simulation-based calibration, polished notebook
  • Loading branch information
stefanradev93 authored Apr 12, 2024
2 parents 6193333 + 6ca3aa1 commit eb4ab84
Show file tree
Hide file tree
Showing 4 changed files with 1,054 additions and 142 deletions.
44 changes: 28 additions & 16 deletions bayesflow/amortizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from bayesflow.default_settings import DEFAULT_KEYS
from bayesflow.exceptions import ConfigurationError, SummaryStatsError
from bayesflow.helper_functions import check_tensor_sanity
from bayesflow.losses import log_loss, mmd_summary_space, norm_diff
from bayesflow.losses import log_loss, mmd_summary_space, norm_diff, quantile_loss
from bayesflow.networks import EvidentialNetwork


Expand Down Expand Up @@ -1223,7 +1223,7 @@ class AmortizedPointEstimator(tf.keras.Model):
The American Statistician, 78(1), 1-14.
"""

def __init__(self, inference_net, summary_net=None, norm_ord=2, loss_fun=None):
def __init__(self, inference_net, summary_net=None, norm_ord=2, quantile_levels=None, loss_fun=None):
"""Initializes a composite neural architecture for amortized bayesian model comparison.
Parameters
Expand All @@ -1250,7 +1250,7 @@ def __init__(self, inference_net, summary_net=None, norm_ord=2, loss_fun=None):

self.inference_net = inference_net
self.summary_net = summary_net
self.loss_fn = self._determine_loss(loss_fun, norm_ord)
self.loss_fn = self._determine_loss(loss_fun, norm_ord, quantile_levels)

def call(self, input_dict, return_summary=False, **kwargs):
"""Performs a forward pass through the summary and inference network given an input dictionary.
Expand Down Expand Up @@ -1342,7 +1342,7 @@ def bootstrap_sample(self, forward_dict, n_bootstrap, simulator, configurator, t
estimates = self.estimate(configurator(forward_dict), **kwargs)

# Prepare for bootstrap simulations based on estimates: Tile estimates
estimates_tiled = np.tile(estimates, (n_bootstrap,1))
estimates_tiled = np.tile(estimates, (n_bootstrap, 1))

# Prepare placeholder dictionary for simulation based on n_bootstrap datasets for every estimate
sim_dict = {
Expand All @@ -1352,37 +1352,48 @@ def bootstrap_sample(self, forward_dict, n_bootstrap, simulator, configurator, t
}

# Populate dictionary with batchable context from forward_dict or leave at None
if DEFAULT_KEYS["sim_batchable_context"] in forward_dict.keys() and forward_dict[DEFAULT_KEYS["sim_batchable_context"]] is not None:

if (
DEFAULT_KEYS["sim_batchable_context"] in forward_dict.keys()
and forward_dict[DEFAULT_KEYS["sim_batchable_context"]] is not None
):
sim_batchable_context = tf.constant(forward_dict[DEFAULT_KEYS["sim_batchable_context"]])

# If sim_batchable_context is a 1D tensor, i.e. single element per dataset, add an axis before tiling
if sim_batchable_context.ndim == 1:
sim_batchable_context_tiled = tf.tile(sim_batchable_context[:,None], (n_bootstrap, 1))
sim_batchable_context_tiled = tf.tile(sim_batchable_context[:, None], (n_bootstrap, 1))
else:
sim_batchable_context_tiled = tf.tile(sim_batchable_context, (n_bootstrap, 1))

sim_dict[DEFAULT_KEYS["batchable_context"]] = sim_batchable_context_tiled

# Populate dictionary with non-batchable context from forward_dict or leave at None
if DEFAULT_KEYS["sim_non_batchable_context"] in forward_dict.keys() and forward_dict[DEFAULT_KEYS["sim_non_batchable_context"]] is not None:
if (
DEFAULT_KEYS["sim_non_batchable_context"] in forward_dict.keys()
and forward_dict[DEFAULT_KEYS["sim_non_batchable_context"]] is not None
):
sim_dict[DEFAULT_KEYS["non_batchable_context"]] = forward_dict[DEFAULT_KEYS["sim_non_batchable_context"]]

# Simulate data based on estimates and context, both tiled `n_bootstrap` times
if simulator.is_batched:
sim_dict = simulator._simulate_batched(estimates_tiled, sim_dict, **kwargs.pop("sim_args", {})) # TODO: think again: for now intentionally left out *args compared to simulator.__call__(), bc could bring unintended behavior
sim_dict = simulator._simulate_batched(
estimates_tiled, sim_dict, **kwargs.pop("sim_args", {})
) # TODO: think again: for now intentionally left out *args compared to simulator.__call__(), bc could bring unintended behavior
else:
sim_dict = simulator._simulate_non_batched(estimates_tiled, sim_dict, **kwargs.pop("sim_args", {})) # TODO: test if sim_args are passed successfully
sim_dict = simulator._simulate_non_batched(
estimates_tiled, sim_dict, **kwargs.pop("sim_args", {})
) # TODO: test if sim_args are passed successfully

# To ensure proper configuration prior to estimation, we need to tile prior_batchable_context as well
forward_dict = forward_dict.copy()
if DEFAULT_KEYS["prior_batchable_context"] in forward_dict.keys() and forward_dict[DEFAULT_KEYS["prior_batchable_context"]] is not None:

if (
DEFAULT_KEYS["prior_batchable_context"] in forward_dict.keys()
and forward_dict[DEFAULT_KEYS["prior_batchable_context"]] is not None
):
prior_batchable_context = tf.constant(forward_dict[DEFAULT_KEYS["prior_batchable_context"]])

# If prior_batchable_context is a 1D tensor, i.e. single element per dataset, add an axis before tiling
if prior_batchable_context.ndim == 1:
prior_batchable_context_tiled = tf.tile(prior_batchable_context[:,None], (n_bootstrap, 1))
prior_batchable_context_tiled = tf.tile(prior_batchable_context[:, None], (n_bootstrap, 1))
else:
prior_batchable_context_tiled = tf.tile(prior_batchable_context, (n_bootstrap, 1))

Expand All @@ -1401,8 +1412,7 @@ def bootstrap_sample(self, forward_dict, n_bootstrap, simulator, configurator, t

# Reshape and reorder to (num_data_sets, n_bootstrap, num_params)
bootstrap_estimates = tf.transpose(
tf.reshape(bootstrap_estimates_tiled, (n_bootstrap, estimates.shape[0], estimates.shape[1])),
perm=[1,0,2]
tf.reshape(bootstrap_estimates_tiled, (n_bootstrap, estimates.shape[0], estimates.shape[1])), perm=[1, 0, 2]
)

if to_numpy:
Expand Down Expand Up @@ -1453,10 +1463,12 @@ def _compute_summary_condition(self, summary_conditions, direct_conditions, **kw
raise SummaryStatsError("Could not concatenarte or determine conditioning inputs...")
return sum_condition, full_cond

def _determine_loss(self, loss_fun, norm_ord):
def _determine_loss(self, loss_fun, norm_ord, quantile_levels):
"""Determines which loss function to use and defaults to the norm_ord=2 as specified by the ``__init__`` method."""

# In case of user-provided loss, override norm order
if loss_fun is not None:
return loss_fun
if quantile_levels is not None:
return partial(quantile_loss, quantile_levels=quantile_levels)
return partial(norm_diff, ord=norm_ord, axis=-1)
53 changes: 53 additions & 0 deletions bayesflow/helper_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,56 @@ def call(self, inputs, **kwargs):
if self.residual:
x = x + inputs
return self.act_fn(x)


class QuantileActivation(tf.keras.layers.Layer):
"""Inductive bias activation function for ordered quantiles anchored at a central quantile."""

def __init__(self, quantiles, *args, **kwargs):
"""Creates an activation function that ensures ordered quantiles anchored at a central quantile.
Parameters
----------
quantiles : list
List of quantiles to be used
*args : list
Additional positional arguments passed to the base class tf.keras.layers.Layer
**kwargs : dict
Additional keyword arguments passed to the base class tf.keras.layers.Layer
"""
super(QuantileActivation, self).__init__(*args, **kwargs)
self.quantiles = quantiles
self.anchor_quantile_index = len(quantiles) // 2

def call(self, inputs):
"""Forward pass through the activation function.
Parameters
----------
inputs : tf.Tensor of shape (batch_size, n_quantiles*num_params)
The tensor containing the pre-activation to be transformed by the activation function.
Returns
-------
outputs : tf.Tensor of shape (batch_size, n_quantiles, num_params)
The transformed output tensor.
"""

# Reshape to separate n_quantiles from n_params
assert inputs.shape[-1] % len(self.quantiles) == 0, "Number of quantiles must divide number of parameters"
inputs = tf.reshape(inputs, [-1, len(self.quantiles), inputs.shape[-1] // len(self.quantiles)])

# Divide in anchor, below and above
below_inputs = inputs[:, : self.anchor_quantile_index, :]
anchor_input = inputs[:, self.anchor_quantile_index, :][:, None, :]
above_inputs = inputs[:, self.anchor_quantile_index + 1 :, :]

# Apply exponential activation and cumulate to ensure ordered quantiles
below = tf.exp(below_inputs)
above = tf.exp(above_inputs)
below = anchor_input - tf.cumsum(below_inputs, axis=1)
above = anchor_input + tf.cumsum(above_inputs, axis=1)

# Concatenate and reshape back
x = tf.concat([below, anchor_input, above], axis=1)
return x
49 changes: 46 additions & 3 deletions bayesflow/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ def log_loss(model_indices, preds, evidential=False, label_smoothing=0.01):
return loss


def norm_diff(tensor_a, tensor_b, axis=None, ord='euclidean'):
def norm_diff(tensor_a, tensor_b, axis=None, ord="euclidean"):
"""
Wrapper around tf.norm that computes the norm of the difference between two tensors along the specified axis.
Wrapper around tf.norm that computes the norm of the difference between
two tensors along the specified axis.
Parameters
----------
Expand All @@ -197,7 +198,49 @@ def norm_diff(tensor_a, tensor_b, axis=None, ord='euclidean'):
axis : Any or None
Axis along which to compute the norm of the difference. Default is None.
ord : int or str
Order of the norm. Supports 'euclidean' and other norms supported by tf.norm. Default is 'euclidean'.
Order of the norm. Supports 'euclidean' and other norms supported by tf.norm.
Default is 'euclidean'.
"""
difference = tensor_a - tensor_b
return tf.norm(difference, ord=ord, axis=axis)


def quantile_loss(y_pred, y_true, quantile_levels=[0.05, 0.95]):
"""Quantile loss as described in [1].
[1] Gneiting, T., & Raftery, A. E. (2007). Strictly Proper Scoring Rules, Prediction,
and Estimation. Journal of the American Statistical Association, 102(477), 359–378.
Parameters
----------
y_pred : tf.Tensor of shape (n_batch, n_quantiles, n_params)
Neural estimates of each quantile and each parameter.
y_true : tf.Tensor of shape (n_batch, n_params)
True values for each parameter.
quantile_levels : list, optional
Desired quantile levels. Number of quantile levels must match second-to-last dimension of `y_pred`.
Default is [0.05, 0.95].
Returns
-------
loss : tf.Tensor
A single scalar Monte-Carlo approximation of the quantile-loss, shape (,)
"""
n_params = y_true.shape[-1]
tau = tf.constant(quantile_levels, dtype=tf.float32)
n_quantiles = tau.shape[0]

# If requested, reshape to separate n_quantiles from n_params
assert (
y_pred.shape[-2] == n_quantiles
), "Second-to-last dimension of network output should contain quantiles for different quantile levels! The dimension does not match with the specified quantile levels."

assert (
y_pred.shape[-1] == n_params
), "Last dimension should contain quantiles for different parameters! The dimension does not match with the number of parameters."

pointwise_diff = y_pred - y_true[:, None, :] # (n_batch, n_quantiles, n_params)

loss = pointwise_diff * (tf.cast(pointwise_diff > 0, tf.float32) - tau[None, :, None])

return tf.reduce_mean(loss)
Loading

0 comments on commit eb4ab84

Please sign in to comment.