1
- import keras
2
1
import pytest
3
2
4
- import bayesflow as bf
3
+
4
+ @pytest .fixture ()
5
+ def batch_size ():
6
+ return 8
5
7
6
8
7
9
@pytest .fixture ()
@@ -11,40 +13,43 @@ def summary_network():
11
13
12
14
@pytest .fixture ()
13
15
def inference_network ():
14
- network = keras . Sequential ([ keras . layers . Dense ( 10 )])
15
- network . compile ( loss = "mse" )
16
- return network
16
+ from bayesflow . networks import CouplingFlow
17
+
18
+ return CouplingFlow ()
17
19
18
20
19
21
@pytest .fixture ()
20
22
def approximator (inference_network , summary_network ):
21
- return bf .Approximator (
23
+ from bayesflow import Approximator
24
+
25
+ return Approximator (
22
26
inference_network = inference_network ,
23
27
summary_network = summary_network ,
24
- inference_variables = [],
25
- inference_conditions = [],
26
- summary_variables = [],
27
- summary_conditions = [],
28
+ inference_variables = ["mean" , "std" ],
29
+ inference_conditions = ["x" ],
28
30
)
29
31
30
32
31
33
@pytest .fixture ()
32
- def dataset ():
33
- batch_size = 16
34
- batches_per_epoch = 4
35
- parameter_sets = batch_size * batches_per_epoch
36
- observations_per_parameter_set = 32
34
+ def simulator ():
35
+ from bayesflow .simulators import NormalSimulator
37
36
38
- mean = keras .random .normal (mean = 0.0 , stddev = 0.1 , shape = (parameter_sets , 2 ))
39
- std = keras .ops .exp (keras .random .normal (mean = 0.0 , stddev = 0.1 , shape = (parameter_sets , 2 )))
37
+ return NormalSimulator ()
40
38
41
- mean = keras .ops .repeat (mean [:, None ], observations_per_parameter_set , 1 )
42
- std = keras .ops .repeat (std [:, None ], observations_per_parameter_set , 1 )
43
39
44
- noise = keras .random .normal (shape = (parameter_sets , observations_per_parameter_set , 2 ))
40
+ @pytest .fixture ()
41
+ def train_dataset (simulator , batch_size ):
42
+ from bayesflow import OfflineDataset
43
+
44
+ num_batches = 4
45
+ data = simulator .sample ((num_batches * batch_size ,))
46
+ return OfflineDataset (data , workers = 4 , max_queue_size = num_batches , batch_size = batch_size )
45
47
46
- x = mean + std * noise
47
48
48
- data = dict (observables = dict (x = x ), parameters = dict (mean = mean , std = std ))
49
+ @pytest .fixture ()
50
+ def validation_dataset (simulator , batch_size ):
51
+ from bayesflow import OfflineDataset
49
52
50
- return bf .datasets .OfflineDataset (data , batch_size = batch_size , batches_per_epoch = batches_per_epoch )
53
+ num_batches = 2
54
+ data = simulator .sample ((num_batches * batch_size ,))
55
+ return OfflineDataset (data , workers = 4 , max_queue_size = num_batches , batch_size = batch_size )
0 commit comments