Skip to content

Commit

Permalink
Change coupling defaults to usable defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jun 20, 2024
1 parent 793eaa3 commit 5aa1ac3
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions bayesflow/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class CouplingFlow(InferenceNetwork):
def __init__(
self,
depth: int = 6,
subnet: str = "default",
subnet: str | keras.Layer = "mlp",
transform: str = "affine",
permutation: str | None = None,
permutation: str | None = "random",
use_actnorm: bool = True,
base_distribution: str = "normal",
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@register_keras_serializable(package="bayesflow.networks.coupling_flow")
class DualCoupling(InvertibleLayer):
def __init__(self, subnet: str = "resnet", transform: str = "affine", **kwargs):
def __init__(self, subnet: str = "mlp", transform: str = "affine", **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.coupling1 = SingleCoupling(subnet, transform, **kwargs)
self.coupling2 = SingleCoupling(subnet, transform, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SingleCoupling(InvertibleLayer):
"""
def __init__(
self,
subnet: str = "resnet",
subnet: str = "mlp",
transform: str = "affine",
**kwargs
):
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/utils/dispatch/find_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _(name: str, **kwargs):
case "mlp" | "default":
from bayesflow.networks import MLP
network = MLP(**kwargs)
#TODO - remove, since MLP encompasses the functionality
# TODO - remove, since MLP encompasses the functionality
case "resnet":
from bayesflow.networks import ResNet
network = ResNet(**kwargs)
Expand Down

0 comments on commit 5aa1ac3

Please sign in to comment.