|
1 | 1 | import keras
|
2 |
| -from keras.saving import register_keras_serializable |
3 |
| - |
4 |
| -from bayesflow.configurators import Configurator |
5 |
| - |
6 |
| -match keras.backend.backend(): |
7 |
| - case "jax": |
8 |
| - from .jax_approximator import JAXApproximator as BaseApproximator |
9 |
| - case "numpy": |
10 |
| - from .numpy_approximator import NumpyApproximator as BaseApproximator |
11 |
| - case "tensorflow": |
12 |
| - from .tensorflow_approximator import TensorFlowApproximator as BaseApproximator |
13 |
| - case "torch": |
14 |
| - from .torch_approximator import TorchApproximator as BaseApproximator |
15 |
| - case other: |
16 |
| - raise NotImplementedError(f"BayesFlow does not currently support backend '{other}'.") |
17 |
| - |
18 |
| - |
19 |
| -@register_keras_serializable(package="bayesflow.amortizers") |
20 |
| -class Approximator(BaseApproximator): |
21 |
| - def __init__(self, **kwargs): |
22 |
| - """The main workhorse for learning amortized neural approximators for distributions arising |
23 |
| - in inverse problems and Bayesian inference (e.g., posterior distributions, likelihoods, marginal |
24 |
| - likelihoods). |
25 |
| -
|
26 |
| - The complete semantics of this class allow for flexible estimation of the following distribution: |
27 |
| -
|
28 |
| - Q(inference_variables | H(summary_variables; summary_conditions), inference_conditions), |
29 |
| -
|
30 |
| - # TODO - math notation |
31 |
| -
|
32 |
| - where all quantities to the right of the "given" symbol | are optional and H refers to the optional |
33 |
| - summary /embedding network used to compress high-dimensional data into lower-dimensional summary |
34 |
| - vectors. Some examples are provided below. |
35 |
| -
|
36 |
| - Parameters |
37 |
| - ---------- |
38 |
| - inference_variables: list[str] |
39 |
| - A list of variable names indicating the quantities to be inferred / learned by the approximator, |
40 |
| - e.g., model parameters when approximating the Bayesian posterior or observables when approximating |
41 |
| - a likelihood density. |
42 |
| - inference_conditions: list[str] |
43 |
| - A list of variable names indicating quantities that will be used to condition (i.e., inform) the |
44 |
| - distribution over inference variables directly, that is, without passing through the summary network. |
45 |
| - summary_variables: list[str] |
46 |
| - A list of variable names indicating quantities that will be used to condition (i.e., inform) the |
47 |
| - distribution over inference variables after passing through the summary network (i.e., undergoing a |
48 |
| - learnable transformation / dimensionality reduction). For instance, non-vector quantities (e.g., |
49 |
| - sets or time-series) in posterior inference will typically qualify as summary variables. In addition, |
50 |
| - these quantities may involve learnable distributions on their own. |
51 |
| - summary_conditions: list[str] |
52 |
| - A list of variable names indicating quantities that will be used to condition (i.e., inform) the |
53 |
| - optional summary network, e.g., when the summary network accepts further conditions that do not |
54 |
| - conform to the semantics of summary variable (i.e., need not be embedded or their distribution |
55 |
| - needs not be learned). |
56 |
| -
|
57 |
| - # TODO add citations |
58 |
| -
|
59 |
| - Examples |
60 |
| - ------- |
61 |
| - # TODO |
62 |
| - """ |
63 |
| - if "configurator" not in kwargs: |
64 |
| - # try to set up a default configurator |
65 |
| - if "inference_variables" not in kwargs: |
66 |
| - raise ValueError("You must specify either a configurator or arguments for the default configurator.") |
67 |
| - |
68 |
| - inference_variables = kwargs.pop("inference_variables") |
69 |
| - inference_conditions = kwargs.pop("inference_conditions", None) |
70 |
| - summary_variables = kwargs.pop("summary_variables", None) |
71 |
| - |
72 |
| - kwargs["configurator"] = Configurator( |
73 |
| - inference_variables, |
74 |
| - inference_conditions, |
75 |
| - summary_variables, |
76 |
| - ) |
| 2 | +import multiprocessing as mp |
| 3 | + |
| 4 | +from bayesflow.data_adapters import DataAdapter |
| 5 | +from bayesflow.datasets import OnlineDataset |
| 6 | +from bayesflow.simulators import Simulator |
| 7 | +from bayesflow.utils import find_batch_size, filter_kwargs, logging |
| 8 | + |
| 9 | +from .backend_approximators import BackendApproximator |
| 10 | + |
| 11 | + |
| 12 | +class Approximator(BackendApproximator): |
| 13 | + def build(self, data_shapes: any) -> None: |
| 14 | + mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes) |
| 15 | + self.build_from_data(mock_data) |
| 16 | + |
| 17 | + @classmethod |
| 18 | + def build_data_adapter(cls, **kwargs) -> DataAdapter: |
| 19 | + # implemented by each respective architecture |
| 20 | + raise NotImplementedError |
| 21 | + |
| 22 | + def build_from_data(self, data: dict[str, any]) -> None: |
| 23 | + self.compute_metrics(**data, stage="training") |
| 24 | + self.built = True |
| 25 | + |
| 26 | + @classmethod |
| 27 | + def build_dataset( |
| 28 | + cls, |
| 29 | + *, |
| 30 | + batch_size: int = "auto", |
| 31 | + num_batches: int, |
| 32 | + data_adapter: DataAdapter = "auto", |
| 33 | + memory_budget: str | int = "auto", |
| 34 | + simulator: Simulator, |
| 35 | + workers: int = "auto", |
| 36 | + use_multiprocessing: bool = False, |
| 37 | + max_queue_size: int = 32, |
| 38 | + **kwargs, |
| 39 | + ) -> OnlineDataset: |
| 40 | + if batch_size == "auto": |
| 41 | + batch_size = find_batch_size(memory_budget=memory_budget, sample=simulator.sample((1,))) |
| 42 | + logging.info(f"Using a batch size of {batch_size}.") |
| 43 | + |
| 44 | + if data_adapter == "auto": |
| 45 | + data_adapter = cls.build_data_adapter(**filter_kwargs(kwargs, cls.build_data_adapter)) |
| 46 | + |
| 47 | + if workers == "auto": |
| 48 | + workers = mp.cpu_count() |
| 49 | + logging.info(f"Using {workers} data loading workers.") |
| 50 | + |
| 51 | + workers = workers or 1 |
| 52 | + |
| 53 | + return OnlineDataset( |
| 54 | + simulator=simulator, |
| 55 | + batch_size=batch_size, |
| 56 | + num_batches=num_batches, |
| 57 | + data_adapter=data_adapter, |
| 58 | + workers=workers, |
| 59 | + use_multiprocessing=use_multiprocessing, |
| 60 | + max_queue_size=max_queue_size, |
| 61 | + ) |
| 62 | + |
| 63 | + def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs): |
| 64 | + if dataset is None: |
| 65 | + if simulator is None: |
| 66 | + raise ValueError("Received no data to fit on. Please provide either a dataset or a simulator.") |
| 67 | + |
| 68 | + logging.info(f"Building dataset from simulator instance of {simulator.__class__.__name__}.") |
| 69 | + dataset = self.build_dataset(simulator=simulator, **filter_kwargs(kwargs, self.build_dataset)) |
77 | 70 | else:
|
78 |
| - # the user passed a configurator, so we should not configure a default one |
79 |
| - # check if the user also passed args for the default configurator |
80 |
| - keys = ["inference_variables", "inference_conditions", "summary_variables"] |
81 |
| - if any(key in kwargs for key in keys): |
| 71 | + if simulator is not None: |
82 | 72 | raise ValueError(
|
83 |
| - "Received an ambiguous set of arguments: You are passing a configurator explicitly, " |
84 |
| - "but also providing arguments for the default configurator." |
| 73 | + "Received conflicting arguments. Please provide either a dataset or a simulator, but not both." |
85 | 74 | )
|
86 | 75 |
|
87 |
| - kwargs.setdefault("summary_network", None) |
88 |
| - super().__init__(**kwargs) |
| 76 | + logging.info(f"Fitting on dataset instance of {dataset.__class__.__name__}.") |
| 77 | + |
| 78 | + if not self.built: |
| 79 | + logging.info("Building on a test batch.") |
| 80 | + mock_data = dataset[0] |
| 81 | + mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data) |
| 82 | + self.build_from_data(mock_data) |
| 83 | + |
| 84 | + return super().fit(dataset=dataset, **kwargs) |
0 commit comments