11import pytest
22
33from bayesflow .networks import MLP
4+ from bayesflow .metrics import RootMeanSquaredError
45
56
67@pytest .fixture ()
@@ -12,6 +13,7 @@ def diffusion_model_edm_F():
1213 integrate_kwargs = {"method" : "rk45" , "steps" : 250 },
1314 noise_schedule = "edm" ,
1415 prediction_type = "F" ,
16+ metrics = [RootMeanSquaredError ()],
1517 )
1618
1719
@@ -82,22 +84,32 @@ def flow_matching():
8284 return FlowMatching (
8385 subnet = MLP ([8 , 8 ]),
8486 integrate_kwargs = {"method" : "rk45" , "steps" : 100 },
87+ metrics = [RootMeanSquaredError ()],
8588 )
8689
8790
8891@pytest .fixture ()
8992def consistency_model ():
9093 from bayesflow .networks import ConsistencyModel
9194
92- return ConsistencyModel (total_steps = 100 , subnet = MLP ([8 , 8 ]))
95+ return ConsistencyModel (
96+ total_steps = 100 ,
97+ subnet = MLP ([8 , 8 ]),
98+ metrics = [RootMeanSquaredError ()],
99+ )
93100
94101
95102@pytest .fixture ()
96103def affine_coupling_flow ():
97104 from bayesflow .networks import CouplingFlow
98105
99106 return CouplingFlow (
100- depth = 2 , subnet = "mlp" , subnet_kwargs = dict (widths = [8 , 8 ]), transform = "affine" , transform_kwargs = dict (clamp = 1.8 )
107+ depth = 2 ,
108+ subnet = "mlp" ,
109+ subnet_kwargs = dict (widths = [8 , 8 ]),
110+ transform = "affine" ,
111+ transform_kwargs = dict (clamp = 1.8 ),
112+ metrics = [RootMeanSquaredError ()],
101113 )
102114
103115
@@ -106,15 +118,24 @@ def spline_coupling_flow():
106118 from bayesflow .networks import CouplingFlow
107119
108120 return CouplingFlow (
109- depth = 2 , subnet = "mlp" , subnet_kwargs = dict (widths = [8 , 8 ]), transform = "spline" , transform_kwargs = dict (bins = 8 )
121+ depth = 2 ,
122+ subnet = "mlp" ,
123+ subnet_kwargs = dict (widths = [8 , 8 ]),
124+ transform = "spline" ,
125+ transform_kwargs = dict (bins = 8 ),
126+ metrics = [RootMeanSquaredError ()],
110127 )
111128
112129
113130@pytest .fixture ()
114131def free_form_flow ():
115132 from bayesflow .experimental import FreeFormFlow
116133
117- return FreeFormFlow (encoder_subnet = MLP ([16 , 16 ]), decoder_subnet = MLP ([16 , 16 ]))
134+ return FreeFormFlow (
135+ encoder_subnet = MLP ([16 , 16 ]),
136+ decoder_subnet = MLP ([16 , 16 ]),
137+ metrics = [RootMeanSquaredError ()],
138+ )
118139
119140
120141@pytest .fixture ()
@@ -236,35 +257,35 @@ def generative_inference_network(request):
236257def time_series_network (summary_dim ):
237258 from bayesflow .networks import TimeSeriesNetwork
238259
239- return TimeSeriesNetwork (summary_dim = summary_dim )
260+ return TimeSeriesNetwork (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
240261
241262
242263@pytest .fixture (scope = "function" )
243264def time_series_transformer (summary_dim ):
244265 from bayesflow .networks import TimeSeriesTransformer
245266
246- return TimeSeriesTransformer (summary_dim = summary_dim )
267+ return TimeSeriesTransformer (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
247268
248269
249270@pytest .fixture (scope = "function" )
250271def fusion_transformer (summary_dim ):
251272 from bayesflow .networks import FusionTransformer
252273
253- return FusionTransformer (summary_dim = summary_dim )
274+ return FusionTransformer (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
254275
255276
256277@pytest .fixture (scope = "function" )
257278def set_transformer (summary_dim ):
258279 from bayesflow .networks import SetTransformer
259280
260- return SetTransformer (summary_dim = summary_dim )
281+ return SetTransformer (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
261282
262283
263284@pytest .fixture (scope = "function" )
264285def deep_set (summary_dim ):
265286 from bayesflow .networks import DeepSet
266287
267- return DeepSet (summary_dim = summary_dim )
288+ return DeepSet (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
268289
269290
270291@pytest .fixture (
0 commit comments