1
1
from collections .abc import Sequence
2
+
2
3
import keras
3
4
from keras .saving import register_keras_serializable as serializable
4
5
5
6
from bayesflow .types import Shape , Tensor
6
7
from bayesflow .utils import (
7
8
expand_right_as ,
9
+ find_network ,
10
+ integrate ,
11
+ jacobian_trace ,
8
12
keras_kwargs ,
9
13
optimal_transport ,
10
14
serialize_value_or_type ,
11
15
deserialize_value_or_type ,
12
16
)
13
17
from ..inference_network import InferenceNetwork
14
- from .integrators import EulerIntegrator
15
- from .integrators import RK2Integrator
16
- from .integrators import RK4Integrator
17
18
18
19
19
20
@serializable (package = "bayesflow.networks" )
@@ -30,47 +31,71 @@ def __init__(
30
31
self ,
31
32
subnet : str | type = "mlp" ,
32
33
base_distribution : str = "normal" ,
33
- integrator : str = "euler" ,
34
34
use_optimal_transport : bool = False ,
35
+ loss_fn : str = "mse" ,
36
+ integrate_kwargs : dict [str , any ] = None ,
35
37
optimal_transport_kwargs : dict [str , any ] = None ,
36
38
** kwargs ,
37
39
):
38
40
super ().__init__ (base_distribution = base_distribution , ** keras_kwargs (kwargs ))
39
41
40
42
self .use_optimal_transport = use_optimal_transport
41
- self .optimal_transport_kwargs = optimal_transport_kwargs or {
42
- "method" : "sinkhorn" ,
43
- "cost" : "euclidean" ,
44
- "regularization" : 0.1 ,
45
- "max_steps" : 1000 ,
46
- "tolerance" : 1e-4 ,
47
- }
43
+
44
+ if integrate_kwargs is None :
45
+ integrate_kwargs = {
46
+ "method" : "rk45" ,
47
+ "steps" : "adaptive" ,
48
+ "tolerance" : 1e-3 ,
49
+ "min_steps" : 10 ,
50
+ "max_steps" : 100 ,
51
+ }
52
+
53
+ self .integrate_kwargs = integrate_kwargs
54
+
55
+ if optimal_transport_kwargs is None :
56
+ optimal_transport_kwargs = {
57
+ "method" : "sinkhorn" ,
58
+ "cost" : "euclidean" ,
59
+ "regularization" : 0.1 ,
60
+ "max_steps" : 100 ,
61
+ "tolerance" : 1e-4 ,
62
+ }
63
+
64
+ self .loss_fn = keras .losses .get (loss_fn )
65
+
66
+ self .optimal_transport_kwargs = optimal_transport_kwargs
48
67
49
68
self .seed_generator = keras .random .SeedGenerator ()
50
69
51
- match integrator :
52
- case "euler" :
53
- self .integrator = EulerIntegrator (subnet , ** kwargs )
54
- case "rk2" :
55
- self .integrator = RK2Integrator (subnet , ** kwargs )
56
- case "rk4" :
57
- self .integrator = RK4Integrator (subnet , ** kwargs )
58
- case _:
59
- raise NotImplementedError (f"No support for { integrator } integration" )
70
+ self .subnet = find_network (subnet , ** kwargs .get ("subnet_kwargs" , {}))
71
+ self .output_projector = keras .layers .Dense (units = None , bias_initializer = "zeros" )
60
72
61
73
# serialization: store all parameters necessary to call __init__
62
74
self .config = {
63
75
"base_distribution" : base_distribution ,
64
- "integrator" : integrator ,
65
76
"use_optimal_transport" : use_optimal_transport ,
66
77
"optimal_transport_kwargs" : optimal_transport_kwargs ,
78
+ "integrate_kwargs" : integrate_kwargs ,
67
79
** kwargs ,
68
80
}
69
81
self .config = serialize_value_or_type (self .config , "subnet" , subnet )
70
82
71
83
def build (self , xz_shape : Shape , conditions_shape : Shape = None ) -> None :
72
- super ().build (xz_shape )
73
- self .integrator .build (xz_shape , conditions_shape )
84
+ super ().build (xz_shape , conditions_shape = conditions_shape )
85
+
86
+ self .output_projector .units = xz_shape [- 1 ]
87
+ input_shape = list (xz_shape )
88
+
89
+ # construct time vector
90
+ input_shape [- 1 ] += 1
91
+ if conditions_shape is not None :
92
+ input_shape [- 1 ] += conditions_shape [- 1 ]
93
+
94
+ input_shape = tuple (input_shape )
95
+
96
+ self .subnet .build (input_shape )
97
+ out_shape = self .subnet .compute_output_shape (input_shape )
98
+ self .output_projector .build (out_shape )
74
99
75
100
def get_config (self ):
76
101
base_config = super ().get_config ()
@@ -81,32 +106,80 @@ def from_config(cls, config):
81
106
config = deserialize_value_or_type (config , "subnet" )
82
107
return cls (** config )
83
108
109
+ def velocity (self , xz : Tensor , t : float | Tensor , conditions : Tensor = None , training : bool = False ) -> Tensor :
110
+ t = keras .ops .convert_to_tensor (t )
111
+ t = expand_right_as (t , xz )
112
+ t = keras .ops .broadcast_to (t , keras .ops .shape (xz )[:- 1 ] + (1 ,))
113
+
114
+ if conditions is None :
115
+ xtc = keras .ops .concatenate ([xz , t ], axis = - 1 )
116
+ else :
117
+ xtc = keras .ops .concatenate ([xz , t , conditions ], axis = - 1 )
118
+
119
+ return self .output_projector (self .subnet (xtc , training = training ), training = training )
120
+
121
+ def _velocity_trace (
122
+ self , xz : Tensor , t : Tensor , conditions : Tensor = None , max_steps : int = None , training : bool = False
123
+ ) -> (Tensor , Tensor ):
124
+ def f (x ):
125
+ return self .velocity (x , t , conditions = conditions , training = training )
126
+
127
+ v , trace = jacobian_trace (f , xz , max_steps = max_steps , seed = self .seed_generator , return_output = True )
128
+
129
+ return v , keras .ops .expand_dims (trace , axis = - 1 )
130
+
84
131
def _forward (
85
132
self , x : Tensor , conditions : Tensor = None , density : bool = False , training : bool = False , ** kwargs
86
133
) -> Tensor | tuple [Tensor , Tensor ]:
87
- steps = kwargs .get ("steps" , 100 )
88
-
89
134
if density :
90
- z , trace = self .integrator (x , conditions = conditions , steps = steps , density = True )
91
- log_prob = self .base_distribution .log_prob (z )
92
- log_density = log_prob + trace
135
+
136
+ def deltas (t , xz ):
137
+ v , trace = self ._velocity_trace (xz , t , conditions = conditions , training = training )
138
+ return {"xz" : v , "trace" : trace }
139
+
140
+ state = {"xz" : x , "trace" : keras .ops .zeros (keras .ops .shape (x )[:- 1 ] + (1 ,), dtype = keras .ops .dtype (x ))}
141
+ state = integrate (deltas , state , start_time = 1.0 , stop_time = 0.0 , ** (self .integrate_kwargs | kwargs ))
142
+
143
+ z = state ["xz" ]
144
+ log_density = self .base_distribution .log_prob (z ) + keras .ops .squeeze (state ["trace" ], axis = - 1 )
145
+
93
146
return z , log_density
94
147
95
- z = self .integrator (x , conditions = conditions , steps = steps , density = False )
148
+ def deltas (t , xz ):
149
+ return {"xz" : self .velocity (xz , t , conditions = conditions , training = training )}
150
+
151
+ state = {"xz" : x }
152
+ state = integrate (deltas , state , start_time = 1.0 , stop_time = 0.0 , ** (self .integrate_kwargs | kwargs ))
153
+
154
+ z = state ["xz" ]
155
+
96
156
return z
97
157
98
158
def _inverse (
99
159
self , z : Tensor , conditions : Tensor = None , density : bool = False , training : bool = False , ** kwargs
100
160
) -> Tensor | tuple [Tensor , Tensor ]:
101
- steps = kwargs .get ("steps" , 100 )
102
-
103
161
if density :
104
- x , trace = self .integrator (z , conditions = conditions , steps = steps , density = True , inverse = True )
105
- log_prob = self .base_distribution .log_prob (z )
106
- log_density = log_prob - trace
162
+
163
+ def deltas (t , xz ):
164
+ v , trace = self ._velocity_trace (xz , t , conditions = conditions , training = training )
165
+ return {"xz" : v , "trace" : trace }
166
+
167
+ state = {"xz" : z , "trace" : keras .ops .zeros (keras .ops .shape (z )[:- 1 ] + (1 ,), dtype = keras .ops .dtype (z ))}
168
+ state = integrate (deltas , state , start_time = 0.0 , stop_time = 1.0 , ** (self .integrate_kwargs | kwargs ))
169
+
170
+ x = state ["xz" ]
171
+ log_density = self .base_distribution .log_prob (z ) - keras .ops .squeeze (state ["trace" ], axis = - 1 )
172
+
107
173
return x , log_density
108
174
109
- x = self .integrator (z , conditions = conditions , steps = steps , density = False , inverse = True )
175
+ def deltas (t , xz ):
176
+ return {"xz" : self .velocity (xz , t , conditions = conditions , training = training )}
177
+
178
+ state = {"xz" : z }
179
+ state = integrate (deltas , state , start_time = 0.0 , stop_time = 1.0 , ** (self .integrate_kwargs | kwargs ))
180
+
181
+ x = state ["xz" ]
182
+
110
183
return x
111
184
112
185
def compute_metrics (
@@ -118,7 +191,7 @@ def compute_metrics(
118
191
else :
119
192
# not pre-configured, resample
120
193
x1 = x
121
- x0 = keras . random . normal (keras .ops .shape ( x1 ), dtype = keras . ops . dtype (x1 ), seed = self .seed_generator )
194
+ x0 = self . base_distribution . sample (keras .ops .shape (x1 ), seed = self .seed_generator )
122
195
123
196
if self .use_optimal_transport :
124
197
x1 , x0 , conditions = optimal_transport (
@@ -133,9 +206,9 @@ def compute_metrics(
133
206
134
207
base_metrics = super ().compute_metrics (x1 , conditions , stage )
135
208
136
- predicted_velocity = self .integrator . velocity (x , t , conditions )
209
+ predicted_velocity = self .velocity (x , t , conditions , training = stage == "training" )
137
210
138
- loss = keras . losses . mean_squared_error (target_velocity , predicted_velocity )
211
+ loss = self . loss_fn (target_velocity , predicted_velocity )
139
212
loss = keras .ops .mean (loss )
140
213
141
214
return base_metrics | {"loss" : loss }
0 commit comments