Skip to content

Commit 00f0a89

Browse files
committed
address remaining comments
1 parent 676b0cd commit 00f0a89

File tree

8 files changed

+64
-81
lines changed

8 files changed

+64
-81
lines changed

bayesflow/approximators/backend_approximators/backend_approximator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
match keras.backend.backend():
77
case "jax":
88
from .jax_approximator import JAXApproximator as BaseBackendApproximator
9-
case "numpy":
10-
from .numpy_approximator import NumpyApproximator as BaseBackendApproximator
119
case "tensorflow":
1210
from .tensorflow_approximator import TensorFlowApproximator as BaseBackendApproximator
1311
case "torch":

bayesflow/approximators/backend_approximators/numpy_approximator.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

bayesflow/approximators/continuous_approximator.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def sample(
448448
conditions = self._prepare_data(conditions, **kwargs)
449449

450450
# Remove any superfluous keys, just retain actual conditions
451-
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
451+
conditions = {k: v for k, v in conditions.items() if k in self.CONDITION_KEYS}
452452

453453
# Sample and undo optional standardization
454454
samples = self._sample(num_samples=num_samples, **conditions, **kwargs)
@@ -485,7 +485,7 @@ def _prepare_data(
485485
ldj_inference = None
486486

487487
# Standardize conditions
488-
for key in ContinuousApproximator.CONDITION_KEYS:
488+
for key in self.CONDITION_KEYS:
489489
if key in self.standardize and key in data:
490490
data[key] = self.standardize_layers[key](data[key])
491491

@@ -514,8 +514,12 @@ def _sample(
514514
summary_variables: Tensor = None,
515515
**kwargs,
516516
) -> Tensor:
517-
if (self.summary_network is None) != (summary_variables is None):
518-
raise ValueError("Summary variables and summary network must be used together.")
517+
if self.summary_network is None:
518+
if summary_variables is not None:
519+
raise ValueError("Cannot use summary variables without a summary network.")
520+
else:
521+
if summary_variables is None:
522+
raise ValueError("Summary variables are required when a summary network is present.")
519523

520524
if self.summary_network is not None:
521525
summary_outputs = self.summary_network(
@@ -606,8 +610,12 @@ def _log_prob(
606610
summary_variables: Tensor = None,
607611
**kwargs,
608612
) -> Tensor:
609-
if (self.summary_network is None) != (summary_variables is None):
610-
raise ValueError("Summary variables and summary network must be used together.")
613+
if self.summary_network is None:
614+
if summary_variables is not None:
615+
raise ValueError("Cannot use summary variables without a summary network.")
616+
else:
617+
if summary_variables is None:
618+
raise ValueError("Summary variables are required when a summary network is present.")
611619

612620
if self.summary_network is not None:
613621
summary_outputs = self.summary_network(

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
9292

9393
# Set up standardization layers if requested
9494
if self.standardize == "all":
95-
self.standardize = [var for var in ModelComparisonApproximator.CONDITION_KEYS if var in data_shapes]
95+
self.standardize = [var for var in self.CONDITION_KEYS if var in data_shapes]
9696
self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize}
9797

9898
# Build all standardization layers
@@ -242,7 +242,7 @@ def compute_metrics(
242242
def fit(
243243
self,
244244
*,
245-
adapter: Adapter | str = "auto",
245+
adapter: Adapter = "auto",
246246
dataset: keras.utils.PyDataset = None,
247247
simulator: ModelComparisonSimulator = None,
248248
simulators: Sequence[Simulator] = None,
@@ -256,7 +256,7 @@ def fit(
256256
257257
Parameters
258258
----------
259-
adapter : Adapter or str, optional
259+
adapter : Adapter or 'auto', optional
260260
The data adapter that will make the simulated / real outputs neural-network friendly.
261261
dataset : keras.utils.PyDataset, optional
262262
A dataset containing simulations for training. If provided, `simulator` must be None.
@@ -392,17 +392,20 @@ def predict(
392392
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
393393

394394
# Ensure only keys relevant for sampling are present in the conditions dictionary
395-
conditions = {k: v for k, v in conditions.items() if k in ModelComparisonApproximator.CONDITION_KEYS}
395+
conditions = {k: v for k, v in conditions.items() if k in self.CONDITION_KEYS}
396396
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
397397

398398
# Optionally standardize conditions
399-
for key in ModelComparisonApproximator.CONDITION_KEYS:
399+
for key in self.CONDITION_KEYS:
400400
if key in conditions and key in self.standardize:
401401
conditions[key] = self.standardize_layers[key](conditions[key])
402402

403403
output = self._predict(**conditions, **kwargs)
404404

405-
return keras.ops.convert_to_numpy(keras.ops.softmax(output) if probs else output)
405+
if probs:
406+
output = keras.ops.softmax(output)
407+
408+
return keras.ops.convert_to_numpy(output)
406409

407410
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
408411
"""

bayesflow/approximators/point_approximator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def estimate(
5858
# Adapt, optionally standardize and convert conditions to tensor.
5959
conditions = self._prepare_data(conditions, **kwargs)
6060
# Remove any superfluous keys, just retain actual conditions. # TODO: is this necessary?
61-
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
61+
conditions = {k: v for k, v in conditions.items() if k in self.CONDITION_KEYS}
6262

6363
estimates = self._estimate(**conditions, **kwargs)
6464

@@ -77,9 +77,9 @@ def estimate(
7777
estimates = split_arrays(estimates, axis=-1)
7878

7979
# Reorder the nested dictionary so that original variable names are at the top.
80-
estimates = PointApproximator._reorder_estimates(estimates)
80+
estimates = self._reorder_estimates(estimates)
8181
# Remove unnecessary nesting.
82-
estimates = PointApproximator._squeeze_estimates(estimates)
82+
estimates = self._squeeze_estimates(estimates)
8383

8484
return estimates
8585

@@ -124,7 +124,7 @@ def sample(
124124
# Adapt, optionally standardize and convert conditions to tensor.
125125
conditions = self._prepare_data(conditions, **kwargs)
126126
# Remove any superfluous keys, just retain actual conditions. # TODO: is this necessary?
127-
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
127+
conditions = {k: v for k, v in conditions.items() if k in self.CONDITION_KEYS}
128128

129129
# Sample and undo optional standardization
130130
samples = self._sample(num_samples, **conditions, **kwargs)
@@ -183,7 +183,7 @@ def log_prob(
183183
if log_det_jac is not None:
184184
log_prob = keras.tree.map_structure(lambda x: x + log_det_jac, log_prob)
185185

186-
log_prob = PointApproximator._squeeze_parametric_score_major_dict(log_prob)
186+
log_prob = self._squeeze_parametric_score_major_dict(log_prob)
187187

188188
return log_prob
189189

bayesflow/datasets/disk_dataset.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Mapping, Callable
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import os
44
import pathlib as pl
@@ -36,7 +36,7 @@ def __init__(
3636
load_fn: Callable = None,
3737
adapter: Adapter | None,
3838
stage: str = "training",
39-
augmentations: Mapping[str, Callable] | Callable = None,
39+
augmentations: Callable | Mapping[str, Callable] | Sequence[Callable] = None,
4040
shuffle: bool = True,
4141
**kwargs,
4242
):
@@ -58,13 +58,14 @@ def __init__(
5858
Optional adapter to transform the loaded batch.
5959
stage : str, default="training"
6060
Current stage (e.g., "training", "validation", etc.) used by the adapter.
61-
augmentations : dict of str to Callable or Callable, optional
62-
Dictionary of augmentation functions to apply to each corresponding key in the batch
63-
or a function to apply to the entire batch (possibly adding new keys).
61+
augmentations : Callable or Mapping[str, Callable] or Sequence[Callable], optional
62+
A single augmentation function, dictionary of augmentation functions, or sequence of augmentation functions
63+
to apply to the batch.
6464
6565
If you provide a dictionary of functions, each function should accept one element
66-
of your output batch and return the corresponding transformed element. Otherwise,
67-
your function should accept the entire dictionary output and return a dictionary.
66+
of your output batch and return the corresponding transformed element.
67+
68+
Otherwise, your function should accept the entire dictionary output and return a dictionary.
6869
6970
Note - augmentations are applied before the adapter is called and are generally
7071
transforms that you only want to apply during training.
@@ -81,7 +82,7 @@ def __init__(
8182
self.files = list(map(str, self.root.glob(pattern)))
8283
self.stage = stage
8384

84-
self.augmentations = augmentations
85+
self.augmentations = augmentations or []
8586
self._shuffle = shuffle
8687
if self._shuffle:
8788
self.shuffle()
@@ -101,6 +102,9 @@ def __getitem__(self, item) -> dict[str, np.ndarray]:
101102
elif isinstance(self.augmentations, Mapping):
102103
for key, fn in self.augmentations.items():
103104
batch[key] = fn(batch[key])
105+
elif isinstance(self.augmentations, Sequence):
106+
for fn in self.augmentations:
107+
batch = fn(batch)
104108
elif isinstance(self.augmentations, Callable):
105109
batch = self.augmentations(batch)
106110
else:

bayesflow/datasets/offline_dataset.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Mapping, Callable
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import numpy as np
44

@@ -23,7 +23,7 @@ def __init__(
2323
num_samples: int = None,
2424
*,
2525
stage: str = "training",
26-
augmentations: Mapping[str, Callable] | Callable = None,
26+
augmentations: Callable | Mapping[str, Callable] | Sequence[Callable] = None,
2727
shuffle: bool = True,
2828
**kwargs,
2929
):
@@ -42,13 +42,14 @@ def __init__(
4242
Number of samples in the dataset. If None, it will be inferred from the data.
4343
stage : str, default="training"
4444
Current stage (e.g., "training", "validation", etc.) used by the adapter.
45-
augmentations : dict of str to Callable or Callable, optional
46-
Dictionary of augmentation functions to apply to each corresponding key in the batch
47-
or a function to apply to the entire batch (possibly adding new keys).
45+
augmentations : Callable or Mapping[str, Callable] or Sequence[Callable], optional
46+
A single augmentation function, dictionary of augmentation functions, or sequence of augmentation functions
47+
to apply to the batch.
4848
4949
If you provide a dictionary of functions, each function should accept one element
50-
of your output batch and return the corresponding transformed element. Otherwise,
51-
your function should accept the entire dictionary output and return a dictionary.
50+
of your output batch and return the corresponding transformed element.
51+
52+
Otherwise, your function should accept the entire dictionary output and return a dictionary.
5253
5354
Note - augmentations are applied before the adapter is called and are generally
5455
transforms that you only want to apply during training.
@@ -71,7 +72,7 @@ def __init__(
7172

7273
self.indices = np.arange(self.num_samples, dtype="int64")
7374

74-
self.augmentations = augmentations
75+
self.augmentations = augmentations or []
7576
self._shuffle = shuffle
7677
if self._shuffle:
7778
self.shuffle()
@@ -111,6 +112,9 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
111112
elif isinstance(self.augmentations, Mapping):
112113
for key, fn in self.augmentations.items():
113114
batch[key] = fn(batch[key])
115+
elif isinstance(self.augmentations, Sequence):
116+
for fn in self.augmentations:
117+
batch = fn(batch)
114118
elif isinstance(self.augmentations, Callable):
115119
batch = self.augmentations(batch)
116120
else:

bayesflow/datasets/online_dataset.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Mapping, Callable
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import keras
44
import numpy as np
@@ -20,7 +20,7 @@ def __init__(
2020
adapter: Adapter | None,
2121
*,
2222
stage: str = "training",
23-
augmentations: Mapping[str, Callable] | Callable = None,
23+
augmentations: Callable | Mapping[str, Callable] | Sequence[Callable] = None,
2424
**kwargs,
2525
):
2626
"""
@@ -38,13 +38,14 @@ def __init__(
3838
Optional adapter to transform the simulated batch.
3939
stage : str, default="training"
4040
Current stage (e.g., "training", "validation", etc.) used by the adapter.
41-
augmentations : dict of str to Callable or Callable, optional
42-
Dictionary of augmentation functions to apply to each corresponding key in the batch
43-
or a function to apply to the entire batch (possibly adding new keys).
41+
augmentations : Callable or Mapping[str, Callable] or Sequence[Callable], optional
42+
A single augmentation function, dictionary of augmentation functions, or sequence of augmentation functions
43+
to apply to the batch.
4444
4545
If you provide a dictionary of functions, each function should accept one element
46-
of your output batch and return the corresponding transformed element. Otherwise,
47-
your function should accept the entire dictionary output and return a dictionary.
46+
of your output batch and return the corresponding transformed element.
47+
48+
Otherwise, your function should accept the entire dictionary output and return a dictionary.
4849
4950
Note - augmentations are applied before the adapter is called and are generally
5051
transforms that you only want to apply during training.
@@ -58,7 +59,7 @@ def __init__(
5859
self.adapter = adapter
5960
self.simulator = simulator
6061
self.stage = stage
61-
self.augmentations = augmentations
62+
self.augmentations = augmentations or []
6263

6364
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
6465
"""
@@ -81,6 +82,9 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
8182
elif isinstance(self.augmentations, Mapping):
8283
for key, fn in self.augmentations.items():
8384
batch[key] = fn(batch[key])
85+
elif isinstance(self.augmentations, Sequence):
86+
for fn in self.augmentations:
87+
batch = fn(batch)
8488
elif isinstance(self.augmentations, Callable):
8589
batch = self.augmentations(batch)
8690
else:

0 commit comments

Comments
 (0)