Skip to content

Commit 6508273

Browse files
committed
Add new ConfigurableMLP
1 parent 0b866a9 commit 6508273

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

bayesflow/helper_networks.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,66 @@ def _multi_conv(self, x, **kwargs):
588588
"""Applies the convolutions with different sizes and concatenates outputs."""
589589

590590
return tf.concat([conv(x, **kwargs) for conv in self.convs], axis=-1)
591+
592+
593+
class ConfigurableMLP(tf.keras.Model):
594+
"""Implements a simple configurable MLP with optional residual connections and dropout."""
595+
596+
def __init__(
597+
self, input_dim, hidden_dim=512, num_hidden=2, activation="relu", residual=True, dropout_rate=0.05, **kwargs
598+
):
599+
"""
600+
Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
601+
602+
Parameters:
603+
-----------
604+
input_dim : int
605+
The input dimensionality
606+
hidden_dim : int, optional, default: 512
607+
The dimensionality of the hidden layers
608+
num_hidden : int, optional, default: 2
609+
The number of hidden layers (minimum: 1)
610+
activation : string, optional, default: 'relu'
611+
The activation function of the dense layers
612+
residual : bool, optional, default: True
613+
Use residual connections in the MLP
614+
dropout_rate : float, optional, default: 0.05
615+
Dropout rate for the hidden layers in the MLP
616+
"""
617+
618+
super().__init__(**kwargs)
619+
620+
self.input_dim = input_dim
621+
self.model = tf.keras.Sequential(
622+
[tf.keras.layers.Dense(hidden_dim, activation=activation), tf.keras.layers.Dropout(dropout_rate)]
623+
)
624+
for _ in range(num_hidden):
625+
self.model.add(
626+
ConfigurableHiddenBlock(
627+
hidden_dim,
628+
activation=activation,
629+
residual=residual,
630+
dropout_rate=dropout_rate,
631+
)
632+
)
633+
self.model.add(tf.keras.layers.Dense(input_dim))
634+
635+
def call(self, inputs, **kwargs):
636+
return self.model(inputs, **kwargs)
637+
638+
639+
class ConfigurableHiddenBlock(tf.keras.Model):
640+
def __init__(self, num_units, activation="relu", residual=True, dropout_rate=0.0):
641+
super().__init__()
642+
643+
self.act_fn = tf.keras.activations.get(activation)
644+
self.residual = residual
645+
self.dense_with_dropout = tf.keras.Sequential(
646+
[tf.keras.layers.Dense(num_units, activation=None), tf.keras.layers.Dropout(dropout_rate)]
647+
)
648+
649+
def call(self, inputs, **kwargs):
650+
x = self.dense_with_dropout(inputs, **kwargs)
651+
if self.residual:
652+
x = x + inputs
653+
return self.act_fn(x)

0 commit comments

Comments
 (0)