Skip to content

Commit 057f3fd

Browse files
authored
Merge pull request #516 from bayesflow-org/rename-summaries
Rename approximator.summaries to summarize with deprecation
2 parents 00f0a89 + 329ebe7 commit 057f3fd

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
import keras
6+
import warnings
67

78
from bayesflow.adapters import Adapter
89
from bayesflow.networks import InferenceNetwork, SummaryNetwork
@@ -543,7 +544,7 @@ def _sample(
543544
batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample)
544545
)
545546

546-
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
547+
def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
547548
"""
548549
Computes the learned summary statistics of given summary variables.
549550
@@ -574,6 +575,14 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
574575

575576
return summaries
576577

578+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
579+
"""
580+
.. deprecated:: 2.0.4
581+
`summaries` will be removed in version 2.0.5, it was renamed to `summarize` which should be used instead.
582+
"""
583+
warnings.warn("`summaries` was renamed to `summarize` and will be removed in version 2.0.5.", FutureWarning)
584+
return self.summarize(data=data, **kwargs)
585+
577586
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
578587
"""
579588
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import keras
44
import numpy as np
5+
import warnings
56

67
from bayesflow.adapters import Adapter
78
from bayesflow.datasets import OnlineDataset
@@ -407,7 +408,7 @@ def predict(
407408

408409
return keras.ops.convert_to_numpy(output)
409410

410-
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
411+
def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
411412
"""
412413
Computes the learned summary statistics of given summary variables.
413414
@@ -438,6 +439,14 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
438439

439440
return summaries
440441

442+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
443+
"""
444+
.. deprecated:: 2.0.4
445+
`summaries` will be removed in version 2.0.5, it was renamed to `summarize` which should be used instead.
446+
"""
447+
warnings.warn("`summaries` was renamed to `summarize` and will be removed in version 2.0.5.", FutureWarning)
448+
return self.summarize(data=data, **kwargs)
449+
441450
def _compute_logits(self, classifier_conditions: Tensor) -> Tensor:
442451
"""Helper to compute projected logits from the classifier network."""
443452
logits = self.classifier_network(classifier_conditions)

bayesflow/diagnostics/metrics/model_misspecification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ def summary_space_comparison(
142142
"statistics, or want to compare raw data and not summary statistics, please use the "
143143
f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays."
144144
)
145-
observed_summaries = convert_to_numpy(approximator.summaries(observed_data))
146-
reference_summaries = convert_to_numpy(approximator.summaries(reference_data))
145+
observed_summaries = convert_to_numpy(approximator.summarize(observed_data))
146+
reference_summaries = convert_to_numpy(approximator.summarize(reference_data))
147147

148148
distance_observed, distance_null = bootstrap_comparison(
149149
observed_samples=observed_summaries,

tests/test_approximators/test_summaries.py renamed to tests/test_approximators/test_summarize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55

66
def test_valid_summaries(approximator_with_summaries, mean_std_summary_network, monkeypatch):
77
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
8-
summaries = approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
8+
summaries = approximator_with_summaries.summarize({"summary_variables": keras.ops.ones((2, 3))})
99
assert_allclose(summaries, keras.ops.stack([keras.ops.ones((2,)), keras.ops.zeros((2,))], axis=-1))
1010

1111

1212
def test_no_summary_network(approximator_with_summaries, monkeypatch):
1313
monkeypatch.setattr(approximator_with_summaries, "summary_network", None)
1414

1515
with pytest.raises(ValueError):
16-
approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
16+
approximator_with_summaries.summarize({"summary_variables": keras.ops.ones((2, 3))})
1717

1818

1919
def test_no_summary_variables(approximator_with_summaries, mean_std_summary_network, monkeypatch):
2020
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
2121

2222
with pytest.raises(ValueError):
23-
approximator_with_summaries.summaries({})
23+
approximator_with_summaries.summarize({})

0 commit comments

Comments
 (0)