@@ -739,7 +739,7 @@ <h1>Source code for bayesflow.helper_networks</h1><div class="highlight"><pre>
739
739
740
740
< span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> )</ span >
741
741
742
- < span class ="c1 "> # Initialize scale and bias with zeros and ones if no batch for initalization was provided.</ span >
742
+ < span class ="c1 "> # Initialize scale and bias with zeros and ones if no batch for initialization was provided.</ span >
743
743
< span class ="k "> if</ span > < span class ="n "> act_norm_init</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
744
744
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> scale</ span > < span class ="o "> =</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> Variable</ span > < span class ="p "> (</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> ones</ span > < span class ="p "> ((</ span > < span class ="n "> latent_dim</ span > < span class ="p "> ,)),</ span > < span class ="n "> trainable</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span > < span class ="n "> name</ span > < span class ="o "> =</ span > < span class ="s2 "> "act_norm_scale"</ span > < span class ="p "> )</ span >
745
745
@@ -1037,7 +1037,15 @@ <h1>Source code for bayesflow.helper_networks</h1><div class="highlight"><pre>
1037
1037
< div class ="viewcode-block " id ="ConfigurableMLP.__init__ ">
1038
1038
< a class ="viewcode-back " href ="../../api/bayesflow.helper_networks.html#bayesflow.helper_networks.ConfigurableMLP.__init__ "> [docs]</ a >
1039
1039
< span class ="k "> def</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span >
1040
- < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> input_dim</ span > < span class ="p "> ,</ span > < span class ="n "> hidden_dim</ span > < span class ="o "> =</ span > < span class ="mi "> 512</ span > < span class ="p "> ,</ span > < span class ="n "> num_hidden</ span > < span class ="o "> =</ span > < span class ="mi "> 2</ span > < span class ="p "> ,</ span > < span class ="n "> activation</ span > < span class ="o "> =</ span > < span class ="s2 "> "relu"</ span > < span class ="p "> ,</ span > < span class ="n "> residual</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span > < span class ="n "> dropout_rate</ span > < span class ="o "> =</ span > < span class ="mf "> 0.05</ span > < span class ="p "> ,</ span > < span class ="o "> **</ span > < span class ="n "> kwargs</ span >
1040
+ < span class ="bp "> self</ span > < span class ="p "> ,</ span >
1041
+ < span class ="n "> input_dim</ span > < span class ="p "> ,</ span >
1042
+ < span class ="n "> hidden_dim</ span > < span class ="o "> =</ span > < span class ="mi "> 512</ span > < span class ="p "> ,</ span >
1043
+ < span class ="n "> output_dim</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
1044
+ < span class ="n "> num_hidden</ span > < span class ="o "> =</ span > < span class ="mi "> 2</ span > < span class ="p "> ,</ span >
1045
+ < span class ="n "> activation</ span > < span class ="o "> =</ span > < span class ="s2 "> "relu"</ span > < span class ="p "> ,</ span >
1046
+ < span class ="n "> residual</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
1047
+ < span class ="n "> dropout_rate</ span > < span class ="o "> =</ span > < span class ="mf "> 0.05</ span > < span class ="p "> ,</ span >
1048
+ < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> ,</ span >
1041
1049
< span class ="p "> ):</ span >
1042
1050
< span class ="w "> </ span > < span class ="sd "> """</ span >
1043
1051
< span class ="sd "> Creates an instance of a flexible and simple MLP with optional residual connections and dropout.</ span >
@@ -1048,6 +1056,8 @@ <h1>Source code for bayesflow.helper_networks</h1><div class="highlight"><pre>
1048
1056
< span class ="sd "> The input dimensionality</ span >
1049
1057
< span class ="sd "> hidden_dim : int, optional, default: 512</ span >
1050
1058
< span class ="sd "> The dimensionality of the hidden layers</ span >
1059
+ < span class ="sd "> output_dim : int, optional, default: None</ span >
1060
+ < span class ="sd "> The output dimensionality. If None is passed, `output_dim` is set to `input_dim`</ span >
1051
1061
< span class ="sd "> num_hidden : int, optional, default: 2</ span >
1052
1062
< span class ="sd "> The number of hidden layers (minimum: 1)</ span >
1053
1063
< span class ="sd "> activation : string, optional, default: 'relu'</ span >
@@ -1061,6 +1071,7 @@ <h1>Source code for bayesflow.helper_networks</h1><div class="highlight"><pre>
1061
1071
< span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> )</ span >
1062
1072
1063
1073
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> input_dim</ span > < span class ="o "> =</ span > < span class ="n "> input_dim</ span >
1074
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> output_dim</ span > < span class ="o "> =</ span > < span class ="n "> input_dim</ span > < span class ="k "> if</ span > < span class ="n "> output_dim</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="k "> else</ span > < span class ="n "> output_dim</ span >
1064
1075
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="o "> =</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> keras</ span > < span class ="o "> .</ span > < span class ="n "> Sequential</ span > < span class ="p "> (</ span >
1065
1076
< span class ="p "> [</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> keras</ span > < span class ="o "> .</ span > < span class ="n "> layers</ span > < span class ="o "> .</ span > < span class ="n "> Dense</ span > < span class ="p "> (</ span > < span class ="n "> hidden_dim</ span > < span class ="p "> ,</ span > < span class ="n "> activation</ span > < span class ="o "> =</ span > < span class ="n "> activation</ span > < span class ="p "> ),</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> keras</ span > < span class ="o "> .</ span > < span class ="n "> layers</ span > < span class ="o "> .</ span > < span class ="n "> Dropout</ span > < span class ="p "> (</ span > < span class ="n "> dropout_rate</ span > < span class ="p "> )]</ span >
1066
1077
< span class ="p "> )</ span >
@@ -1073,7 +1084,7 @@ <h1>Source code for bayesflow.helper_networks</h1><div class="highlight"><pre>
1073
1084
< span class ="n "> dropout_rate</ span > < span class ="o "> =</ span > < span class ="n "> dropout_rate</ span > < span class ="p "> ,</ span >
1074
1085
< span class ="p "> )</ span >
1075
1086
< span class ="p "> )</ span >
1076
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> add</ span > < span class ="p "> (</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> keras</ span > < span class ="o "> .</ span > < span class ="n "> layers</ span > < span class ="o "> .</ span > < span class ="n "> Dense</ span > < span class ="p "> (</ span > < span class ="n "> input_dim </ span > < span class ="p "> ))</ span > </ div >
1087
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> add</ span > < span class ="p "> (</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> keras</ span > < span class ="o "> .</ span > < span class ="n "> layers</ span > < span class ="o "> .</ span > < span class ="n "> Dense</ span > < span class ="p "> (</ span > < span class ="bp " > self </ span > < span class =" o " > . </ span > < span class =" n "> output_dim </ span > < span class ="p "> ))</ span > </ div >
1077
1088
1078
1089
1079
1090
< div class ="viewcode-block " id ="ConfigurableMLP.call ">
0 commit comments