Skip to content

Commit

Permalink
fix summary networks
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jul 15, 2024
1 parent c8b76e3 commit b908c91
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
3 changes: 1 addition & 2 deletions bayesflow/networks/deep_set/deep_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


from bayesflow.types import Tensor
from bayesflow.utils import keras_kwargs
from .invariant_module import InvariantModule
from .equivariant_module import EquivariantModule

Expand Down Expand Up @@ -46,7 +45,7 @@ def __init__(
#TODO
"""

super().__init__(**keras_kwargs(kwargs))
super().__init__(**kwargs)

# Stack of equivariant modules for a many-to-many learnable transformation
self.equivariant_modules = keras.Sequential()
Expand Down
3 changes: 1 addition & 2 deletions bayesflow/networks/lstnet/lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from keras.saving import register_keras_serializable

from bayesflow.types import Tensor
from bayesflow.utils import keras_kwargs

from .skip_recurrent import SkipRecurrentNet
from ..mlp import MLP
Expand Down Expand Up @@ -39,7 +38,7 @@ def __init__(
skip_steps: int = 4,
**kwargs,
):
super().__init__(**keras_kwargs(kwargs))
super().__init__(**kwargs)

# Convolutional backbone -> can be extended with inception-like structure
if not isinstance(filters, (list, tuple)):
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/networks/summary_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class SummaryNetwork(keras.Layer):
def __init_(self, base_distribution: str = "normal", **kwargs):
def __init__(self, base_distribution: str = "normal", **kwargs):
super().__init__(**keras_kwargs(kwargs))

self.base_distribution = find_distribution(base_distribution)
Expand Down
4 changes: 3 additions & 1 deletion bayesflow/networks/transformers/set_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

from bayesflow.types import Tensor

from ..summary_network import SummaryNetwork

from .sab import SetAttentionBlock
from .isab import InducedSetAttentionBlock
from .pma import PoolingByMultiHeadAttention


@register_keras_serializable(package="bayesflow.networks")
class SetTransformer(keras.Layer):
class SetTransformer(SummaryNetwork):
"""Implements the set transformer architecture from [1] which ultimately represents
a learnable permutation-invariant function. Designed to naturally model interactions in
the input set, which may be hard to capture with the simpler ``DeepSet`` architecture.
Expand Down

0 comments on commit b908c91

Please sign in to comment.