Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jun 12, 2024
1 parent 8396e41 commit 047204d
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions bayesflow/experimental/configurators/configurator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@

import keras
from keras import ops

from bayesflow.experimental.types import Tensor
from bayesflow.experimental.utils import concatenate_tensors

from .base_configurator import BaseConfigurator


class Configurator(BaseConfigurator):
def __init__(self, inference_variables: list[str], inference_conditions: list[str] = None, summary_variables: list[str] = None, summary_conditions: list[str] = None):
def __init__(
self,
inference_variables: list[str],
inference_conditions: list[str] = None,
summary_variables: list[str] = None,
summary_conditions: list[str] = None
):
self.inference_variables = inference_variables
self.inference_conditions = inference_conditions or []
self.summary_variables = summary_variables or []
self.summary_conditions = summary_conditions or []

def configure_inference_variables(self, data: dict[str, Tensor]) -> Tensor:
try:
data["inference_variables"] = keras.ops.concatenate([val for key, val in data.items() if key in self.inference_variables], axis=-1)
data["inference_variables"] = concatenate_tensors(data, self.inference_variables)
except ValueError as e:
raise ValueError(f"Cannot trivially concatenate inference variables.") from e

Expand All @@ -29,17 +36,19 @@ def configure_inference_conditions(self, data: dict[str, Tensor]) -> Tensor:
data["inference_conditions"] = data["summary_outputs"]
else:
try:
specified_conditions = keras.ops.concatenate([val for key, val in data.items() if key in self.inference_conditions], axis=-1)
specified_conditions = concatenate_tensors(data, self.inference_conditions)
except ValueError as e:
raise ValueError(f"Cannot trivially concatenate inference conditions.") from e

if "summary_outputs" not in data:
# case 3: just the specified conditions
# case 3: just the direct inference conditions
data["inference_conditions"] = specified_conditions
else:
# case 4: summaries and specified conditions
# case 4: summaries and direct inference conditions
try:
data["inference_conditions"] = keras.ops.concatenate([data["summary_outputs"], specified_conditions], axis=-1)
data["inference_conditions"] = ops.concatenate(
[data["summary_outputs"], specified_conditions], axis=-1
)
except ValueError as e:
raise ValueError(f"Cannot trivially concatenate summary outputs to inference conditions.") from e

Expand All @@ -48,15 +57,15 @@ def configure_summary_variables(self, data: dict[str, Tensor]) -> Tensor:
return

try:
data["summary_variables"] = keras.ops.concatenate([val for key, val in data.items() if key in self.summary_variables], axis=-1)
data["summary_variables"] = concatenate_tensors(data, self.summary_variables)
except ValueError as e:
raise ValueError(f"Cannot trivially concatenate summary variables.") from e
raise ValueError(f"Cannot trivially concatenate summary variables along last axis.") from e

def configure_summary_conditions(self, data: dict[str, Tensor]) -> Tensor:
if not self.summary_conditions:
return

try:
data["summary_conditions"] = keras.ops.concatenate([val for key, val in data.items() if key in self.summary_conditions], axis=-1)
data["summary_conditions"] = concatenate_tensors(data, self.summary_conditions)
except ValueError as e:
raise ValueError(f"Cannot trivially concatenate summary conditions.") from e
raise ValueError(f"Cannot trivially concatenate summary conditions along last axis.") from e

0 comments on commit 047204d

Please sign in to comment.