Skip to content

Commit

Permalink
Merge pull request #145 from vpratz/Development
Browse files Browse the repository at this point in the history
add notebook: amortized point estimation
  • Loading branch information
stefanradev93 authored Mar 9, 2024
2 parents 6508273 + 962dde9 commit 65961be
Show file tree
Hide file tree
Showing 2 changed files with 652 additions and 2 deletions.
15 changes: 13 additions & 2 deletions bayesflow/helper_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,15 @@ class ConfigurableMLP(tf.keras.Model):
"""Implements a simple configurable MLP with optional residual connections and dropout."""

def __init__(
self, input_dim, hidden_dim=512, num_hidden=2, activation="relu", residual=True, dropout_rate=0.05, **kwargs
self,
input_dim,
hidden_dim=512,
output_dim=None,
num_hidden=2,
activation="relu",
residual=True,
dropout_rate=0.05,
**kwargs,
):
"""
Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
Expand All @@ -605,6 +613,8 @@ def __init__(
The input dimensionality
hidden_dim : int, optional, default: 512
The dimensionality of the hidden layers
output_dim : int, optional, default: None
The output dimensionality. If None is passed, `output_dim` is set to `input_dim`
num_hidden : int, optional, default: 2
The number of hidden layers (minimum: 1)
activation : string, optional, default: 'relu'
Expand All @@ -618,6 +628,7 @@ def __init__(
super().__init__(**kwargs)

self.input_dim = input_dim
self.output_dim = input_dim if output_dim is None else output_dim
self.model = tf.keras.Sequential(
[tf.keras.layers.Dense(hidden_dim, activation=activation), tf.keras.layers.Dropout(dropout_rate)]
)
Expand All @@ -630,7 +641,7 @@ def __init__(
dropout_rate=dropout_rate,
)
)
self.model.add(tf.keras.layers.Dense(input_dim))
self.model.add(tf.keras.layers.Dense(self.output_dim))

def call(self, inputs, **kwargs):
return self.model(inputs, **kwargs)
Expand Down
639 changes: 639 additions & 0 deletions examples/Amortized_Point_Estimation.ipynb

Large diffs are not rendered by default.

0 comments on commit 65961be

Please sign in to comment.