Skip to content

Commit

Permalink
notebook: point est., adapt to use ConfigurableMLP
Browse files Browse the repository at this point in the history
  • Loading branch information
vpratz committed Mar 9, 2024
1 parent d4b5c24 commit 962dde9
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 194 deletions.
15 changes: 13 additions & 2 deletions bayesflow/helper_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,15 @@ class ConfigurableMLP(tf.keras.Model):
"""Implements a simple configurable MLP with optional residual connections and dropout."""

def __init__(
self, input_dim, hidden_dim=512, num_hidden=2, activation="relu", residual=True, dropout_rate=0.05, **kwargs
self,
input_dim,
hidden_dim=512,
output_dim=None,
num_hidden=2,
activation="relu",
residual=True,
dropout_rate=0.05,
**kwargs,
):
"""
Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
Expand All @@ -605,6 +613,8 @@ def __init__(
The input dimensionality
hidden_dim : int, optional, default: 512
The dimensionality of the hidden layers
output_dim : int, optional, default: None
The output dimensionality. If None is passed, `output_dim` is set to `input_dim`
num_hidden : int, optional, default: 2
The number of hidden layers (minimum: 1)
activation : string, optional, default: 'relu'
Expand All @@ -618,6 +628,7 @@ def __init__(
super().__init__(**kwargs)

self.input_dim = input_dim
self.output_dim = input_dim if output_dim is None else output_dim
self.model = tf.keras.Sequential(
[tf.keras.layers.Dense(hidden_dim, activation=activation), tf.keras.layers.Dropout(dropout_rate)]
)
Expand All @@ -630,7 +641,7 @@ def __init__(
dropout_rate=dropout_rate,
)
)
self.model.add(tf.keras.layers.Dense(input_dim))
self.model.add(tf.keras.layers.Dense(self.output_dim))

def call(self, inputs, **kwargs):
return self.model(inputs, **kwargs)
Expand Down
288 changes: 96 additions & 192 deletions examples/Amortized_Point_Estimation.ipynb

Large diffs are not rendered by default.

0 comments on commit 962dde9

Please sign in to comment.