Skip to content

Commit 65961be

Browse files
Merge pull request #145 from vpratz/Development
add notebook: amortized point estimation
2 parents 6508273 + 962dde9 commit 65961be

File tree

2 files changed

+652
-2
lines changed

2 files changed

+652
-2
lines changed

bayesflow/helper_networks.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,15 @@ class ConfigurableMLP(tf.keras.Model):
594594
"""Implements a simple configurable MLP with optional residual connections and dropout."""
595595

596596
def __init__(
597-
self, input_dim, hidden_dim=512, num_hidden=2, activation="relu", residual=True, dropout_rate=0.05, **kwargs
597+
self,
598+
input_dim,
599+
hidden_dim=512,
600+
output_dim=None,
601+
num_hidden=2,
602+
activation="relu",
603+
residual=True,
604+
dropout_rate=0.05,
605+
**kwargs,
598606
):
599607
"""
600608
Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
@@ -605,6 +613,8 @@ def __init__(
605613
The input dimensionality
606614
hidden_dim : int, optional, default: 512
607615
The dimensionality of the hidden layers
616+
output_dim : int, optional, default: None
617+
The output dimensionality. If None is passed, `output_dim` is set to `input_dim`
608618
num_hidden : int, optional, default: 2
609619
The number of hidden layers (minimum: 1)
610620
activation : string, optional, default: 'relu'
@@ -618,6 +628,7 @@ def __init__(
618628
super().__init__(**kwargs)
619629

620630
self.input_dim = input_dim
631+
self.output_dim = input_dim if output_dim is None else output_dim
621632
self.model = tf.keras.Sequential(
622633
[tf.keras.layers.Dense(hidden_dim, activation=activation), tf.keras.layers.Dropout(dropout_rate)]
623634
)
@@ -630,7 +641,7 @@ def __init__(
630641
dropout_rate=dropout_rate,
631642
)
632643
)
633-
self.model.add(tf.keras.layers.Dense(input_dim))
644+
self.model.add(tf.keras.layers.Dense(self.output_dim))
634645

635646
def call(self, inputs, **kwargs):
636647
return self.model(inputs, **kwargs)

examples/Amortized_Point_Estimation.ipynb

Lines changed: 639 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)