Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Comparison #184

Merged
merged 75 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
ffffdb2
current wip state
LarsKue Jul 17, 2024
0503954
configurator -> collate_fn
LarsKue Jul 17, 2024
2593f91
reorder dict_utils
LarsKue Jul 17, 2024
4054f2c
post-discussion state
LarsKue Jul 18, 2024
0c32c71
move into approximators
LarsKue Jul 19, 2024
e072dc1
redesign configurators
LarsKue Jul 19, 2024
af90266
improve setup and logging
LarsKue Jul 19, 2024
f8dfd4b
improve simulators
LarsKue Jul 19, 2024
bc80df5
add configurators to datasets
LarsKue Jul 19, 2024
c4cf8d2
slight adjustment to flow matching
LarsKue Jul 19, 2024
9816a2f
preliminary adjustments to tests (WIP)
LarsKue Jul 19, 2024
6c7e8bf
WIP state example
LarsKue Jul 19, 2024
9f19c1a
fix jax logging
LarsKue Jul 22, 2024
4589cfe
switch to using dicts as compute_metrics data input
LarsKue Jul 22, 2024
6ebe8ea
use bayesflow logger instead of root logger
LarsKue Jul 23, 2024
417f48d
unify backend approximators for shared behavior
LarsKue Jul 23, 2024
9adf83d
better type hints
LarsKue Jul 23, 2024
f490ff5
modularize example
LarsKue Jul 23, 2024
0665b97
fix unbuilt inference network compute_metrics
LarsKue Jul 23, 2024
43701e9
fix continuous approximator for summary_network=None
LarsKue Jul 23, 2024
0cc96d0
introduce memory budget for automatic batch sizing
LarsKue Jul 25, 2024
a944141
revert to using an id list instead of set because some tensor types a…
LarsKue Jul 25, 2024
745c5d4
fix docs
LarsKue Jul 25, 2024
e82043e
ask for forgiveness instead of permission
LarsKue Jul 25, 2024
0b7fc67
reduce functionality of BackendApproximator
LarsKue Jul 25, 2024
5cc1f2c
modularize automatic dataset building
LarsKue Jul 25, 2024
79fa952
improve definition of Tensor type
LarsKue Jul 26, 2024
9099413
make bound available internally as BackendTensor
LarsKue Jul 26, 2024
232ac6c
adjust dispatch to be based on BackendTensor
LarsKue Jul 26, 2024
3f110ec
update size_of for any nested structure of tensors
LarsKue Jul 26, 2024
1693243
move finding batch size and memory budget into utils
LarsKue Jul 26, 2024
d9109b1
improve error message for incorrect memory budget format
LarsKue Jul 29, 2024
9810587
improve error catching in Simulator.rejection_sample
LarsKue Jul 29, 2024
7a785d8
state dump
LarsKue Jul 30, 2024
167bfa4
state dump
LarsKue Jul 30, 2024
a976a25
Merge remote-tracking branch 'origin/model-comparison' into model-com…
LarsKue Jul 30, 2024
5afe352
concatenate_dicts -> tree_concatenate, stack_dicts -> tree_stack
LarsKue Jul 30, 2024
fca72a6
configurators -> data_adapters
LarsKue Jul 30, 2024
acf0327
Simulator now only generates numpy arrays by design
LarsKue Jul 30, 2024
66c1897
state dump
LarsKue Aug 7, 2024
2982d95
continue porting simulators to pure numpy
LarsKue Aug 7, 2024
d3e21b1
allow arbitrary keyword-arguments for approximator's compute_metrics
LarsKue Aug 7, 2024
bd0bdf8
mark flaky tests as such
LarsKue Aug 9, 2024
e5e43b8
reintroduce metrics to approximators
LarsKue Aug 9, 2024
f788ff5
fix data adapter serialization
LarsKue Aug 9, 2024
3ad4872
fix sample-based metrics in inference network
LarsKue Aug 9, 2024
1e6e5d7
modularize sample method of lambda simulators
LarsKue Aug 9, 2024
ef28225
reintroduce improved batched_call
LarsKue Aug 9, 2024
b29cae7
final fixes
LarsKue Aug 9, 2024
c5cdbd9
remove old example
LarsKue Aug 9, 2024
9796d01
Merge branch 'refs/heads/streamlined-backend' into model-comparison
LarsKue Aug 9, 2024
5582df8
add numpy version of optimal transport
LarsKue Aug 12, 2024
fda7768
rename batches_per_epoch -> num_batches for consistency
LarsKue Aug 13, 2024
5a0afe4
add numpy optimal transport
LarsKue Aug 13, 2024
362ce5b
fix size_of util for numpy arrays with non-numpy backend
LarsKue Aug 13, 2024
0380e17
add (maybe temporary) split_tensors utility for plotting
LarsKue Aug 13, 2024
30af1ee
make keys optional for configuration / deconfiguration in composite d…
LarsKue Aug 13, 2024
4649daa
rename batches_per_epoch -> num_batches, matching fda7768817ec28ebe9f…
LarsKue Aug 13, 2024
3553589
fix data adapter passing to super fit
LarsKue Aug 13, 2024
40d628a
reimplement sample and log prob
LarsKue Aug 13, 2024
b09d09f
update example (not working yet)
LarsKue Aug 13, 2024
e6bdf83
fix missing rename for batches_per_epoch arg
LarsKue Aug 14, 2024
c3658a0
allow dropout=None in MLP
LarsKue Aug 14, 2024
4aebef6
fix configure order for concatenate keys data adapter
LarsKue Aug 14, 2024
013e7ae
fix flow matching data adapter
LarsKue Aug 14, 2024
bb657b7
improve building of approximators
LarsKue Aug 14, 2024
61447a1
ensure numpy v1
LarsKue Aug 14, 2024
329a830
make optimal transport optional
LarsKue Aug 14, 2024
f672dbf
improve logging message
LarsKue Aug 14, 2024
33c69d3
improve versioning
LarsKue Aug 14, 2024
d69c0ab
update two moons example
LarsKue Aug 14, 2024
0ae5d72
fix tests
LarsKue Aug 14, 2024
4660267
final fixes
LarsKue Aug 14, 2024
90b8757
Merge branch 'streamlined-backend' into model-comparison
LarsKue Aug 14, 2024
151b9d5
fix lint
LarsKue Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
from . import (
approximators,
configurators,
benchmarks,
data_adapters,
datasets,
diagnostics,
distributions,
networks,
simulators,
benchmarks,
utils,
)

