Skip to content

Commit

Permalink
drop summary conditions, add deconfiguration for sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jul 5, 2024
1 parent 4ee1aa8 commit 4aac652
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 78 deletions.
7 changes: 4 additions & 3 deletions bayesflow/approximators/approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,16 @@ def __init__(self, **kwargs):
inference_variables = kwargs.pop("inference_variables")
inference_conditions = kwargs.pop("inference_conditions", None)
summary_variables = kwargs.pop("summary_variables", None)
summary_conditions = kwargs.pop("summary_conditions", None)

kwargs["configurator"] = Configurator(
inference_variables, inference_conditions, summary_variables, summary_conditions
inference_variables,
inference_conditions,
summary_variables,
)
else:
# the user passed a configurator, so we should not configure a default one
# check if the user also passed args for the default configurator
keys = ["inference_variables", "inference_conditions", "summary_variables", "summary_conditions"]
keys = ["inference_variables", "inference_conditions", "summary_variables"]
if any(key in kwargs for key in keys):
raise ValueError(
"Received an ambiguous set of arguments: You are passing a configurator explicitly, "
Expand Down
87 changes: 32 additions & 55 deletions bayesflow/approximators/base_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from bayesflow.configurators import BaseConfigurator
from bayesflow.networks import InferenceNetwork, SummaryNetwork
from bayesflow.types import Shape, Tensor
from bayesflow.utils import keras_kwargs, repeat_tensor, process_output
from bayesflow.utils import keras_kwargs


@register_keras_serializable(package="bayesflow.approximators")
Expand All @@ -27,68 +27,44 @@ def __init__(
self.summary_network = summary_network
self.configurator = configurator

def sample(self, data: dict[str, Tensor], num_samples: int = 500, as_numpy: bool = True) -> Tensor:
"""Generates ``num_samples'' from the approximate distribution. Will typically be called only on
trained models.
Parameters
----------
data: dict[str: Tensor]
The data dictionary containing all keys used when constructing the Approximator except
``inference_variables'', which is assumed to be absent during inference and will be ignored
if present.
num_samples: int, optional, default - 500
The number of samples per data set / instance in the data dictionary.
as_numpy: bool, optional, default - True
An optional flag to convert the samples to a numpy array before returning.
Returns
-------
samples: Tensor
A tensor of shape (num_data_sets, num_samples, num_inference_variables) if data contains
multiple data sets / instances or of shape (num_samples, num_inference_variables) if data
contains a single data sets (i.e., a leading axis with one element in the corresponding
conditioning variables).
"""

data = data.copy()
def sample(self, num_samples: int = 1, data: dict[str, Tensor] = None) -> dict[str, Tensor]:
if data is None:
data = {}
else:
data = data.copy()

if self.summary_network is None:
data["inference_conditions"] = self.configurator.configure_inference_conditions(data)
inference_conditions = self.configurator.configure_inference_conditions(data)
samples = self.inference_network.sample(num_samples, conditions=inference_conditions)

else:
data["summary_conditions"] = self.configurator.configure_summary_conditions(data)
data["summary_variables"] = self.configurator.configure_summary_variables(data)
summary_metrics = self.summary_network.compute_metrics(data, stage="inference")
data["summary_outputs"] = summary_metrics.get("outputs")
return self.configurator.deconfigure(samples)

data["inference_conditions"] = self.configurator.configure_inference_conditions(data)
data["summary_variables"] = self.configurator.configure_summary_variables(data)
data["summary_outputs"] = self.summary_network(data["summary_variables"])

data["inference_conditions"] = repeat_tensor(data["inference_conditions"], num_repeats=num_samples, axis=1)
samples = self.inference_network.sample(num_samples, data["inference_conditions"])
inference_conditions = self.configurator.configure_inference_conditions(data)

return process_output(samples, convert_to_numpy=as_numpy)
samples = self.inference_network.sample(num_samples, conditions=inference_conditions)

def log_prob(self, data: dict[str, Tensor], as_numpy: bool = True) -> Tensor:
"""TODO"""
return self.configurator.deconfigure(samples)

def log_prob(self, data: dict[str, Tensor]) -> Tensor:
data = data.copy()

if self.summary_network is None:
data["inference_conditions"] = self.configurator.configure_inference_conditions(data)
data["inference_variables"] = self.configurator.configure_inference_variables(data)

else:
data["summary_conditions"] = self.configurator.configure_summary_conditions(data)
data["summary_variables"] = self.configurator.configure_summary_variables(data)
summary_metrics = self.summary_network.compute_metrics(data, stage="inference")
data["summary_outputs"] = summary_metrics.get("outputs")
return self.inference_network.log_prob(data)

data["inference_conditions"] = self.configurator.configure_inference_conditions(data)
data["summary_variables"] = self.configurator.configure_summary_variables(data)
summary_metrics = self.summary_network.compute_metrics(data, stage="inference")
data["summary_outputs"] = summary_metrics.get("outputs")

data["inference_conditions"] = self.configurator.configure_inference_conditions(data)
data["inference_variables"] = self.configurator.configure_inference_variables(data)
log_density = self.inference_network.log_prob(data["inference_variables"], data["inference_conditions"])

return process_output(log_density, convert_to_numpy=as_numpy)
return self.inference_network.log_prob(data)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "BaseApproximator":
Expand Down Expand Up @@ -135,10 +111,8 @@ def evaluate(self, *args, **kwargs):
if val_logs is None:
# https://github.com/keras-team/keras/issues/19835
warnings.warn(
"Found no validation logs due to a bug in keras. "
"Applying workaround, but incorrect loss values may be logged. "
"If possible, increase the size of your dataset, "
"or lower the number of validation steps used."
"Found no validation logs due to a bug in keras. Applying workaround, but incorrect loss values may be "
"logged. If possible, increase the size of your dataset, or lower the number of validation steps used."
)

val_logs = {}
Expand All @@ -157,7 +131,6 @@ def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> d
return self.inference_network.compute_metrics(data, stage=stage)

data["summary_variables"] = self.configurator.configure_summary_variables(data)
data["summary_conditions"] = self.configurator.configure_summary_conditions(data)

summary_metrics = self.summary_network.compute_metrics(data, stage=stage)

Expand Down Expand Up @@ -196,9 +169,13 @@ def fit(self, *args, **kwargs):
def compile(
self, inference_metrics: Sequence[keras.Metric] = None, summary_metrics: Sequence[keras.Metric] = None, **kwargs
) -> None:
self.inference_network._metrics = inference_metrics or []

if self.summary_network is not None:
self.summary_network._metrics = summary_metrics or []
if inference_metrics:
self.inference_network._metrics = inference_metrics

if summary_metrics:
if self.summary_network is None:
warnings.warn("Ignoring summary metrics because there is no summary network.")
else:
self.summary_network._metrics = summary_metrics

return super().compile(**kwargs)
34 changes: 27 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def pytest_make_parametrize_id(config, val, argname):
return f"{argname}={repr(val)}"


@pytest.fixture(params=[1, 2, 3, 4], scope="session")
@pytest.fixture(params=[128, 256], scope="session", autouse=True)
def batch_size(request):
return request.param

Expand All @@ -28,7 +28,7 @@ def batch_size(request):
def coupling_flow():
from bayesflow.networks import CouplingFlow

return CouplingFlow(depth=2, subnet_kwargs=dict(depth=2, width=32))
return CouplingFlow(depth=4, subnet_kwargs=dict(depth=4, width=256))


@pytest.fixture(params=["two_moons"], scope="session")
Expand All @@ -40,7 +40,7 @@ def dataset(request):
def flow_matching():
from bayesflow.networks import FlowMatching

return FlowMatching(network_kwargs=dict(depth=2, width=32))
return FlowMatching(network_kwargs=dict(depth=12, width=256))


@pytest.fixture(params=["coupling_flow", "flow_matching"], scope="function")
Expand All @@ -53,13 +53,18 @@ def network(request):
return request.getfixturevalue(request.param)


@pytest.fixture(autouse=True, scope="function")
@pytest.fixture(scope="function", autouse=True)
def random_seed():
seed = 0
keras.utils.set_random_seed(seed)
return seed


@pytest.fixture(params=["two_moons"], scope="session")
def simulator(request):
return request.getfixturevalue(request.param)


@pytest.fixture(params=[None], scope="function")
def summary_network(request):
if request.param is None:
Expand All @@ -68,10 +73,25 @@ def summary_network(request):


@pytest.fixture(scope="session")
def two_moons(batch_size):
def training_dataset(simulator, batch_size):
from bayesflow.datasets import OfflineDataset

num_batches = 128
samples = simulator.sample((num_batches * batch_size,))
return OfflineDataset(samples, batch_size=batch_size)


@pytest.fixture(scope="session")
def two_moons(batch_size):
from bayesflow.simulators import TwoMoonsSimulator

simulator = TwoMoonsSimulator()
samples = simulator.sample((4 * batch_size,))
return TwoMoonsSimulator()


@pytest.fixture(scope="session")
def validation_dataset(simulator, batch_size):
from bayesflow.datasets import OfflineDataset

num_batches = 16
samples = simulator.sample((num_batches * batch_size,))
return OfflineDataset(samples, batch_size=batch_size)
3 changes: 0 additions & 3 deletions tests/test_configurators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def random_data(request, batch_size, set_size, num_features):
"var2": keras.random.normal((batch_size, set_size, num_features)),
"var3": keras.random.normal((batch_size, set_size, num_features)),
"summary_inputs": keras.random.normal((batch_size, set_size, num_features)),
"summary_conditions": keras.random.normal((batch_size, set_size, num_features)),
}
if request.param:
data["summary_outputs"] = keras.random.normal((batch_size, set_size, num_features))
Expand All @@ -38,7 +37,6 @@ def test_params(request):
"inference_variables": ["var1"],
"inference_conditions": ["var2", "var3"],
"summary_variables": ["var1"],
"summary_conditions": ["var2"],
}
if request.param:
args["inference_conditions"].append("summary_outputs")
Expand All @@ -53,5 +51,4 @@ def configurator(request, test_params):
inference_variables=test_params["inference_variables"],
inference_conditions=test_params["inference_conditions"],
summary_variables=test_params["summary_variables"],
summary_conditions=test_params["summary_conditions"],
)
10 changes: 0 additions & 10 deletions tests/test_configurators/test_configurators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,3 @@ def test_summary_variables_shape(random_data, configurator):
filtered_data = configurator.configure_summary_variables(random_data)
expected = keras.ops.concatenate([random_data[v] for v in configurator.summary_variables], axis=-1)
assert filtered_data.shape == expected.shape


def test_summary_conditions_shape(random_data, configurator):
# Tests for correct output shape when querying summary conditions
if not configurator.summary_conditions:
assert configurator.configure_summary_conditions(random_data) is None
else:
filtered_data = configurator.configure_summary_conditions(random_data)
expected = keras.ops.concatenate([random_data[v] for v in configurator.summary_conditions], axis=-1)
assert filtered_data.shape == expected.shape

0 comments on commit 4aac652

Please sign in to comment.