diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index 06bb95ea8..2a10f6151 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -6,9 +6,8 @@ class SummaryNetwork(keras.Layer): - def __init__(self, base_distribution: str = "normal", **kwargs): + def __init__(self, base_distribution: str = None, **kwargs): super().__init__(**keras_kwargs(kwargs)) - self.base_distribution = find_distribution(base_distribution) def build(self, input_shape): diff --git a/bayesflow/utils/dispatch/find_distribution.py b/bayesflow/utils/dispatch/find_distribution.py index 2e244c684..aa16d6d4f 100644 --- a/bayesflow/utils/dispatch/find_distribution.py +++ b/bayesflow/utils/dispatch/find_distribution.py @@ -13,6 +13,8 @@ def _(name: str, *args, **kwargs): from bayesflow.distributions import DiagonalNormal distribution = DiagonalNormal(*args, **kwargs) + case "none": + distribution = None case other: raise ValueError(f"Unsupported distribution name '{other}'.") @@ -20,5 +22,5 @@ def _(name: str, *args, **kwargs): @find_distribution.register -def _(constructor: type, *args, **kwargs): - return constructor(*args, **kwargs) +def _(none: None, *args, **kwargs): + return None