Skip to content

Commit 7c66f97

Browse files
authored
Merge pull request #184 from stefanradev93/model-comparison
Model Comparison
2 parents 5c61054 + 151b9d5 commit 7c66f97

File tree

76 files changed

+2516
-1194
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+2516
-1194
lines changed

bayesflow/__init__.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,38 @@
11
from . import (
22
approximators,
3-
configurators,
3+
benchmarks,
4+
data_adapters,
45
datasets,
56
diagnostics,
67
distributions,
78
networks,
89
simulators,
9-
benchmarks,
1010
utils,
1111
)
1212

13-
from .approximators import Approximator
13+
from .approximators import ContinuousApproximator
1414
from .datasets import OfflineDataset, OnlineDataset
1515

16-
import keras
1716

18-
if keras.backend.backend() == "torch":
19-
# turn off gradients by default
20-
import torch
17+
def setup():
18+
# perform any necessary setup without polluting the namespace
19+
import keras
20+
import logging
21+
22+
# set the basic logging level if the user hasn't already
23+
logging.basicConfig(level=logging.INFO)
24+
25+
# use a separate logger for the bayesflow package
26+
logger = logging.getLogger(__name__)
27+
logger.setLevel(logging.INFO)
28+
29+
if keras.backend.backend() == "torch":
30+
# turn off gradients by default
31+
import torch
32+
33+
torch.autograd.set_grad_enabled(False)
2134

22-
torch.autograd.set_grad_enabled(False)
2335

24-
# clean up namespace
25-
del keras
26-
del torch
36+
# call and clean up namespace
37+
setup()
38+
del setup

bayesflow/approximators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from .approximator import Approximator
2+
from .continuous_approximator import ContinuousApproximator
3+
from .model_comparison_approximator import ModelComparisonApproximator
Lines changed: 79 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,84 @@
11
import keras
2-
from keras.saving import register_keras_serializable
3-
4-
from bayesflow.configurators import Configurator
5-
6-
match keras.backend.backend():
7-
case "jax":
8-
from .jax_approximator import JAXApproximator as BaseApproximator
9-
case "numpy":
10-
from .numpy_approximator import NumpyApproximator as BaseApproximator
11-
case "tensorflow":
12-
from .tensorflow_approximator import TensorFlowApproximator as BaseApproximator
13-
case "torch":
14-
from .torch_approximator import TorchApproximator as BaseApproximator
15-
case other:
16-
raise NotImplementedError(f"BayesFlow does not currently support backend '{other}'.")
17-
18-
19-
@register_keras_serializable(package="bayesflow.amortizers")
20-
class Approximator(BaseApproximator):
21-
def __init__(self, **kwargs):
22-
"""The main workhorse for learning amortized neural approximators for distributions arising
23-
in inverse problems and Bayesian inference (e.g., posterior distributions, likelihoods, marginal
24-
likelihoods).
25-
26-
The complete semantics of this class allow for flexible estimation of the following distribution:
27-
28-
Q(inference_variables | H(summary_variables; summary_conditions), inference_conditions),
29-
30-
# TODO - math notation
31-
32-
where all quantities to the right of the "given" symbol | are optional and H refers to the optional
33-
summary /embedding network used to compress high-dimensional data into lower-dimensional summary
34-
vectors. Some examples are provided below.
35-
36-
Parameters
37-
----------
38-
inference_variables: list[str]
39-
A list of variable names indicating the quantities to be inferred / learned by the approximator,
40-
e.g., model parameters when approximating the Bayesian posterior or observables when approximating
41-
a likelihood density.
42-
inference_conditions: list[str]
43-
A list of variable names indicating quantities that will be used to condition (i.e., inform) the
44-
distribution over inference variables directly, that is, without passing through the summary network.
45-
summary_variables: list[str]
46-
A list of variable names indicating quantities that will be used to condition (i.e., inform) the
47-
distribution over inference variables after passing through the summary network (i.e., undergoing a
48-
learnable transformation / dimensionality reduction). For instance, non-vector quantities (e.g.,
49-
sets or time-series) in posterior inference will typically qualify as summary variables. In addition,
50-
these quantities may involve learnable distributions on their own.
51-
summary_conditions: list[str]
52-
A list of variable names indicating quantities that will be used to condition (i.e., inform) the
53-
optional summary network, e.g., when the summary network accepts further conditions that do not
54-
conform to the semantics of summary variable (i.e., need not be embedded or their distribution
55-
needs not be learned).
56-
57-
# TODO add citations
58-
59-
Examples
60-
-------
61-
# TODO
62-
"""
63-
if "configurator" not in kwargs:
64-
# try to set up a default configurator
65-
if "inference_variables" not in kwargs:
66-
raise ValueError("You must specify either a configurator or arguments for the default configurator.")
67-
68-
inference_variables = kwargs.pop("inference_variables")
69-
inference_conditions = kwargs.pop("inference_conditions", None)
70-
summary_variables = kwargs.pop("summary_variables", None)
71-
72-
kwargs["configurator"] = Configurator(
73-
inference_variables,
74-
inference_conditions,
75-
summary_variables,
76-
)
2+
import multiprocessing as mp
3+
4+
from bayesflow.data_adapters import DataAdapter
5+
from bayesflow.datasets import OnlineDataset
6+
from bayesflow.simulators import Simulator
7+
from bayesflow.utils import find_batch_size, filter_kwargs, logging
8+
9+
from .backend_approximators import BackendApproximator
10+
11+
12+
class Approximator(BackendApproximator):
13+
def build(self, data_shapes: any) -> None:
14+
mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes)
15+
self.build_from_data(mock_data)
16+
17+
@classmethod
18+
def build_data_adapter(cls, **kwargs) -> DataAdapter:
19+
# implemented by each respective architecture
20+
raise NotImplementedError
21+
22+
def build_from_data(self, data: dict[str, any]) -> None:
23+
self.compute_metrics(**data, stage="training")
24+
self.built = True
25+
26+
@classmethod
27+
def build_dataset(
28+
cls,
29+
*,
30+
batch_size: int = "auto",
31+
num_batches: int,
32+
data_adapter: DataAdapter = "auto",
33+
memory_budget: str | int = "auto",
34+
simulator: Simulator,
35+
workers: int = "auto",
36+
use_multiprocessing: bool = False,
37+
max_queue_size: int = 32,
38+
**kwargs,
39+
) -> OnlineDataset:
40+
if batch_size == "auto":
41+
batch_size = find_batch_size(memory_budget=memory_budget, sample=simulator.sample((1,)))
42+
logging.info(f"Using a batch size of {batch_size}.")
43+
44+
if data_adapter == "auto":
45+
data_adapter = cls.build_data_adapter(**filter_kwargs(kwargs, cls.build_data_adapter))
46+
47+
if workers == "auto":
48+
workers = mp.cpu_count()
49+
logging.info(f"Using {workers} data loading workers.")
50+
51+
workers = workers or 1
52+
53+
return OnlineDataset(
54+
simulator=simulator,
55+
batch_size=batch_size,
56+
num_batches=num_batches,
57+
data_adapter=data_adapter,
58+
workers=workers,
59+
use_multiprocessing=use_multiprocessing,
60+
max_queue_size=max_queue_size,
61+
)
62+
63+
def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
64+
if dataset is None:
65+
if simulator is None:
66+
raise ValueError("Received no data to fit on. Please provide either a dataset or a simulator.")
67+
68+
logging.info(f"Building dataset from simulator instance of {simulator.__class__.__name__}.")
69+
dataset = self.build_dataset(simulator=simulator, **filter_kwargs(kwargs, self.build_dataset))
7770
else:
78-
# the user passed a configurator, so we should not configure a default one
79-
# check if the user also passed args for the default configurator
80-
keys = ["inference_variables", "inference_conditions", "summary_variables"]
81-
if any(key in kwargs for key in keys):
71+
if simulator is not None:
8272
raise ValueError(
83-
"Received an ambiguous set of arguments: You are passing a configurator explicitly, "
84-
"but also providing arguments for the default configurator."
73+
"Received conflicting arguments. Please provide either a dataset or a simulator, but not both."
8574
)
8675

