Skip to content

Commit

Permalink
Add new ConfigurableMLP
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Mar 7, 2024
1 parent 0b866a9 commit 6508273
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions bayesflow/helper_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,66 @@ def _multi_conv(self, x, **kwargs):
"""Applies the convolutions with different sizes and concatenates outputs."""

return tf.concat([conv(x, **kwargs) for conv in self.convs], axis=-1)


class ConfigurableMLP(tf.keras.Model):
"""Implements a simple configurable MLP with optional residual connections and dropout."""

def __init__(
self, input_dim, hidden_dim=512, num_hidden=2, activation="relu", residual=True, dropout_rate=0.05, **kwargs
):
"""
Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
Parameters:
-----------
input_dim : int
The input dimensionality
hidden_dim : int, optional, default: 512
The dimensionality of the hidden layers
num_hidden : int, optional, default: 2
The number of hidden layers (minimum: 1)
activation : string, optional, default: 'relu'
The activation function of the dense layers
residual : bool, optional, default: True
Use residual connections in the MLP
dropout_rate : float, optional, default: 0.05
Dropout rate for the hidden layers in the MLP
"""

super().__init__(**kwargs)

self.input_dim = input_dim
self.model = tf.keras.Sequential(
[tf.keras.layers.Dense(hidden_dim, activation=activation), tf.keras.layers.Dropout(dropout_rate)]
)
for _ in range(num_hidden):
self.model.add(
ConfigurableHiddenBlock(
hidden_dim,
activation=activation,
residual=residual,
dropout_rate=dropout_rate,
)
)
self.model.add(tf.keras.layers.Dense(input_dim))

def call(self, inputs, **kwargs):
return self.model(inputs, **kwargs)


class ConfigurableHiddenBlock(tf.keras.Model):
def __init__(self, num_units, activation="relu", residual=True, dropout_rate=0.0):
super().__init__()

self.act_fn = tf.keras.activations.get(activation)
self.residual = residual
self.dense_with_dropout = tf.keras.Sequential(
[tf.keras.layers.Dense(num_units, activation=None), tf.keras.layers.Dropout(dropout_rate)]
)

def call(self, inputs, **kwargs):
x = self.dense_with_dropout(inputs, **kwargs)
if self.residual:
x = x + inputs
return self.act_fn(x)

0 comments on commit 6508273

Please sign in to comment.