from .approximators import Approximator
from .approximators import ContinuousApproximator
from .datasets import OfflineDataset, OnlineDataset

import keras

if keras.backend.backend() == "torch":
# turn off gradients by default
import torch
def setup():
# perform any necessary setup without polluting the namespace
import keras
import logging

# set the basic logging level if the user hasn't already
logging.basicConfig(level=logging.INFO)

# use a separate logger for the bayesflow package
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

if keras.backend.backend() == "torch":
# turn off gradients by default
import torch

torch.autograd.set_grad_enabled(False)

torch.autograd.set_grad_enabled(False)

# clean up namespace
del keras
del torch
# call and clean up namespace
setup()
del setup
2 changes: 2 additions & 0 deletions bayesflow/approximators/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .approximator import Approximator
from .continuous_approximator import ContinuousApproximator
from .model_comparison_approximator import ModelComparisonApproximator
162 changes: 79 additions & 83 deletions bayesflow/approximators/approximator.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,84 @@
import keras
from keras.saving import register_keras_serializable

from bayesflow.configurators import Configurator

match keras.backend.backend():
case "jax":
from .jax_approximator import JAXApproximator as BaseApproximator
case "numpy":
from .numpy_approximator import NumpyApproximator as BaseApproximator
case "tensorflow":
from .tensorflow_approximator import TensorFlowApproximator as BaseApproximator
case "torch":
from .torch_approximator import TorchApproximator as BaseApproximator
case other:
raise NotImplementedError(f"BayesFlow does not currently support backend '{other}'.")


@register_keras_serializable(package="bayesflow.amortizers")
class Approximator(BaseApproximator):
def __init__(self, **kwargs):
"""The main workhorse for learning amortized neural approximators for distributions arising
in inverse problems and Bayesian inference (e.g., posterior distributions, likelihoods, marginal
likelihoods).

The complete semantics of this class allow for flexible estimation of the following distribution:

Q(inference_variables | H(summary_variables; summary_conditions), inference_conditions),

# TODO - math notation

where all quantities to the right of the "given" symbol | are optional and H refers to the optional
summary /embedding network used to compress high-dimensional data into lower-dimensional summary
vectors. Some examples are provided below.

Parameters
----------
inference_variables: list[str]
A list of variable names indicating the quantities to be inferred / learned by the approximator,
e.g., model parameters when approximating the Bayesian posterior or observables when approximating
a likelihood density.
inference_conditions: list[str]
A list of variable names indicating quantities that will be used to condition (i.e., inform) the
distribution over inference variables directly, that is, without passing through the summary network.
summary_variables: list[str]
A list of variable names indicating quantities that will be used to condition (i.e., inform) the
distribution over inference variables after passing through the summary network (i.e., undergoing a
learnable transformation / dimensionality reduction). For instance, non-vector quantities (e.g.,
sets or time-series) in posterior inference will typically qualify as summary variables. In addition,
these quantities may involve learnable distributions on their own.
summary_conditions: list[str]
A list of variable names indicating quantities that will be used to condition (i.e., inform) the
optional summary network, e.g., when the summary network accepts further conditions that do not
conform to the semantics of summary variable (i.e., need not be embedded or their distribution
needs not be learned).

# TODO add citations

Examples
-------
# TODO
"""
if "configurator" not in kwargs:
# try to set up a default configurator
if "inference_variables" not in kwargs:
raise ValueError("You must specify either a configurator or arguments for the default configurator.")

