Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fusion Summary Networks (WIP) #181

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/configurators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base_configurator import BaseConfigurator
from .configurator import Configurator
from .dict_configurator import DictConfigurator
1 change: 0 additions & 1 deletion bayesflow/configurators/configurator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import keras
from keras.saving import register_keras_serializable


from bayesflow.types import Tensor
from bayesflow.utils import filter_concatenate

Expand Down
14 changes: 14 additions & 0 deletions bayesflow/configurators/dict_configurator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from keras.saving import register_keras_serializable

from bayesflow.types import Tensor

from .configurator import Configurator


@register_keras_serializable(package="bayesflow.configurators")
class DictConfigurator(Configurator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def configure_summary_variables(self, data: dict[str, Tensor]) -> dict[str, Tensor] | None:
return {k: v for k, v in data.items() if k in self.summary_variables}
1 change: 1 addition & 0 deletions bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .flow_matching import FlowMatching
from .fusion import LateFusionSummaryNetwork
from .inference_network import InferenceNetwork
from .mlp import MLP
from .lstnet import LSTNet
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/fusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .late_fusion_summary_network import LateFusionSummaryNetwork
67 changes: 67 additions & 0 deletions bayesflow/networks/fusion/late_fusion_summary_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import keras

from bayesflow.networks.summary_network import SummaryNetwork
from bayesflow.types import Tensor
from bayesflow.utils import keras_kwargs


class LateFusionSummaryNetwork(keras.Layer):
def __init__(self, summary_networks: dict[str, SummaryNetwork], **kwargs):
super().__init__(**keras_kwargs(kwargs))

self.num_data_sources = len(summary_networks)
self.summary_networks = summary_networks

def build(self, input_shape, **kwargs):
for summary_network in self.summary_networks.values():
summary_network.build(input_shape, **kwargs)

def call(self, x: Tensor, **kwargs) -> Tensor:
"""
:param x: Tensor of shape (batch_size, set_size, input_dim)

:param kwargs: Additional keyword arguments.

:return: Tensor of shape (batch_size, output_dim)
"""
outputs = [] * self.num_data_sources
# Pass all data sources through their respective summary network
for i, (source_name, summary_network) in enumerate(self.summary_networks.items()):
data_source = {"summary_variables": x[source_name]}
outputs[i] = summary_network(data_source, training=kwargs.get("training", False))

# Concatenate the outputs of the individual summary networks
return keras.ops.concatenate(outputs, axis=-1)

def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> dict[str, Tensor]:
metrics_sources = {}

summary_variables = data["summary_variables"]

# Pass all data sources through their respective summary network
for source_name, summary_network in self.summary_networks.items():
data_source = {"summary_variables": summary_variables[source_name]}
metrics_sources[source_name] = summary_network.compute_metrics(data_source, training=stage == "training")

# Merge all information (outputs, loss, additional metrics)
metrics_out = {}

# fuse (concatenate) the outputs of the individual summary networks
try:
outputs = [metrics["outputs"] for metrics in metrics_sources.values()]
metrics_out["outputs"] = keras.ops.concatenate(outputs, axis=-1)
except ValueError as e:
shapes = [metrics["outputs"].shape for metrics in metrics_sources.values()]
raise ValueError(f"Cannot trivially concatenate outputs with shapes {shapes}") from e

# sum up any losses of the individual summary networks
metrics_out["loss"] = keras.ops.sum([metrics["loss"] for metrics in metrics_sources.values()], axis=0)

# gather remaining metrics (only relevant if not training)
if stage != "training":
for source_name, source_metrics in metrics_sources.items():
for metric_name, metric_value in source_metrics.items():
if metric_name not in ["loss", "outputs"]:
metrics_out[f"{source_name}_{metric_name}"] = metric_value

return metrics_out
12 changes: 3 additions & 9 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,24 @@
filter_concatenate,
filter_kwargs,
keras_kwargs,
stack_dicts,
process_output,
stack_dicts,
)

from .dispatch import find_distribution, find_network, find_permutation, find_pooling, find_recurrent_net
from .git import (
issue_url,
pull_url,
repo_url,
)

from .io import warning

from .jacobian_trace import jacobian_trace

from .dispatch import find_distribution, find_network, find_permutation, find_pooling, find_recurrent_net

from .optimal_transport import optimal_transport

from .tensor_utils import (
broadcast_right,
broadcast_right_as,
expand_right,
expand_right_as,
expand_right_to,
tile_axis,
expand_tile,
tile_axis,
)
8 changes: 4 additions & 4 deletions bayesflow/utils/dict_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import inspect
import logging
from collections.abc import Sequence

import keras
from keras import ops
import numpy as np

from collections.abc import Sequence
from keras import ops

from bayesflow.types import Shape, Tensor

Expand Down
Loading
Loading