diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index 4dc4f906b..fc57a3931 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -39,9 +39,9 @@ class CouplingFlow(InferenceNetwork): def __init__( self, depth: int = 6, - subnet: str = "default", + subnet: str | keras.Layer = "mlp", transform: str = "affine", - permutation: str | None = None, + permutation: str | None = "random", use_actnorm: bool = True, base_distribution: str = "normal", **kwargs diff --git a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py index 84163a8aa..aaa384a76 100644 --- a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py @@ -10,7 +10,7 @@ @register_keras_serializable(package="bayesflow.networks.coupling_flow") class DualCoupling(InvertibleLayer): - def __init__(self, subnet: str = "resnet", transform: str = "affine", **kwargs): + def __init__(self, subnet: str = "mlp", transform: str = "affine", **kwargs): super().__init__(**keras_kwargs(kwargs)) self.coupling1 = SingleCoupling(subnet, transform, **kwargs) self.coupling2 = SingleCoupling(subnet, transform, **kwargs) diff --git a/bayesflow/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/networks/coupling_flow/couplings/single_coupling.py index b5ce1c5ae..f8665eecc 100644 --- a/bayesflow/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/single_coupling.py @@ -17,7 +17,7 @@ class SingleCoupling(InvertibleLayer): """ def __init__( self, - subnet: str = "resnet", + subnet: str = "mlp", transform: str = "affine", **kwargs ): diff --git a/bayesflow/utils/dispatch/find_network.py b/bayesflow/utils/dispatch/find_network.py index c7006632f..05f2b4b7a 100644 --- a/bayesflow/utils/dispatch/find_network.py +++ b/bayesflow/utils/dispatch/find_network.py @@ -15,7 +15,7 @@ def _(name: str, **kwargs): case "mlp" | "default": from bayesflow.networks import MLP network = MLP(**kwargs) - #TODO - remove, since MLP encompasses the functionality + # TODO - remove, since MLP encompasses the functionality case "resnet": from bayesflow.networks import ResNet network = ResNet(**kwargs)