inference_variables = kwargs.pop("inference_variables")
inference_conditions = kwargs.pop("inference_conditions", None)
summary_variables = kwargs.pop("summary_variables", None)

kwargs["configurator"] = Configurator(
inference_variables,
inference_conditions,
summary_variables,
)
import multiprocessing as mp

from bayesflow.data_adapters import DataAdapter
from bayesflow.datasets import OnlineDataset
from bayesflow.simulators import Simulator
from bayesflow.utils import find_batch_size, filter_kwargs, logging

from .backend_approximators import BackendApproximator


class Approximator(BackendApproximator):
def build(self, data_shapes: any) -> None:
mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes)
self.build_from_data(mock_data)

@classmethod
def build_data_adapter(cls, **kwargs) -> DataAdapter:
# implemented by each respective architecture
raise NotImplementedError

def build_from_data(self, data: dict[str, any]) -> None:
self.compute_metrics(**data, stage="training")
self.built = True

@classmethod
def build_dataset(
cls,
*,
batch_size: int = "auto",
num_batches: int,
data_adapter: DataAdapter = "auto",
memory_budget: str | int = "auto",
simulator: Simulator,
workers: int = "auto",
use_multiprocessing: bool = False,
max_queue_size: int = 32,
**kwargs,
) -> OnlineDataset:
if batch_size == "auto":
batch_size = find_batch_size(memory_budget=memory_budget, sample=simulator.sample((1,)))
logging.info(f"Using a batch size of {batch_size}.")

if data_adapter == "auto":
data_adapter = cls.build_data_adapter(**filter_kwargs(kwargs, cls.build_data_adapter))

if workers == "auto":
workers = mp.cpu_count()
logging.info(f"Using {workers} data loading workers.")

workers = workers or 1

return OnlineDataset(
simulator=simulator,
batch_size=batch_size,
num_batches=num_batches,
data_adapter=data_adapter,
workers=workers,
use_multiprocessing=use_multiprocessing,
max_queue_size=max_queue_size,
)

def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
if dataset is None:
if simulator is None:
raise ValueError("Received no data to fit on. Please provide either a dataset or a simulator.")

logging.info(f"Building dataset from simulator instance of {simulator.__class__.__name__}.")
dataset = self.build_dataset(simulator=simulator, **filter_kwargs(kwargs, self.build_dataset))
else:
# the user passed a configurator, so we should not configure a default one
# check if the user also passed args for the default configurator
keys = ["inference_variables", "inference_conditions", "summary_variables"]
if any(key in kwargs for key in keys):
if simulator is not None:
raise ValueError(
"Received an ambiguous set of arguments: You are passing a configurator explicitly, "
"but also providing arguments for the default configurator."
"Received conflicting arguments. Please provide either a dataset or a simulator, but not both."
)

kwargs.setdefault("summary_network", None)
super().__init__(**kwargs)
logging.info(f"Fitting on dataset instance of {dataset.__class__.__name__}.")

if not self.built:
logging.info("Building on a test batch.")
mock_data = dataset[0]
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
self.build_from_data(mock_data)

return super().fit(dataset=dataset, **kwargs)
1 change: 1 addition & 0 deletions bayesflow/approximators/backend_approximators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .backend_approximator import BackendApproximator
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import keras

from bayesflow.utils import filter_kwargs


match keras.backend.backend():
case "jax":
from .jax_approximator import JAXApproximator as BaseBackendApproximator
case "numpy":
from .numpy_approximator import NumpyApproximator as BaseBackendApproximator
case "tensorflow":
from .tensorflow_approximator import TensorFlowApproximator as BaseBackendApproximator
case "torch":
from .torch_approximator import TorchApproximator as BaseBackendApproximator
case other:
raise ValueError(f"Backend '{other}' is not supported.")


