Skip to content

Commit

Permalink
allow and change default base distribution to None
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jul 16, 2024
1 parent b908c91 commit ab635f7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
3 changes: 1 addition & 2 deletions bayesflow/networks/summary_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@


class SummaryNetwork(keras.Layer):
def __init__(self, base_distribution: str = "normal", **kwargs):
def __init__(self, base_distribution: str = None, **kwargs):
super().__init__(**keras_kwargs(kwargs))

self.base_distribution = find_distribution(base_distribution)

def build(self, input_shape):
Expand Down
6 changes: 4 additions & 2 deletions bayesflow/utils/dispatch/find_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ def _(name: str, *args, **kwargs):
from bayesflow.distributions import DiagonalNormal

distribution = DiagonalNormal(*args, **kwargs)
case "none":
distribution = None
case other:
raise ValueError(f"Unsupported distribution name '{other}'.")

return distribution


@find_distribution.register
def _(constructor: type, *args, **kwargs):
return constructor(*args, **kwargs)
def _(none: None, *args, **kwargs):
return None

0 comments on commit ab635f7

Please sign in to comment.