Skip to content

Commit 9071be4

Browse files
Improved Integrators (#300)
* remove old integrators * change logging level of debug infos * add utils/integrate.py * update usage of integration in flow matching * fix #288 * set logging level to debug for tests * add seed parameter to jacobian trace for stochastic evaluation under compiled backends * fix shape of trace in flow matching * fix integration for negative step size (mostly) * Add user-defined loss functions due to empirically better performance of some non-MSE losses sometimes * fix negative step size integration fix small deviations in some backends add (required) min and max step number selection for dynamic integration improve dispatch remove step size selection (users can compute this if they need it, but exposing the argument is ambiguous w.r.t. fixed vs adaptive integration) * speed up build test * fix density computation add default integrate kwargs * reduce default number of max steps for dynamic step size integration * allow specifying steps = "dynamic" instead of just "adaptive" * add integrate kwargs to serialization * add todo for density test * improve time broadcasting * fix tensorflow incompatible types --------- Co-authored-by: stefanradev93 <[email protected]>
1 parent a12e7b3 commit 9071be4

File tree

12 files changed

+396
-311
lines changed

12 files changed

+396
-311
lines changed

bayesflow/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
workflows,
1111
utils,
1212
)
13-
14-
from .workflows import BasicWorkflow
15-
from .approximators import ContinuousApproximator
1613
from .adapters import Adapter
14+
from .approximators import ContinuousApproximator
1715
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
1816
from .simulators import make_simulator
17+
from .workflows import BasicWorkflow
1918

2019

2120
def setup():
@@ -38,7 +37,7 @@ def setup():
3837

3938
from bayesflow.utils import logging
4039

41-
logging.info(f"Using backend {keras.backend.backend()!r}")
40+
logging.debug(f"Using backend {keras.backend.backend()!r}")
4241

4342

4443
# call and clean up namespace
Lines changed: 111 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from collections.abc import Sequence
2+
23
import keras
34
from keras.saving import register_keras_serializable as serializable
45

56
from bayesflow.types import Shape, Tensor
67
from bayesflow.utils import (
78
expand_right_as,
9+
find_network,
10+
integrate,
11+
jacobian_trace,
812
keras_kwargs,
913
optimal_transport,
1014
serialize_value_or_type,
1115
deserialize_value_or_type,
1216
)
1317
from ..inference_network import InferenceNetwork
14-
from .integrators import EulerIntegrator
15-
from .integrators import RK2Integrator
16-
from .integrators import RK4Integrator
1718

1819

1920
@serializable(package="bayesflow.networks")
@@ -30,47 +31,71 @@ def __init__(
3031
self,
3132
subnet: str | type = "mlp",
3233
base_distribution: str = "normal",
33-
integrator: str = "euler",
3434
use_optimal_transport: bool = False,
35+
loss_fn: str = "mse",
36+
integrate_kwargs: dict[str, any] = None,
3537
optimal_transport_kwargs: dict[str, any] = None,
3638
**kwargs,
3739
):
3840
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))
3941

4042
self.use_optimal_transport = use_optimal_transport
41-
self.optimal_transport_kwargs = optimal_transport_kwargs or {
42-
"method": "sinkhorn",
43-
"cost": "euclidean",
44-
"regularization": 0.1,
45-
"max_steps": 1000,
46-
"tolerance": 1e-4,
47-
}
43+
44+
if integrate_kwargs is None:
45+
integrate_kwargs = {
46+
"method": "rk45",
47+
"steps": "adaptive",
48+
"tolerance": 1e-3,
49+
"min_steps": 10,
50+
"max_steps": 100,
51+
}
52+
53+
self.integrate_kwargs = integrate_kwargs
54+
55+
if optimal_transport_kwargs is None:
56+
optimal_transport_kwargs = {
57+
"method": "sinkhorn",
58+
"cost": "euclidean",
59+
"regularization": 0.1,
60+
"max_steps": 100,
61+
"tolerance": 1e-4,
62+
}
63+
64+
self.loss_fn = keras.losses.get(loss_fn)
65+
66+
self.optimal_transport_kwargs = optimal_transport_kwargs
4867

4968
self.seed_generator = keras.random.SeedGenerator()
5069

51-
match integrator:
52-
case "euler":
53-
self.integrator = EulerIntegrator(subnet, **kwargs)
54-
case "rk2":
55-
self.integrator = RK2Integrator(subnet, **kwargs)
56-
case "rk4":
57-
self.integrator = RK4Integrator(subnet, **kwargs)
58-
case _:
59-
raise NotImplementedError(f"No support for {integrator} integration")
70+
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
71+
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")
6072

