Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Integrators #300

Merged
merged 19 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
workflows,
utils,
)

from .workflows import BasicWorkflow
from .approximators import ContinuousApproximator
from .adapters import Adapter
from .approximators import ContinuousApproximator
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow


def setup():
Expand All @@ -38,7 +37,7 @@ def setup():

from bayesflow.utils import logging

logging.info(f"Using backend {keras.backend.backend()!r}")
logging.debug(f"Using backend {keras.backend.backend()!r}")


# call and clean up namespace
Expand Down
149 changes: 111 additions & 38 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from collections.abc import Sequence

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils import (
expand_right_as,
find_network,
integrate,
jacobian_trace,
keras_kwargs,
optimal_transport,
serialize_value_or_type,
deserialize_value_or_type,
)
from ..inference_network import InferenceNetwork
from .integrators import EulerIntegrator
from .integrators import RK2Integrator
from .integrators import RK4Integrator


@serializable(package="bayesflow.networks")
Expand All @@ -30,47 +31,71 @@ def __init__(
self,
subnet: str | type = "mlp",
base_distribution: str = "normal",
integrator: str = "euler",
use_optimal_transport: bool = False,
loss_fn: str = "mse",
integrate_kwargs: dict[str, any] = None,
optimal_transport_kwargs: dict[str, any] = None,
**kwargs,
):
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))

self.use_optimal_transport = use_optimal_transport
self.optimal_transport_kwargs = optimal_transport_kwargs or {
"method": "sinkhorn",
"cost": "euclidean",
"regularization": 0.1,
"max_steps": 1000,
"tolerance": 1e-4,
}

if integrate_kwargs is None:
integrate_kwargs = {
"method": "rk45",
"steps": "adaptive",
"tolerance": 1e-3,
"min_steps": 10,
"max_steps": 100,
}

self.integrate_kwargs = integrate_kwargs

if optimal_transport_kwargs is None:
optimal_transport_kwargs = {
"method": "sinkhorn",
"cost": "euclidean",
"regularization": 0.1,
"max_steps": 100,
"tolerance": 1e-4,
}

self.loss_fn = keras.losses.get(loss_fn)

self.optimal_transport_kwargs = optimal_transport_kwargs

self.seed_generator = keras.random.SeedGenerator()

match integrator:
case "euler":
self.integrator = EulerIntegrator(subnet, **kwargs)
case "rk2":
self.integrator = RK2Integrator(subnet, **kwargs)
case "rk4":
self.integrator = RK4Integrator(subnet, **kwargs)
case _:
raise NotImplementedError(f"No support for {integrator} integration")
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")

# serialization: store all parameters necessary to call __init__
self.config = {
"base_distribution": base_distribution,
"integrator": integrator,
"use_optimal_transport": use_optimal_transport,
"optimal_transport_kwargs": optimal_transport_kwargs,
"integrate_kwargs": integrate_kwargs,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
super().build(xz_shape)
self.integrator.build(xz_shape, conditions_shape)
super().build(xz_shape, conditions_shape=conditions_shape)

self.output_projector.units = xz_shape[-1]
input_shape = list(xz_shape)

# construct time vector
input_shape[-1] += 1
if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]

input_shape = tuple(input_shape)

self.subnet.build(input_shape)
out_shape = self.subnet.compute_output_shape(input_shape)
self.output_projector.build(out_shape)

def get_config(self):
base_config = super().get_config()
Expand All @@ -81,32 +106,80 @@ def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
t = keras.ops.convert_to_tensor(t)
t = expand_right_as(t, xz)
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1] + (1,))

if conditions is None:
xtc = keras.ops.concatenate([xz, t], axis=-1)
else:
xtc = keras.ops.concatenate([xz, t, conditions], axis=-1)

return self.output_projector(self.subnet(xtc, training=training), training=training)

def _velocity_trace(
self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
) -> (Tensor, Tensor):
def f(x):
return self.velocity(x, t, conditions=conditions, training=training)

v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)

return v, keras.ops.expand_dims(trace, axis=-1)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
steps = kwargs.get("steps", 100)

if density:
z, trace = self.integrator(x, conditions=conditions, steps=steps, density=True)
log_prob = self.base_distribution.log_prob(z)
log_density = log_prob + trace

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))

z = state["xz"]
log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1)

return z, log_density

z = self.integrator(x, conditions=conditions, steps=steps, density=False)
def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}

state = {"xz": x}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))

z = state["xz"]

return z

def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
steps = kwargs.get("steps", 100)

if density:
x, trace = self.integrator(z, conditions=conditions, steps=steps, density=True, inverse=True)
log_prob = self.base_distribution.log_prob(z)
log_density = log_prob - trace

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))

x = state["xz"]
log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1)

return x, log_density

x = self.integrator(z, conditions=conditions, steps=steps, density=False, inverse=True)
def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}

state = {"xz": z}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))

x = state["xz"]

return x

def compute_metrics(
Expand All @@ -118,7 +191,7 @@ def compute_metrics(
else:
# not pre-configured, resample
x1 = x
x0 = keras.random.normal(keras.ops.shape(x1), dtype=keras.ops.dtype(x1), seed=self.seed_generator)
x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator)

if self.use_optimal_transport:
x1, x0, conditions = optimal_transport(
Expand All @@ -133,9 +206,9 @@ def compute_metrics(

base_metrics = super().compute_metrics(x1, conditions, stage)

predicted_velocity = self.integrator.velocity(x, t, conditions)
predicted_velocity = self.velocity(x, t, conditions, training=stage == "training")

loss = keras.losses.mean_squared_error(target_velocity, predicted_velocity)
loss = self.loss_fn(target_velocity, predicted_velocity)
loss = keras.ops.mean(loss)

return base_metrics | {"loss": loss}
3 changes: 0 additions & 3 deletions bayesflow/networks/flow_matching/integrators/__init__.py

This file was deleted.

79 changes: 0 additions & 79 deletions bayesflow/networks/flow_matching/integrators/euler.py

This file was deleted.

11 changes: 0 additions & 11 deletions bayesflow/networks/flow_matching/integrators/integrator.py

This file was deleted.

Loading