-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Draft implementation of quantile estimation
- Loading branch information
Showing
7 changed files
with
294 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .approximator import Approximator | ||
from .continuous_approximator import ContinuousApproximator | ||
from .continuous_point_approximator import ContinuousPointApproximator | ||
from .model_comparison_approximator import ModelComparisonApproximator |
170 changes: 170 additions & 0 deletions
170
bayesflow/approximators/continuous_point_approximator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from collections.abc import Sequence | ||
|
||
import keras | ||
import numpy as np | ||
from keras.saving import ( | ||
deserialize_keras_object as deserialize, | ||
register_keras_serializable as serializable, | ||
serialize_keras_object as serialize, | ||
) | ||
|
||
from bayesflow.adapters import Adapter | ||
from bayesflow.networks import PointInferenceNetwork, SummaryNetwork | ||
from bayesflow.types import Tensor | ||
from bayesflow.utils import logging, split_arrays | ||
from .approximator import Approximator | ||
|
||
|
||
@serializable(package="bayesflow.approximators") | ||
class ContinuousPointApproximator(Approximator): | ||
""" | ||
Defines a workflow for performing fast posterior or likelihood inference. | ||
The distribution is approximated by a point with an feed-forward network and an optional summary network. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
adapter: Adapter, | ||
inference_network: PointInferenceNetwork, | ||
summary_network: SummaryNetwork = None, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.adapter = adapter | ||
self.inference_network = inference_network | ||
self.summary_network = summary_network | ||
|
||
@classmethod | ||
def build_adapter( | ||
cls, | ||
inference_variables: Sequence[str], | ||
inference_conditions: Sequence[str] = None, | ||
summary_variables: Sequence[str] = None, | ||
) -> Adapter: | ||
adapter = Adapter.create_default(inference_variables) | ||
|
||
if inference_conditions is not None: | ||
adapter = adapter.concatenate(inference_conditions, into="inference_conditions") | ||
|
||
if summary_variables is not None: | ||
adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables") | ||
|
||
adapter = adapter.keep(["inference_variables", "inference_conditions", "summary_variables"]).standardize() | ||
|
||
return adapter | ||
|
||
def compile( | ||
self, | ||
*args, | ||
inference_metrics: Sequence[keras.Metric] = None, | ||
summary_metrics: Sequence[keras.Metric] = None, | ||
**kwargs, | ||
): | ||
if inference_metrics: | ||
self.inference_network._metrics = inference_metrics | ||
|
||
if summary_metrics: | ||
if self.summary_network is None: | ||
logging.warning("Ignoring summary metrics because there is no summary network.") | ||
else: | ||
self.summary_network._metrics = summary_metrics | ||
|
||
return super().compile(*args, **kwargs) | ||
|
||
def compute_metrics( | ||
self, | ||
inference_variables: Tensor, | ||
inference_conditions: Tensor = None, | ||
summary_variables: Tensor = None, | ||
stage: str = "training", | ||
) -> dict[str, Tensor]: | ||
if self.summary_network is None: | ||
if summary_variables is not None: | ||
raise ValueError("Cannot compute summary metrics without a summary network.") | ||
|
||
summary_metrics = {} | ||
else: | ||
if summary_variables is None: | ||
raise ValueError("Summary variables are required when a summary network is present.") | ||
|
||
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage) | ||
summary_outputs = summary_metrics.pop("outputs") | ||
|
||
# append summary outputs to inference conditions | ||
if inference_conditions is None: | ||
inference_conditions = summary_outputs | ||
else: | ||
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1) | ||
|
||
inference_metrics = self.inference_network.compute_metrics( | ||
inference_variables, conditions=inference_conditions, stage=stage | ||
) | ||
|
||
loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(())) | ||
|
||
inference_metrics = {f"{key}/inference_{key}": value for key, value in inference_metrics.items()} | ||
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()} | ||
|
||
metrics = {"loss": loss} | inference_metrics | summary_metrics | ||
|
||
return metrics | ||
|
||
def fit(self, *args, **kwargs): | ||
return super().fit(*args, **kwargs, adapter=self.adapter) | ||
|
||
@classmethod | ||
def from_config(cls, config, custom_objects=None): | ||
config["adapter"] = deserialize(config["adapter"], custom_objects=custom_objects) | ||
config["inference_network"] = deserialize(config["inference_network"], custom_objects=custom_objects) | ||
config["summary_network"] = deserialize(config["summary_network"], custom_objects=custom_objects) | ||
|
||
return super().from_config(config, custom_objects=custom_objects) | ||
|
||
def get_config(self): | ||
base_config = super().get_config() | ||
config = { | ||
"adapter": serialize(self.adapter), | ||
"inference_network": serialize(self.inference_network), | ||
"summary_network": serialize(self.summary_network), | ||
} | ||
|
||
return base_config | config | ||
|
||
def estimate( | ||
self, | ||
*, | ||
conditions: dict[str, np.ndarray], | ||
split: bool = False, | ||
**kwargs, | ||
) -> dict[str, np.ndarray]: | ||
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) | ||
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions) | ||
conditions = {"inference_variables": self._estimate(**conditions)} | ||
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions) | ||
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs) | ||
|
||
if split: | ||
conditions = split_arrays(conditions, axis=-1) | ||
return conditions | ||
|
||
def _estimate( | ||
self, | ||
inference_conditions: Tensor = None, | ||
summary_variables: Tensor = None, | ||
) -> Tensor: | ||
if self.summary_network is None: | ||
if summary_variables is not None: | ||
raise ValueError("Cannot use summary variables without a summary network.") | ||
else: | ||
if summary_variables is None: | ||
raise ValueError("Summary variables are required when a summary network is present.") | ||
|
||
summary_outputs = self.summary_network(summary_variables) | ||
|
||
if inference_conditions is None: | ||
inference_conditions = summary_outputs | ||
else: | ||
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=1) | ||
|
||
return self.inference_network.estimate(conditions=inference_conditions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import keras | ||
|
||
from bayesflow.types import Shape, Tensor | ||
|
||
|
||
class PointInferenceNetwork(keras.Layer): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: | ||
pass | ||
|
||
def call( | ||
self, | ||
xz: Tensor, | ||
conditions: Tensor = None, | ||
training: bool = False, | ||
**kwargs, | ||
) -> Tensor | tuple[Tensor, Tensor]: | ||
return self._forward(xz, conditions=conditions, training=training, **kwargs) | ||
|
||
def _forward( | ||
self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs | ||
) -> Tensor | tuple[Tensor, Tensor]: | ||
raise NotImplementedError | ||
|
||
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: | ||
if not self.built: | ||
xz_shape = keras.ops.shape(x) | ||
conditions_shape = None if conditions is None else keras.ops.shape(conditions) | ||
self.build(xz_shape, conditions_shape=conditions_shape) | ||
|
||
metrics = {} | ||
|
||
if stage != "training" and any(self.metrics): | ||
# compute sample-based metrics | ||
# samples = self.sample((keras.ops.shape(x)[0],), conditions=conditions) | ||
# | ||
# for metric in self.metrics: | ||
# metrics[metric.name] = metric(samples, x) | ||
pass | ||
# TODO: instead compute estimate based metrics | ||
|
||
return metrics | ||
|
||
def estimate(self, conditions: Tensor = None) -> Tensor: | ||
return self._forward(None, conditions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .quantile_regressor import QuantileRegressor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from collections.abc import Sequence | ||
|
||
import keras | ||
from keras.saving import register_keras_serializable as serializable | ||
|
||
from bayesflow.types import Tensor | ||
from bayesflow.utils import find_network, keras_kwargs | ||
|
||
from ..point_inference_network import PointInferenceNetwork | ||
|
||
|
||
@serializable(package="networks.regressors") | ||
class QuantileRegressor(PointInferenceNetwork): | ||
def __init__( | ||
self, | ||
subnet: str | type = "mlp", | ||
quantile_levels: Sequence[float] = None, | ||
**kwargs, | ||
): | ||
super().__init__(**keras_kwargs(kwargs)) | ||
|
||
if quantile_levels is not None: | ||
self.quantile_levels = quantile_levels | ||
else: | ||
self.quantile_levels = [0.1, 0.9] | ||
self.quantile_levels = keras.ops.convert_to_tensor(self.quantile_levels) | ||
self.num_quantiles = len(self.quantile_levels) # should we have this shorthand? | ||
# TODO: should we initialize self.num_variables here already? The actual value is assined in build() | ||
|
||
self.body = find_network(subnet, **kwargs.get("subnet_kwargs", {})) | ||
self.head = keras.layers.Dense( | ||
units=None, bias_initializer="zeros", kernel_initializer="zeros" | ||
) # TODO: why initialize at zero (taken from consistency_model.py) | ||
|
||
# noinspection PyMethodOverriding | ||
def build( | ||
self, xz_shape, conditions_shape=None | ||
): # TODO: seems like conditions_shape should definetely be supplied, change to positional argument? | ||
super().build(xz_shape) | ||
|
||
self.num_variables = xz_shape[-1] | ||
input_shape = conditions_shape | ||
self.body.build(input_shape=input_shape) | ||
|
||
input_shape = self.body.compute_output_shape(input_shape) | ||
self.head.units = self.num_quantiles * self.num_variables | ||
self.head.build(input_shape=input_shape) | ||
|
||
def _forward( | ||
self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs | ||
) -> Tensor | tuple[Tensor, Tensor]: | ||
head_input = self.body(conditions) | ||
pred_quantiles = self.head(head_input) # (batch_shape, num_quantiles * num_variables) | ||
pred_quantiles = keras.ops.reshape(pred_quantiles, (-1, self.num_quantiles, self.num_variables)) | ||
# (batch_shape, num_quantiles, num_variables) | ||
|
||
return pred_quantiles | ||
|
||
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: | ||
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) | ||
|
||
true_value = x | ||
# TODO: keeping like it used to be, but why is do we not set training=(stage=="training") in self.call() | ||
pred_quantiles = self(x, conditions) | ||
pointwise_differance = pred_quantiles - true_value[:, None, :] | ||
|
||
loss = pointwise_differance * ( | ||
keras.ops.cast(pointwise_differance > 0, float) - self.quantile_levels[None, :, None] | ||
) | ||
loss = keras.ops.mean(loss) | ||
|
||
return base_metrics | {"loss": loss} |