Skip to content

Commit ab635f7

Browse files
committed
allow and change default base distribution to None
1 parent b908c91 commit ab635f7

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

bayesflow/networks/summary_network.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77

88
class SummaryNetwork(keras.Layer):
9-
def __init__(self, base_distribution: str = "normal", **kwargs):
9+
def __init__(self, base_distribution: str = None, **kwargs):
1010
super().__init__(**keras_kwargs(kwargs))
11-
1211
self.base_distribution = find_distribution(base_distribution)
1312

1413
def build(self, input_shape):

bayesflow/utils/dispatch/find_distribution.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ def _(name: str, *args, **kwargs):
1313
from bayesflow.distributions import DiagonalNormal
1414

1515
distribution = DiagonalNormal(*args, **kwargs)
16+
case "none":
17+
distribution = None
1618
case other:
1719
raise ValueError(f"Unsupported distribution name '{other}'.")
1820

1921
return distribution
2022

2123

2224
@find_distribution.register
23-
def _(constructor: type, *args, **kwargs):
24-
return constructor(*args, **kwargs)
25+
def _(none: None, *args, **kwargs):
26+
return None

0 commit comments

Comments
 (0)