6173
# serialization: store all parameters necessary to call __init__
6274
self.config = {
6375
"base_distribution": base_distribution,
64-
"integrator": integrator,
6576
"use_optimal_transport": use_optimal_transport,
6677
"optimal_transport_kwargs": optimal_transport_kwargs,
78+
"integrate_kwargs": integrate_kwargs,
6779
**kwargs,
6880
}
6981
self.config = serialize_value_or_type(self.config, "subnet", subnet)
7082

7183
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
72-
super().build(xz_shape)
73-
self.integrator.build(xz_shape, conditions_shape)
84+
super().build(xz_shape, conditions_shape=conditions_shape)
85+
86+
self.output_projector.units = xz_shape[-1]
87+
input_shape = list(xz_shape)
88+
89+
# construct time vector
90+
input_shape[-1] += 1
91+
if conditions_shape is not None:
92+
input_shape[-1] += conditions_shape[-1]
93+
94+
input_shape = tuple(input_shape)
95+
96+
self.subnet.build(input_shape)
97+
out_shape = self.subnet.compute_output_shape(input_shape)
98+
self.output_projector.build(out_shape)
7499

75100
def get_config(self):
76101
base_config = super().get_config()
@@ -81,32 +106,80 @@ def from_config(cls, config):
81106
config = deserialize_value_or_type(config, "subnet")
82107
return cls(**config)
83108

109+
def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
110+
t = keras.ops.convert_to_tensor(t)
111+
t = expand_right_as(t, xz)
112+
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1] + (1,))
113+
114+
if conditions is None:
115+
xtc = keras.ops.concatenate([xz, t], axis=-1)
116+
else:
117+
xtc = keras.ops.concatenate([xz, t, conditions], axis=-1)
118+
119+
return self.output_projector(self.subnet(xtc, training=training), training=training)
120+
121+
def _velocity_trace(
122+
self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
123+
) -> (Tensor, Tensor):
124+
def f(x):
125+
return self.velocity(x, t, conditions=conditions, training=training)
126+
127+
v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)
128+
129+
return v, keras.ops.expand_dims(trace, axis=-1)
130+
84131
def _forward(
85132
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
86133
) -> Tensor | tuple[Tensor, Tensor]:
87-
steps = kwargs.get("steps", 100)
88-
89134
if density:
90-
z, trace = self.integrator(x, conditions=conditions, steps=steps, density=True)
91-
log_prob = self.base_distribution.log_prob(z)
92-
log_density = log_prob + trace
135+
136+
def deltas(t, xz):
137+
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
138+
return {"xz": v, "trace": trace}
139+
140+
state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))}
141+
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
142+
143+
z = state["xz"]
144+
log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1)
145+
93146
return z, log_density
94147

95-
z = self.integrator(x, conditions=conditions, steps=steps, density=False)
148+
def deltas(t, xz):
149+
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}
150+
151+
state = {"xz": x}
152+
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
153+
154+
z = state["xz"]
155+
96156
return z
97157

98158
def _inverse(
99159
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
100160
) -> Tensor | tuple[Tensor, Tensor]:
101-
steps = kwargs.get("steps", 100)
102-
103161
if density:
104-
x, trace = self.integrator(z, conditions=conditions, steps=steps, density=True, inverse=True)
105-
log_prob = self.base_distribution.log_prob(z)
106-
log_density = log_prob - trace
162+
163+
def deltas(t, xz):
164+
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
165+
return {"xz": v, "trace": trace}
166+
167+
state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))}
168+
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
169+
170+
x = state["xz"]
171+
log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1)
172+
107173
return x, log_density
108174

109-
x = self.integrator(z, conditions=conditions, steps=steps, density=False, inverse=True)
175+
def deltas(t, xz):
176+
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}
177+
178+
state = {"xz": z}
179+
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
180+
181+
x = state["xz"]
182+
110183
return x
111184

112185
def compute_metrics(
@@ -118,7 +191,7 @@ def compute_metrics(
118191
else:
119192
# not pre-configured, resample
120193
x1 = x
121-
x0 = keras.random.normal(keras.ops.shape(x1), dtype=keras.ops.dtype(x1), seed=self.seed_generator)
194+
x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator)
122195

123196
if self.use_optimal_transport:
124197
x1, x0, conditions = optimal_transport(
@@ -133,9 +206,9 @@ def compute_metrics(
133206

134207
base_metrics = super().compute_metrics(x1, conditions, stage)
135208

136-
predicted_velocity = self.integrator.velocity(x, t, conditions)
209+
predicted_velocity = self.velocity(x, t, conditions, training=stage == "training")
137210

138-
loss = keras.losses.mean_squared_error(target_velocity, predicted_velocity)
211+
loss = self.loss_fn(target_velocity, predicted_velocity)
139212
loss = keras.ops.mean(loss)
140213

141214
return base_metrics | {"loss": loss}

bayesflow/networks/flow_matching/integrators/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

bayesflow/networks/flow_matching/integrators/euler.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

bayesflow/networks/flow_matching/integrators/integrator.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)