87-
kwargs.setdefault("summary_network", None)
88-
super().__init__(**kwargs)
76+
logging.info(f"Fitting on dataset instance of {dataset.__class__.__name__}.")
77+
78+
if not self.built:
79+
logging.info("Building on a test batch.")
80+
mock_data = dataset[0]
81+
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
82+
self.build_from_data(mock_data)
83+
84+
return super().fit(dataset=dataset, **kwargs)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .backend_approximator import BackendApproximator
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import keras
2+
3+
from bayesflow.utils import filter_kwargs
4+
5+
6+
match keras.backend.backend():
7+
case "jax":
8+
from .jax_approximator import JAXApproximator as BaseBackendApproximator
9+
case "numpy":
10+
from .numpy_approximator import NumpyApproximator as BaseBackendApproximator
11+
case "tensorflow":
12+
from .tensorflow_approximator import TensorFlowApproximator as BaseBackendApproximator
13+
case "torch":
14+
from .torch_approximator import TorchApproximator as BaseBackendApproximator
15+
case other:
16+
raise ValueError(f"Backend '{other}' is not supported.")
17+
18+
19+
class BackendApproximator(BaseBackendApproximator):
20+
# noinspection PyMethodOverriding
21+
def fit(self, *, dataset: keras.utils.PyDataset, **kwargs):
22+
return super().fit(x=dataset, y=None, **filter_kwargs(kwargs, super().fit))

bayesflow/approximators/jax_approximator.py renamed to bayesflow/approximators/backend_approximators/jax_approximator.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,23 @@
11
import jax
22
import keras
33

4-
from .base_approximator import BaseApproximator
5-
from ..types import Tensor
4+
from bayesflow.utils import filter_kwargs
65

76

8-
class JAXApproximator(BaseApproximator):
9-
def train_step(self, *args, **kwargs):
10-
return self.stateless_train_step(*args, **kwargs)
11-
12-
def test_step(self, *args, **kwargs):
13-
return self.stateless_test_step(*args, **kwargs)
7+
class JAXApproximator(keras.Model):
8+
# noinspection PyMethodOverriding
9+
def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
10+
# implemented by each respective architecture
11+
raise NotImplementedError
1412

