diff --git a/bayesflow/__init__.py b/bayesflow/__init__.py index d0d8d389b..2d752fa51 100644 --- a/bayesflow/__init__.py +++ b/bayesflow/__init__.py @@ -12,7 +12,7 @@ ) from .workflows import BasicWorkflow -from .approximators import ContinuousApproximator +from .approximators import ContinuousApproximator, ContinuousPointApproximator from .adapters import Adapter from .datasets import OfflineDataset, OnlineDataset, DiskDataset from .simulators import make_simulator diff --git a/bayesflow/approximators/__init__.py b/bayesflow/approximators/__init__.py index b2bd76f6b..1d98bd819 100644 --- a/bayesflow/approximators/__init__.py +++ b/bayesflow/approximators/__init__.py @@ -1,3 +1,4 @@ from .approximator import Approximator from .continuous_approximator import ContinuousApproximator +from .continuous_point_approximator import ContinuousPointApproximator from .model_comparison_approximator import ModelComparisonApproximator diff --git a/bayesflow/approximators/continuous_point_approximator.py b/bayesflow/approximators/continuous_point_approximator.py new file mode 100644 index 000000000..71af53402 --- /dev/null +++ b/bayesflow/approximators/continuous_point_approximator.py @@ -0,0 +1,170 @@ +from collections.abc import Sequence + +import keras +import numpy as np +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) + +from bayesflow.adapters import Adapter +from bayesflow.networks import PointInferenceNetwork, SummaryNetwork +from bayesflow.types import Tensor +from bayesflow.utils import logging, split_arrays +from .approximator import Approximator + + +@serializable(package="bayesflow.approximators") +class ContinuousPointApproximator(Approximator): + """ + Defines a workflow for performing fast posterior or likelihood inference. + The distribution is approximated by a point with an feed-forward network and an optional summary network. + """ + + def __init__( + self, + *, + adapter: Adapter, + inference_network: PointInferenceNetwork, + summary_network: SummaryNetwork = None, + **kwargs, + ): + super().__init__(**kwargs) + self.adapter = adapter + self.inference_network = inference_network + self.summary_network = summary_network + + @classmethod + def build_adapter( + cls, + inference_variables: Sequence[str], + inference_conditions: Sequence[str] = None, + summary_variables: Sequence[str] = None, + ) -> Adapter: + adapter = Adapter.create_default(inference_variables) + + if inference_conditions is not None: + adapter = adapter.concatenate(inference_conditions, into="inference_conditions") + + if summary_variables is not None: + adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables") + + adapter = adapter.keep(["inference_variables", "inference_conditions", "summary_variables"]).standardize() + + return adapter + + def compile( + self, + *args, + inference_metrics: Sequence[keras.Metric] = None, + summary_metrics: Sequence[keras.Metric] = None, + **kwargs, + ): + if inference_metrics: + self.inference_network._metrics = inference_metrics + + if summary_metrics: + if self.summary_network is None: + logging.warning("Ignoring summary metrics because there is no summary network.") + else: + self.summary_network._metrics = summary_metrics + + return super().compile(*args, **kwargs) + + def compute_metrics( + self, + inference_variables: Tensor, + inference_conditions: Tensor = None, + summary_variables: Tensor = None, + stage: str = "training", + ) -> dict[str, Tensor]: + if self.summary_network is None: + if summary_variables is not None: + raise ValueError("Cannot compute summary metrics without a summary network.") + + summary_metrics = {} + else: + if summary_variables is None: + raise ValueError("Summary variables are required when a summary network is present.") + + summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage) + summary_outputs = summary_metrics.pop("outputs") + + # append summary outputs to inference conditions + if inference_conditions is None: + inference_conditions = summary_outputs + else: + inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1) + + inference_metrics = self.inference_network.compute_metrics( + inference_variables, conditions=inference_conditions, stage=stage + ) + + loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(())) + + inference_metrics = {f"{key}/inference_{key}": value for key, value in inference_metrics.items()} + summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()} + + metrics = {"loss": loss} | inference_metrics | summary_metrics + + return metrics + + def fit(self, *args, **kwargs): + return super().fit(*args, **kwargs, adapter=self.adapter) + + @classmethod + def from_config(cls, config, custom_objects=None): + config["adapter"] = deserialize(config["adapter"], custom_objects=custom_objects) + config["inference_network"] = deserialize(config["inference_network"], custom_objects=custom_objects) + config["summary_network"] = deserialize(config["summary_network"], custom_objects=custom_objects) + + return super().from_config(config, custom_objects=custom_objects) + + def get_config(self): + base_config = super().get_config() + config = { + "adapter": serialize(self.adapter), + "inference_network": serialize(self.inference_network), + "summary_network": serialize(self.summary_network), + } + + return base_config | config + + def estimate( + self, + *, + conditions: dict[str, np.ndarray], + split: bool = False, + **kwargs, + ) -> dict[str, np.ndarray]: + conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) + conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions) + conditions = {"inference_variables": self._estimate(**conditions)} + conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions) + conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs) + + if split: + conditions = split_arrays(conditions, axis=-1) + return conditions + + def _estimate( + self, + inference_conditions: Tensor = None, + summary_variables: Tensor = None, + ) -> Tensor: + if self.summary_network is None: + if summary_variables is not None: + raise ValueError("Cannot use summary variables without a summary network.") + else: + if summary_variables is None: + raise ValueError("Summary variables are required when a summary network is present.") + + summary_outputs = self.summary_network(summary_variables) + + if inference_conditions is None: + inference_conditions = summary_outputs + else: + inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=1) + + return self.inference_network.estimate(conditions=inference_conditions) diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index 9a915572b..ecb36631a 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -1,10 +1,12 @@ from .cif import CIF from .consistency_models import ConsistencyModel, ContinuousConsistencyModel from .coupling_flow import CouplingFlow +from .regressors import QuantileRegressor from .deep_set import DeepSet from .flow_matching import FlowMatching from .free_form_flow import FreeFormFlow from .inference_network import InferenceNetwork +from .point_inference_network import PointInferenceNetwork from .mlp import MLP from .lstnet import LSTNet from .summary_network import SummaryNetwork diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py new file mode 100644 index 000000000..3a9794177 --- /dev/null +++ b/bayesflow/networks/point_inference_network.py @@ -0,0 +1,47 @@ +import keras + +from bayesflow.types import Shape, Tensor + + +class PointInferenceNetwork(keras.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: + pass + + def call( + self, + xz: Tensor, + conditions: Tensor = None, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + return self._forward(xz, conditions=conditions, training=training, **kwargs) + + def _forward( + self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + raise NotImplementedError + + def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: + if not self.built: + xz_shape = keras.ops.shape(x) + conditions_shape = None if conditions is None else keras.ops.shape(conditions) + self.build(xz_shape, conditions_shape=conditions_shape) + + metrics = {} + + if stage != "training" and any(self.metrics): + # compute sample-based metrics + # samples = self.sample((keras.ops.shape(x)[0],), conditions=conditions) + # + # for metric in self.metrics: + # metrics[metric.name] = metric(samples, x) + pass + # TODO: instead compute estimate based metrics + + return metrics + + def estimate(self, conditions: Tensor = None) -> Tensor: + return self._forward(None, conditions) diff --git a/bayesflow/networks/regressors/__init__.py b/bayesflow/networks/regressors/__init__.py new file mode 100644 index 000000000..745a5b87d --- /dev/null +++ b/bayesflow/networks/regressors/__init__.py @@ -0,0 +1 @@ +from .quantile_regressor import QuantileRegressor diff --git a/bayesflow/networks/regressors/quantile_regressor.py b/bayesflow/networks/regressors/quantile_regressor.py new file mode 100644 index 000000000..0f528d05f --- /dev/null +++ b/bayesflow/networks/regressors/quantile_regressor.py @@ -0,0 +1,72 @@ +from collections.abc import Sequence + +import keras +from keras.saving import register_keras_serializable as serializable + +from bayesflow.types import Tensor +from bayesflow.utils import find_network, keras_kwargs + +from ..point_inference_network import PointInferenceNetwork + + +@serializable(package="networks.regressors") +class QuantileRegressor(PointInferenceNetwork): + def __init__( + self, + subnet: str | type = "mlp", + quantile_levels: Sequence[float] = None, + **kwargs, + ): + super().__init__(**keras_kwargs(kwargs)) + + if quantile_levels is not None: + self.quantile_levels = quantile_levels + else: + self.quantile_levels = [0.1, 0.9] + self.quantile_levels = keras.ops.convert_to_tensor(self.quantile_levels) + self.num_quantiles = len(self.quantile_levels) # should we have this shorthand? + # TODO: should we initialize self.num_variables here already? The actual value is assined in build() + + self.body = find_network(subnet, **kwargs.get("subnet_kwargs", {})) + self.head = keras.layers.Dense( + units=None, bias_initializer="zeros", kernel_initializer="zeros" + ) # TODO: why initialize at zero (taken from consistency_model.py) + + # noinspection PyMethodOverriding + def build( + self, xz_shape, conditions_shape=None + ): # TODO: seems like conditions_shape should definetely be supplied, change to positional argument? + super().build(xz_shape) + + self.num_variables = xz_shape[-1] + input_shape = conditions_shape + self.body.build(input_shape=input_shape) + + input_shape = self.body.compute_output_shape(input_shape) + self.head.units = self.num_quantiles * self.num_variables + self.head.build(input_shape=input_shape) + + def _forward( + self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + head_input = self.body(conditions) + pred_quantiles = self.head(head_input) # (batch_shape, num_quantiles * num_variables) + pred_quantiles = keras.ops.reshape(pred_quantiles, (-1, self.num_quantiles, self.num_variables)) + # (batch_shape, num_quantiles, num_variables) + + return pred_quantiles + + def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: + base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) + + true_value = x + # TODO: keeping like it used to be, but why is do we not set training=(stage=="training") in self.call() + pred_quantiles = self(x, conditions) + pointwise_differance = pred_quantiles - true_value[:, None, :] + + loss = pointwise_differance * ( + keras.ops.cast(pointwise_differance > 0, float) - self.quantile_levels[None, :, None] + ) + loss = keras.ops.mean(loss) + + return base_metrics | {"loss": loss}