class BackendApproximator(BaseBackendApproximator):
# noinspection PyMethodOverriding
def fit(self, *, dataset: keras.utils.PyDataset, **kwargs):
return super().fit(x=dataset, y=None, **filter_kwargs(kwargs, super().fit))
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import jax
import keras

from .base_approximator import BaseApproximator
from ..types import Tensor
from bayesflow.utils import filter_kwargs


class JAXApproximator(BaseApproximator):
def train_step(self, *args, **kwargs):
return self.stateless_train_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.stateless_test_step(*args, **kwargs)
class JAXApproximator(keras.Model):
# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
# implemented by each respective architecture
raise NotImplementedError

def stateless_compute_metrics(
self,
trainable_variables: any,
non_trainable_variables: any,
metrics_variables: any,
data: dict[str, Tensor],
data: dict[str, any],
stage: str = "training",
) -> (Tensor, tuple):
) -> (jax.Array, tuple):
"""
Things we do for jax:
1. Accept trainable variables as the first argument
Expand All @@ -40,21 +38,35 @@ def stateless_compute_metrics(

# perform a stateless call to compute_metrics
with keras.StatelessScope(state_mapping) as scope:
metrics = self.compute_metrics(data, stage)
kwargs = filter_kwargs(data | {"stage": stage}, self.compute_metrics)
metrics = self.compute_metrics(**kwargs)

# update variables
non_trainable_variables = [scope.get_current_value(v) for v in self.non_trainable_variables]
metrics_variables = [scope.get_current_value(v) for v in self.metrics_variables]

return metrics["loss"], (metrics, non_trainable_variables, metrics_variables)

def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[str, Tensor], tuple):
def stateless_test_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
trainable_variables, non_trainable_variables, metrics_variables = state

loss, aux = self.stateless_compute_metrics(
trainable_variables, non_trainable_variables, metrics_variables, data=data, stage="validation"
)
metrics, non_trainable_variables, metrics_variables = aux

metrics_variables = self._update_loss(loss, metrics_variables)

state = trainable_variables, non_trainable_variables, metrics_variables
return metrics, state

def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables = state

grad_fn = jax.value_and_grad(self.stateless_compute_metrics, has_aux=True)

(loss, aux), grads = grad_fn(
trainable_variables, non_trainable_variables, metrics_variables, data, stage="training"
trainable_variables, non_trainable_variables, metrics_variables, data=data, stage="training"
)
metrics, non_trainable_variables, metrics_variables = aux

Expand All @@ -67,20 +79,13 @@ def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[s
state = trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables
return metrics, state

def stateless_test_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[str, Tensor], tuple):
trainable_variables, non_trainable_variables, metrics_variables = state

loss, aux = self.stateless_compute_metrics(
trainable_variables, non_trainable_variables, metrics_variables, data, stage="validation"
)
metrics, non_trainable_variables, metrics_variables = aux

metrics_variables = self._update_loss(loss, metrics_variables)
def test_step(self, *args, **kwargs):
return self.stateless_test_step(*args, **kwargs)

state = trainable_variables, non_trainable_variables, metrics_variables
return metrics, state
def train_step(self, *args, **kwargs):
return self.stateless_train_step(*args, **kwargs)

def _update_loss(self, loss, metrics_variables):
def _update_loss(self, loss: jax.Array, metrics_variables: any) -> any:
# update the loss progress bar, and possibly metrics variables along with it
state_mapping = list(zip(self.metrics_variables, metrics_variables))
with keras.StatelessScope(state_mapping) as scope:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import keras
import numpy as np

from bayesflow.utils import filter_kwargs


class NumpyApproximator(keras.Model):
# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, np.ndarray]:
# implemented by each respective architecture
raise NotImplementedError

def test_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
return self.compute_metrics(**kwargs)

def train_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
raise NotImplementedError("Numpy backend does not support training.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import keras
import tensorflow as tf

from bayesflow.utils import filter_kwargs


class TensorFlowApproximator(keras.Model):
# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
# implemented by each respective architecture
raise NotImplementedError

def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
return self.compute_metrics(**kwargs)

def train_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
with tf.GradientTape() as tape:
kwargs = filter_kwargs(data | {"stage": "training"}, self.compute_metrics)
metrics = self.compute_metrics(**kwargs)

loss = metrics["loss"]

grads = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

self._loss_tracker.update_state(loss)

return metrics
Loading
Loading