1513
def stateless_compute_metrics(
1614
self,
1715
trainable_variables: any,
1816
non_trainable_variables: any,
1917
metrics_variables: any,
20-
data: dict[str, Tensor],
18+
data: dict[str, any],
2119
stage: str = "training",
22-
) -> (Tensor, tuple):
20+
) -> (jax.Array, tuple):
2321
"""
2422
Things we do for jax:
2523
1. Accept trainable variables as the first argument
@@ -40,21 +38,35 @@ def stateless_compute_metrics(
4038

4139
# perform a stateless call to compute_metrics
4240
with keras.StatelessScope(state_mapping) as scope:
43-
metrics = self.compute_metrics(data, stage)
41+
kwargs = filter_kwargs(data | {"stage": stage}, self.compute_metrics)
42+
metrics = self.compute_metrics(**kwargs)
4443

4544
# update variables
4645
non_trainable_variables = [scope.get_current_value(v) for v in self.non_trainable_variables]
4746
metrics_variables = [scope.get_current_value(v) for v in self.metrics_variables]
4847

4948
return metrics["loss"], (metrics, non_trainable_variables, metrics_variables)
5049

51-
def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[str, Tensor], tuple):
50+
def stateless_test_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
51+
trainable_variables, non_trainable_variables, metrics_variables = state
52+
53+
loss, aux = self.stateless_compute_metrics(
54+
trainable_variables, non_trainable_variables, metrics_variables, data=data, stage="validation"
55+
)
56+
metrics, non_trainable_variables, metrics_variables = aux
57+
58+
metrics_variables = self._update_loss(loss, metrics_variables)
59+
60+
state = trainable_variables, non_trainable_variables, metrics_variables
61+
return metrics, state
62+
63+
def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
5264
trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables = state
5365

5466
grad_fn = jax.value_and_grad(self.stateless_compute_metrics, has_aux=True)
5567

5668
(loss, aux), grads = grad_fn(
57-
trainable_variables, non_trainable_variables, metrics_variables, data, stage="training"
69+
trainable_variables, non_trainable_variables, metrics_variables, data=data, stage="training"
5870
)
5971
metrics, non_trainable_variables, metrics_variables = aux
6072

@@ -67,20 +79,13 @@ def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[s
6779
state = trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables
6880
return metrics, state
6981

70-
def stateless_test_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[str, Tensor], tuple):
71-
trainable_variables, non_trainable_variables, metrics_variables = state
72-
73-
loss, aux = self.stateless_compute_metrics(
74-
trainable_variables, non_trainable_variables, metrics_variables, data, stage="validation"
75-
)
76-
metrics, non_trainable_variables, metrics_variables = aux
77-
78-
metrics_variables = self._update_loss(loss, metrics_variables)
82+
def test_step(self, *args, **kwargs):
83+
return self.stateless_test_step(*args, **kwargs)
7984

80-
state = trainable_variables, non_trainable_variables, metrics_variables
81-
return metrics, state
85+
def train_step(self, *args, **kwargs):
86+
return self.stateless_train_step(*args, **kwargs)
8287

83-
def _update_loss(self, loss, metrics_variables):
88+
def _update_loss(self, loss: jax.Array, metrics_variables: any) -> any:
8489
# update the loss progress bar, and possibly metrics variables along with it
8590
state_mapping = list(zip(self.metrics_variables, metrics_variables))
8691
with keras.StatelessScope(state_mapping) as scope:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import keras
2+
import numpy as np
3+
4+
from bayesflow.utils import filter_kwargs
5+
6+
7+
class NumpyApproximator(keras.Model):
8+
# noinspection PyMethodOverriding
9+
def compute_metrics(self, *args, **kwargs) -> dict[str, np.ndarray]:
10+
# implemented by each respective architecture
11+
raise NotImplementedError
12+
13+
def test_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
14+
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
15+
return self.compute_metrics(**kwargs)
16+
17+
def train_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
18+
raise NotImplementedError("Numpy backend does not support training.")
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import keras
2+
import tensorflow as tf
3+
4+
from bayesflow.utils import filter_kwargs
5+
6+
7+
class TensorFlowApproximator(keras.Model):
8+
# noinspection PyMethodOverriding
9+
def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
10+
# implemented by each respective architecture
11+
raise NotImplementedError
12+
13+
def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
14+
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
15+
return self.compute_metrics(**kwargs)
16+
17+
def train_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
18+
with tf.GradientTape() as tape:
19+
kwargs = filter_kwargs(data | {"stage": "training"}, self.compute_metrics)
20+
metrics = self.compute_metrics(**kwargs)
21+
22+
loss = metrics["loss"]
23+
24+
grads = tape.gradient(loss, self.trainable_variables)
25+
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
26+
27+
self._loss_tracker.update_state(loss)
28+
29+
return metrics

0 commit comments

Comments
 (0)