Skip to content

Commit

Permalink
Improved Integrators (#300)
Browse files Browse the repository at this point in the history
* remove old integrators

* change logging level of debug infos

* add utils/integrate.py

* update usage of integration in flow matching

* fix #288

* set logging level to debug for tests

* add seed parameter to jacobian trace for stochastic evaluation under compiled backends

* fix shape of trace in flow matching

* fix integration for negative step size (mostly)

* Add user-defined loss functions due to empirically better performance of some non-MSE losses sometimes

* fix negative step size integration
fix small deviations in some backends
add (required) min and max step number selection for dynamic integration
improve dispatch
remove step size selection (users can compute this if they need it, but exposing the argument is ambiguous w.r.t. fixed vs adaptive integration)

* speed up build test

* fix density computation
add default integrate kwargs

* reduce default number of max steps for dynamic step size integration

* allow specifying steps = "dynamic" instead of just "adaptive"

* add integrate kwargs to serialization

* add todo for density test

* improve time broadcasting

* fix tensorflow incompatible types

---------

Co-authored-by: stefanradev93 <[email protected]>
  • Loading branch information
LarsKue and stefanradev93 authored Feb 7, 2025
1 parent a12e7b3 commit 9071be4
Show file tree
Hide file tree
Showing 12 changed files with 396 additions and 311 deletions.
7 changes: 3 additions & 4 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
workflows,
utils,
)

from .workflows import BasicWorkflow
from .approximators import ContinuousApproximator
from .adapters import Adapter
from .approximators import ContinuousApproximator
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow


def setup():
Expand All @@ -38,7 +37,7 @@ def setup():

from bayesflow.utils import logging

logging.info(f"Using backend {keras.backend.backend()!r}")
logging.debug(f"Using backend {keras.backend.backend()!r}")


# call and clean up namespace
Expand Down
149 changes: 111 additions & 38 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from collections.abc import Sequence

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils import (
expand_right_as,
find_network,
integrate,
jacobian_trace,
keras_kwargs,
optimal_transport,
serialize_value_or_type,
deserialize_value_or_type,
)
from ..inference_network import InferenceNetwork
from .integrators import EulerIntegrator
from .integrators import RK2Integrator
from .integrators import RK4Integrator


@serializable(package="bayesflow.networks")
Expand All @@ -30,47 +31,71 @@ def __init__(
self,
subnet: str | type = "mlp",
base_distribution: str = "normal",
integrator: str = "euler",
use_optimal_transport: bool = False,
loss_fn: str = "mse",
integrate_kwargs: dict[str, any] = None,
optimal_transport_kwargs: dict[str, any] = None,
**kwargs,
):
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))

self.use_optimal_transport = use_optimal_transport
self.optimal_transport_kwargs = optimal_transport_kwargs or {
"method": "sinkhorn",
"cost": "euclidean",
"regularization": 0.1,
"max_steps": 1000,
"tolerance": 1e-4,
}

if integrate_kwargs is None:
integrate_kwargs = {
"method": "rk45",
"steps": "adaptive",
"tolerance": 1e-3,
"min_steps": 10,
"max_steps": 100,
}

self.integrate_kwargs = integrate_kwargs

if optimal_transport_kwargs is None:
optimal_transport_kwargs = {
"method": "sinkhorn",
"cost": "euclidean",
"regularization": 0.1,
"max_steps": 100,
"tolerance": 1e-4,
}

self.loss_fn = keras.losses.get(loss_fn)

self.optimal_transport_kwargs = optimal_transport_kwargs

self.seed_generator = keras.random.SeedGenerator()

match integrator:
case "euler":
self.integrator = EulerIntegrator(subnet, **kwargs)
case "rk2":
self.integrator = RK2Integrator(subnet, **kwargs)
case "rk4":
self.integrator = RK4Integrator(subnet, **kwargs)
case _:
raise NotImplementedError(f"No support for {integrator} integration")
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")

# serialization: store all parameters necessary to call __init__
self.config = {
"base_distribution": base_distribution,
"integrator": integrator,
"use_optimal_transport": use_optimal_transport,
"optimal_transport_kwargs": optimal_transport_kwargs,
"integrate_kwargs": integrate_kwargs,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
super().build(xz_shape)
self.integrator.build(xz_shape, conditions_shape)
super().build(xz_shape, conditions_shape=conditions_shape)

self.output_projector.units = xz_shape[-1]
input_shape = list(xz_shape)

# construct time vector
input_shape[-1] += 1
if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]

input_shape = tuple(input_shape)

self.subnet.build(input_shape)
out_shape = self.subnet.compute_output_shape(input_shape)
self.output_projector.build(out_shape)

def get_config(self):
base_config = super().get_config()
Expand All @@ -81,32 +106,80 @@ def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
t = keras.ops.convert_to_tensor(t)
t = expand_right_as(t, xz)
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1] + (1,))

if conditions is None:
xtc = keras.ops.concatenate([xz, t], axis=-1)
else:
xtc = keras.ops.concatenate([xz, t, conditions], axis=-1)

return self.output_projector(self.subnet(xtc, training=training), training=training)

def _velocity_trace(
self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
) -> (Tensor, Tensor):
def f(x):
return self.velocity(x, t, conditions=conditions, training=training)

v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)

return v, keras.ops.expand_dims(trace, axis=-1)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
steps = kwargs.get("steps", 100)

if density:
z, trace = self.integrator(x, conditions=conditions, steps=steps, density=True)
log_prob = self.base_distribution.log_prob(z)
log_density = log_prob + trace

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))

z = state["xz"]
log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1)

return z, log_density

z = self.integrator(x, conditions=conditions, steps=steps, density=False)
def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}

state = {"xz": x}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))

z = state["xz"]

return z

def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
steps = kwargs.get("steps", 100)

if density:
x, trace = self.integrator(z, conditions=conditions, steps=steps, density=True, inverse=True)
log_prob = self.base_distribution.log_prob(z)
log_density = log_prob - trace

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))

x = state["xz"]
log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1)

return x, log_density

x = self.integrator(z, conditions=conditions, steps=steps, density=False, inverse=True)
def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}

state = {"xz": z}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))

x = state["xz"]

return x

def compute_metrics(
Expand All @@ -118,7 +191,7 @@ def compute_metrics(
else:
# not pre-configured, resample
x1 = x
x0 = keras.random.normal(keras.ops.shape(x1), dtype=keras.ops.dtype(x1), seed=self.seed_generator)
x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator)

if self.use_optimal_transport:
x1, x0, conditions = optimal_transport(
Expand All @@ -133,9 +206,9 @@ def compute_metrics(

base_metrics = super().compute_metrics(x1, conditions, stage)

predicted_velocity = self.integrator.velocity(x, t, conditions)
predicted_velocity = self.velocity(x, t, conditions, training=stage == "training")

loss = keras.losses.mean_squared_error(target_velocity, predicted_velocity)
loss = self.loss_fn(target_velocity, predicted_velocity)
loss = keras.ops.mean(loss)

return base_metrics | {"loss": loss}
3 changes: 0 additions & 3 deletions bayesflow/networks/flow_matching/integrators/__init__.py

This file was deleted.

79 changes: 0 additions & 79 deletions bayesflow/networks/flow_matching/integrators/euler.py

This file was deleted.

11 changes: 0 additions & 11 deletions bayesflow/networks/flow_matching/integrators/integrator.py

This file was deleted.

Loading

0 comments on commit 9071be4

Please sign in to comment.