Skip to content

Commit

Permalink
add model comparison example
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Aug 16, 2024
1 parent 1381dd0 commit ab6cfce
Show file tree
Hide file tree
Showing 3 changed files with 354 additions and 9 deletions.
27 changes: 21 additions & 6 deletions bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class ModelComparisonApproximator(Approximator):
def __init__(
self,
*,
num_models: int,
classifier_network: keras.Layer,
data_adapter: DataAdapter,
summary_network: SummaryNetwork = None,
Expand All @@ -36,19 +37,24 @@ def __init__(
self.data_adapter = data_adapter
self.summary_network = summary_network

self.logits_projector = keras.layers.Dense(num_models)

def build(self, data_shapes: Mapping[str, Shape]):
data = {key: keras.ops.zeros(value) for key, value in data_shapes.items()}
self.compute_metrics(**data, stage="training")

@classmethod
def build_data_adapter(
cls,
classifier_variables: Sequence[str],
classifier_conditions: Sequence[str] = None,
summary_variables: Sequence[str] = None,
model_index_name: str = "model_indices",
):
if classifier_conditions is None and summary_variables is None:
raise ValueError("At least one of `classifier_variables` or `summary_variables` must be provided.")

variables = {
"classifier_variables": classifier_variables,
"classifier_conditions": classifier_conditions,
"summary_variables": summary_variables,
"model_indices": [model_index_name],
}
Expand Down Expand Up @@ -93,7 +99,8 @@ def compile(

def compute_metrics(
self,
classifier_variables: Tensor,
*,
classifier_conditions: Tensor = None,
model_indices: Tensor,
summary_variables: Tensor = None,
stage: str = "training",
Expand All @@ -104,11 +111,19 @@ def compute_metrics(
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
summary_outputs = summary_metrics.pop("outputs")

classifier_variables = keras.ops.concatenate([classifier_variables, summary_outputs], axis=-1)
if classifier_conditions is None:
classifier_conditions = summary_outputs
else:
classifier_conditions = keras.ops.concatenate([classifier_conditions, summary_outputs], axis=-1)

# we could move this into its own class
logits = self.classifier_network(classifier_variables)
classifier_metrics = {"loss": keras.losses.categorical_crossentropy(model_indices, logits, from_logits=True)}
logits = self.classifier_network(classifier_conditions)
logits = self.logits_projector(logits)

cross_entropy = keras.losses.categorical_crossentropy(model_indices, logits, from_logits=True)
cross_entropy = keras.ops.mean(cross_entropy)

classifier_metrics = {"loss": cross_entropy}

if stage != "training" and any(self.classifier_network.metrics):
# compute sample-based metrics
Expand Down
7 changes: 4 additions & 3 deletions bayesflow/utils/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def filter_kwargs(kwargs: Mapping[str, any], f: callable) -> Mapping[str, any]:
"""Filter keyword arguments for f"""
signature = inspect.signature(f)

if inspect.Parameter.VAR_KEYWORD in signature.parameters:
# the signature has **kwargs
return kwargs
for parameter in signature.parameters.values():
if parameter.kind == inspect.Parameter.VAR_KEYWORD:
# there is a **kwargs parameter, so anything is valid
return kwargs

kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}

Expand Down
Loading

0 comments on commit ab6cfce

Please sign in to comment.