From 9071be4e79c6374e1353a4dd62b70e646f7b48cb Mon Sep 17 00:00:00 2001 From: Lars Date: Fri, 7 Feb 2025 18:32:31 +0100 Subject: [PATCH] Improved Integrators (#300) * remove old integrators * change logging level of debug infos * add utils/integrate.py * update usage of integration in flow matching * fix #288 * set logging level to debug for tests * add seed parameter to jacobian trace for stochastic evaluation under compiled backends * fix shape of trace in flow matching * fix integration for negative step size (mostly) * Add user-defined loss functions due to empirically better performance of some non-MSE losses sometimes * fix negative step size integration fix small deviations in some backends add (required) min and max step number selection for dynamic integration improve dispatch remove step size selection (users can compute this if they need it, but exposing the argument is ambiguous w.r.t. fixed vs adaptive integration) * speed up build test * fix density computation add default integrate kwargs * reduce default number of max steps for dynamic step size integration * allow specifying steps = "dynamic" instead of just "adaptive" * add integrate kwargs to serialization * add todo for density test * improve time broadcasting * fix tensorflow incompatible types --------- Co-authored-by: stefanradev93 --- bayesflow/__init__.py | 7 +- .../networks/flow_matching/flow_matching.py | 149 ++++++++--- .../flow_matching/integrators/__init__.py | 3 - .../flow_matching/integrators/euler.py | 79 ------ .../flow_matching/integrators/integrator.py | 11 - .../flow_matching/integrators/runge_kutta.py | 82 ------ .../integrators/runge_kutta_4.py | 86 ------ bayesflow/utils/__init__.py | 3 + bayesflow/utils/integrate.py | 251 ++++++++++++++++++ bayesflow/utils/jacobian/jacobian_trace.py | 21 +- tests/conftest.py | 5 +- .../test_networks/test_inference_networks.py | 10 +- 12 files changed, 396 insertions(+), 311 deletions(-) delete mode 100644 bayesflow/networks/flow_matching/integrators/__init__.py delete mode 100644 bayesflow/networks/flow_matching/integrators/euler.py delete mode 100644 bayesflow/networks/flow_matching/integrators/integrator.py delete mode 100644 bayesflow/networks/flow_matching/integrators/runge_kutta.py delete mode 100644 bayesflow/networks/flow_matching/integrators/runge_kutta_4.py create mode 100644 bayesflow/utils/integrate.py diff --git a/bayesflow/__init__.py b/bayesflow/__init__.py index 5375e232e..027d36191 100644 --- a/bayesflow/__init__.py +++ b/bayesflow/__init__.py @@ -10,12 +10,11 @@ workflows, utils, ) - -from .workflows import BasicWorkflow -from .approximators import ContinuousApproximator from .adapters import Adapter +from .approximators import ContinuousApproximator from .datasets import OfflineDataset, OnlineDataset, DiskDataset from .simulators import make_simulator +from .workflows import BasicWorkflow def setup(): @@ -38,7 +37,7 @@ def setup(): from bayesflow.utils import logging - logging.info(f"Using backend {keras.backend.backend()!r}") + logging.debug(f"Using backend {keras.backend.backend()!r}") # call and clean up namespace diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 118d3546f..19bd214d0 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -1,19 +1,20 @@ from collections.abc import Sequence + import keras from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.utils import ( expand_right_as, + find_network, + integrate, + jacobian_trace, keras_kwargs, optimal_transport, serialize_value_or_type, deserialize_value_or_type, ) from ..inference_network import InferenceNetwork -from .integrators import EulerIntegrator -from .integrators import RK2Integrator -from .integrators import RK4Integrator @serializable(package="bayesflow.networks") @@ -30,47 +31,71 @@ def __init__( self, subnet: str | type = "mlp", base_distribution: str = "normal", - integrator: str = "euler", use_optimal_transport: bool = False, + loss_fn: str = "mse", + integrate_kwargs: dict[str, any] = None, optimal_transport_kwargs: dict[str, any] = None, **kwargs, ): super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs)) self.use_optimal_transport = use_optimal_transport - self.optimal_transport_kwargs = optimal_transport_kwargs or { - "method": "sinkhorn", - "cost": "euclidean", - "regularization": 0.1, - "max_steps": 1000, - "tolerance": 1e-4, - } + + if integrate_kwargs is None: + integrate_kwargs = { + "method": "rk45", + "steps": "adaptive", + "tolerance": 1e-3, + "min_steps": 10, + "max_steps": 100, + } + + self.integrate_kwargs = integrate_kwargs + + if optimal_transport_kwargs is None: + optimal_transport_kwargs = { + "method": "sinkhorn", + "cost": "euclidean", + "regularization": 0.1, + "max_steps": 100, + "tolerance": 1e-4, + } + + self.loss_fn = keras.losses.get(loss_fn) + + self.optimal_transport_kwargs = optimal_transport_kwargs self.seed_generator = keras.random.SeedGenerator() - match integrator: - case "euler": - self.integrator = EulerIntegrator(subnet, **kwargs) - case "rk2": - self.integrator = RK2Integrator(subnet, **kwargs) - case "rk4": - self.integrator = RK4Integrator(subnet, **kwargs) - case _: - raise NotImplementedError(f"No support for {integrator} integration") + self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {})) + self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros") # serialization: store all parameters necessary to call __init__ self.config = { "base_distribution": base_distribution, - "integrator": integrator, "use_optimal_transport": use_optimal_transport, "optimal_transport_kwargs": optimal_transport_kwargs, + "integrate_kwargs": integrate_kwargs, **kwargs, } self.config = serialize_value_or_type(self.config, "subnet", subnet) def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: - super().build(xz_shape) - self.integrator.build(xz_shape, conditions_shape) + super().build(xz_shape, conditions_shape=conditions_shape) + + self.output_projector.units = xz_shape[-1] + input_shape = list(xz_shape) + + # construct time vector + input_shape[-1] += 1 + if conditions_shape is not None: + input_shape[-1] += conditions_shape[-1] + + input_shape = tuple(input_shape) + + self.subnet.build(input_shape) + out_shape = self.subnet.compute_output_shape(input_shape) + self.output_projector.build(out_shape) def get_config(self): base_config = super().get_config() @@ -81,32 +106,80 @@ def from_config(cls, config): config = deserialize_value_or_type(config, "subnet") return cls(**config) + def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor: + t = keras.ops.convert_to_tensor(t) + t = expand_right_as(t, xz) + t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1] + (1,)) + + if conditions is None: + xtc = keras.ops.concatenate([xz, t], axis=-1) + else: + xtc = keras.ops.concatenate([xz, t, conditions], axis=-1) + + return self.output_projector(self.subnet(xtc, training=training), training=training) + + def _velocity_trace( + self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False + ) -> (Tensor, Tensor): + def f(x): + return self.velocity(x, t, conditions=conditions, training=training) + + v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True) + + return v, keras.ops.expand_dims(trace, axis=-1) + def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: - steps = kwargs.get("steps", 100) - if density: - z, trace = self.integrator(x, conditions=conditions, steps=steps, density=True) - log_prob = self.base_distribution.log_prob(z) - log_density = log_prob + trace + + def deltas(t, xz): + v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training) + return {"xz": v, "trace": trace} + + state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))} + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + + z = state["xz"] + log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1) + return z, log_density - z = self.integrator(x, conditions=conditions, steps=steps, density=False) + def deltas(t, xz): + return {"xz": self.velocity(xz, t, conditions=conditions, training=training)} + + state = {"xz": x} + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + + z = state["xz"] + return z def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: - steps = kwargs.get("steps", 100) - if density: - x, trace = self.integrator(z, conditions=conditions, steps=steps, density=True, inverse=True) - log_prob = self.base_distribution.log_prob(z) - log_density = log_prob - trace + + def deltas(t, xz): + v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training) + return {"xz": v, "trace": trace} + + state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))} + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + + x = state["xz"] + log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1) + return x, log_density - x = self.integrator(z, conditions=conditions, steps=steps, density=False, inverse=True) + def deltas(t, xz): + return {"xz": self.velocity(xz, t, conditions=conditions, training=training)} + + state = {"xz": z} + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + + x = state["xz"] + return x def compute_metrics( @@ -118,7 +191,7 @@ def compute_metrics( else: # not pre-configured, resample x1 = x - x0 = keras.random.normal(keras.ops.shape(x1), dtype=keras.ops.dtype(x1), seed=self.seed_generator) + x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator) if self.use_optimal_transport: x1, x0, conditions = optimal_transport( @@ -133,9 +206,9 @@ def compute_metrics( base_metrics = super().compute_metrics(x1, conditions, stage) - predicted_velocity = self.integrator.velocity(x, t, conditions) + predicted_velocity = self.velocity(x, t, conditions, training=stage == "training") - loss = keras.losses.mean_squared_error(target_velocity, predicted_velocity) + loss = self.loss_fn(target_velocity, predicted_velocity) loss = keras.ops.mean(loss) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/flow_matching/integrators/__init__.py b/bayesflow/networks/flow_matching/integrators/__init__.py deleted file mode 100644 index 1686a89bc..000000000 --- a/bayesflow/networks/flow_matching/integrators/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .euler import EulerIntegrator -from .runge_kutta import RK2Integrator -from .runge_kutta_4 import RK4Integrator diff --git a/bayesflow/networks/flow_matching/integrators/euler.py b/bayesflow/networks/flow_matching/integrators/euler.py deleted file mode 100644 index 424d725ab..000000000 --- a/bayesflow/networks/flow_matching/integrators/euler.py +++ /dev/null @@ -1,79 +0,0 @@ -import keras -from bayesflow.types import Tensor, Shape -from bayesflow.utils import find_network, jacobian_trace, keras_kwargs, expand_right_as, tile_axis -from .integrator import Integrator - - -class EulerIntegrator(Integrator): - """ - TODO: docstring - """ - - def __init__(self, subnet: str | type = "mlp", **kwargs): - super().__init__(**keras_kwargs(kwargs)) - self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {})) - self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") - - def build(self, xz_shape: Shape, conditions_shape: Shape = None): - self.output_projector.units = xz_shape[-1] - input_shape = list(xz_shape) - - # construct time vector - input_shape[-1] += 1 - if conditions_shape is not None: - input_shape[-1] += conditions_shape[-1] - - input_shape = tuple(input_shape) - - self.subnet.build(input_shape) - out_shape = self.subnet.compute_output_shape(input_shape) - self.output_projector.build(out_shape) - - def velocity(self, x: Tensor, t: int | float | Tensor, conditions: Tensor = None, **kwargs): - if not keras.ops.is_tensor(t): - t = keras.ops.convert_to_tensor(t, dtype=x.dtype) - if keras.ops.ndim(t) == 0: - t = keras.ops.full((keras.ops.shape(x)[0],), t, dtype=keras.ops.dtype(x)) - - t = expand_right_as(t, x) - if keras.ops.ndim(x) == 3: - t = tile_axis(t, n=keras.ops.shape(x)[1], axis=1) - - if conditions is None: - xtc = keras.ops.concatenate([x, t], axis=-1) - else: - xtc = keras.ops.concatenate([x, t, conditions], axis=-1) - - return self.output_projector(self.subnet(xtc, **kwargs)) - - def call( - self, - x: Tensor, - conditions: Tensor = None, - steps: int = 100, - density: bool = False, - inverse: bool = False, - **kwargs, - ): - z = keras.ops.copy(x) - t = 1.0 if not inverse else 0.0 - dt = -1.0 / steps if not inverse else 1.0 / steps - - def f(arg): - return self.velocity(arg, t, conditions, **kwargs) - - if density: - trace = keras.ops.zeros(keras.ops.shape(x)[:-1], dtype=x.dtype) - for _ in range(steps): - v, tr = jacobian_trace(f, z, max_steps=kwargs.get("trace_steps", 5), return_output=True) - z += dt * v - trace += dt * tr - t += dt - return z, trace - - for _ in range(steps): - v = self.velocity(z, t, conditions, **kwargs) - z += dt * v - t += dt - - return z diff --git a/bayesflow/networks/flow_matching/integrators/integrator.py b/bayesflow/networks/flow_matching/integrators/integrator.py deleted file mode 100644 index 791fda62f..000000000 --- a/bayesflow/networks/flow_matching/integrators/integrator.py +++ /dev/null @@ -1,11 +0,0 @@ -import keras -from bayesflow.types import Tensor -from bayesflow.utils import keras_kwargs - - -class Integrator(keras.Layer): - def __init__(self, **kwargs): - super().__init__(**keras_kwargs(kwargs)) - - def call(self, x: Tensor, steps: int, conditions: Tensor = None, dynamic: bool = False): - raise NotImplementedError diff --git a/bayesflow/networks/flow_matching/integrators/runge_kutta.py b/bayesflow/networks/flow_matching/integrators/runge_kutta.py deleted file mode 100644 index 1264e5d0c..000000000 --- a/bayesflow/networks/flow_matching/integrators/runge_kutta.py +++ /dev/null @@ -1,82 +0,0 @@ -import keras -from bayesflow.types import Tensor, Shape -from bayesflow.utils import find_network, jacobian_trace, keras_kwargs, expand_right_as, tile_axis -from .integrator import Integrator - - -class RK2Integrator(Integrator): - """ - TODO: docstring - """ - - def __init__(self, subnet: str | type = "mlp", **kwargs): - super().__init__(**keras_kwargs(kwargs)) - self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {})) - self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") - - def build(self, xz_shape: Shape, conditions_shape: Shape = None): - self.output_projector.units = xz_shape[-1] - input_shape = list(xz_shape) - - # construct time vector - input_shape[-1] += 1 - if conditions_shape is not None: - input_shape[-1] += conditions_shape[-1] - - input_shape = tuple(input_shape) - - self.subnet.build(input_shape) - out_shape = self.subnet.compute_output_shape(input_shape) - self.output_projector.build(out_shape) - - def velocity(self, x: Tensor, t: int | float | Tensor, conditions: Tensor = None, **kwargs): - if not keras.ops.is_tensor(t): - t = keras.ops.convert_to_tensor(t, dtype=x.dtype) - if keras.ops.ndim(t) == 0: - t = keras.ops.full((keras.ops.shape(x)[0],), t, dtype=keras.ops.dtype(x)) - - t = expand_right_as(t, x) - if keras.ops.ndim(x) == 3: - t = tile_axis(t, n=keras.ops.shape(x)[1], axis=1) - - if conditions is None: - xtc = keras.ops.concatenate([x, t], axis=-1) - else: - xtc = keras.ops.concatenate([x, t, conditions], axis=-1) - - return self.output_projector(self.subnet(xtc, **kwargs)) - - def call( - self, - x: Tensor, - conditions: Tensor = None, - steps: int = 100, - density: bool = False, - inverse: bool = False, - **kwargs, - ): - z = keras.ops.copy(x) - t = 1.0 if not inverse else 0.0 - dt = -1.0 / steps if not inverse else 1.0 / steps - - def f(arg): - k1 = self.velocity(arg, t, conditions, **kwargs) - k2 = self.velocity(arg + (dt / 2.0 * k1), t + (dt / 2.0), conditions, **kwargs) - return k2 - - if density: - trace = keras.ops.zeros(keras.ops.shape(x)[:-1], dtype=x.dtype) - for _ in range(steps): - k2, tr = jacobian_trace(f, z, max_steps=kwargs.get("trace_steps", 5), return_output=True) - z += dt * k2 - trace += dt * tr - t += dt - return z, trace - - for _ in range(steps): - k1 = self.velocity(z, t, conditions, **kwargs) - k2 = self.velocity(z + (dt / 2.0 * k1), t + (dt / 2.0), conditions, **kwargs) - z += dt * k2 - t += dt - - return z diff --git a/bayesflow/networks/flow_matching/integrators/runge_kutta_4.py b/bayesflow/networks/flow_matching/integrators/runge_kutta_4.py deleted file mode 100644 index 98ea9ee07..000000000 --- a/bayesflow/networks/flow_matching/integrators/runge_kutta_4.py +++ /dev/null @@ -1,86 +0,0 @@ -import keras -from bayesflow.types import Tensor, Shape -from bayesflow.utils import find_network, jacobian_trace, keras_kwargs, expand_right_as, tile_axis -from .integrator import Integrator - - -class RK4Integrator(Integrator): - """ - TODO: docstring - """ - - def __init__(self, subnet: str | type = "mlp", **kwargs): - super().__init__(**keras_kwargs(kwargs)) - self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {})) - self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") - - def build(self, xz_shape: Shape, conditions_shape: Shape = None): - self.output_projector.units = xz_shape[-1] - input_shape = list(xz_shape) - - # construct time vector - input_shape[-1] += 1 - if conditions_shape is not None: - input_shape[-1] += conditions_shape[-1] - - input_shape = tuple(input_shape) - - self.subnet.build(input_shape) - out_shape = self.subnet.compute_output_shape(input_shape) - self.output_projector.build(out_shape) - - def velocity(self, x: Tensor, t: int | float | Tensor, conditions: Tensor = None, **kwargs): - if not keras.ops.is_tensor(t): - t = keras.ops.convert_to_tensor(t, dtype=x.dtype) - if keras.ops.ndim(t) == 0: - t = keras.ops.full((keras.ops.shape(x)[0],), t, dtype=keras.ops.dtype(x)) - - t = expand_right_as(t, x) - if keras.ops.ndim(x) == 3: - t = tile_axis(t, n=keras.ops.shape(x)[1], axis=1) - - if conditions is None: - xtc = keras.ops.concatenate([x, t], axis=-1) - else: - xtc = keras.ops.concatenate([x, t, conditions], axis=-1) - - return self.output_projector(self.subnet(xtc, **kwargs)) - - def call( - self, - x: Tensor, - conditions: Tensor = None, - steps: int = 100, - density: bool = False, - inverse: bool = False, - **kwargs, - ): - z = keras.ops.copy(x) - t = 1.0 if not inverse else 0.0 - dt = -1.0 / steps if not inverse else 1.0 / steps - - def f(arg): - k1 = self.velocity(arg, t, conditions, **kwargs) - k2 = self.velocity(arg + (dt / 2.0 * k1), t + (dt / 2.0), conditions, **kwargs) - k3 = self.velocity(arg + (dt / 2.0 * k2), t + (dt / 2.0), conditions, **kwargs) - k4 = self.velocity(arg + (dt * k3), t + dt, conditions, **kwargs) - return (k1 + (2 * k2) + (2 * k3) + k4) / 6.0 - - if density: - trace = keras.ops.zeros(keras.ops.shape(x)[:-1], dtype=x.dtype) - for _ in range(steps): - v4, tr = jacobian_trace(f, z, max_steps=kwargs.get("trace_steps", 5), return_output=True) - z += dt * v4 - trace += dt * tr - t += dt - return z, trace - - for _ in range(steps): - k1 = self.velocity(z, t, conditions, **kwargs) - k2 = self.velocity(z + (dt / 2.0 * k1), t + (dt / 2.0), conditions, **kwargs) - k3 = self.velocity(z + (dt / 2.0 * k2), t + (dt / 2.0), conditions, **kwargs) - k4 = self.velocity(z + (dt * k3), t + dt, conditions, **kwargs) - z += (dt / 6.0) * (k1 + (2 * k2) + (2 * k3) + k4) - t += dt - - return z diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index a4b084388..6e7057d61 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -22,6 +22,9 @@ repo_url, ) from .hparam_utils import find_batch_size, find_memory_budget +from .integrate import ( + integrate, +) from .io import ( pickle_load, format_bytes, diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py new file mode 100644 index 000000000..ab652f401 --- /dev/null +++ b/bayesflow/utils/integrate.py @@ -0,0 +1,251 @@ +from collections.abc import Callable +from functools import partial + +import keras + +from bayesflow.types import Tensor +from bayesflow.utils import filter_kwargs +from . import logging + +ArrayLike = int | float | Tensor + + +def euler_step( + fn: Callable, + state: dict[str, ArrayLike], + time: ArrayLike, + step_size: ArrayLike, + tolerance: ArrayLike = 1e-6, + min_step_size: ArrayLike = -float("inf"), + max_step_size: ArrayLike = float("inf"), + use_adaptive_step_size: bool = False, +) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + k1 = fn(time, **filter_kwargs(state, fn)) + + if use_adaptive_step_size: + intermediate_state = state.copy() + for key, delta in k1.items(): + intermediate_state[key] = state[key] + step_size * delta + + k2 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) + + # check all keys are equal + if set(k1.keys()) != set(k2.keys()): + raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") + + # compute next step size + intermediate_error = keras.ops.stack([keras.ops.norm(k2[key] - k1[key], ord=2, axis=-1) for key in k1]) + new_step_size = step_size * tolerance / (intermediate_error + 1e-9) + + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + + # consolidate step size + new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) + else: + new_step_size = step_size + + # apply updates + new_state = state.copy() + for key in k1.keys(): + new_state[key] = state[key] + step_size * k1[key] + + new_time = time + step_size + + return new_state, new_time, new_step_size + + +def rk45_step( + fn: Callable, + state: dict[str, ArrayLike], + time: ArrayLike, + last_step_size: ArrayLike, + tolerance: ArrayLike = 1e-6, + min_step_size: ArrayLike = -float("inf"), + max_step_size: ArrayLike = float("inf"), + use_adaptive_step_size: bool = False, +) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + step_size = last_step_size + + k1 = fn(time, **filter_kwargs(state, fn)) + + intermediate_state = state.copy() + for key, delta in k1.items(): + intermediate_state[key] = state[key] + 0.5 * step_size * delta + + k2 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + + intermediate_state = state.copy() + for key, delta in k2.items(): + intermediate_state[key] = state[key] + 0.5 * step_size * delta + + k3 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + + intermediate_state = state.copy() + for key, delta in k3.items(): + intermediate_state[key] = state[key] + step_size * delta + + k4 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) + + if use_adaptive_step_size: + intermediate_state = state.copy() + for key, delta in k4.items(): + intermediate_state[key] = state[key] + 0.5 * step_size * delta + + k5 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + + # check all keys are equal + if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5]): + raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") + + # compute next step size + intermediate_error = keras.ops.stack([keras.ops.norm(k5[key] - k4[key], ord=2, axis=-1) for key in k5.keys()]) + new_step_size = step_size * tolerance / (intermediate_error + 1e-9) + + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + + # consolidate step size + new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) + else: + new_step_size = step_size + + # apply updates + new_state = state.copy() + for key in k1.keys(): + new_state[key] = state[key] + (step_size / 6.0) * (k1[key] + 2.0 * k2[key] + 2.0 * k3[key] + k4[key]) + + new_time = time + step_size + + return new_state, new_time, new_step_size + + +def integrate_fixed( + fn: Callable, + state: dict[str, ArrayLike], + start_time: ArrayLike, + stop_time: ArrayLike, + steps: int, + method: str = "rk45", + **kwargs, +) -> dict[str, ArrayLike]: + if steps <= 0: + raise ValueError("Number of steps must be positive.") + + match method: + case "euler": + step_fn = euler_step + case "rk45": + step_fn = rk45_step + case str() as name: + raise ValueError(f"Unknown integration method name: {name!r}") + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + step_size = (stop_time - start_time) / steps + + time = start_time + + def body(_loop_var, _loop_state): + _state, _time = _loop_state + _state, _time, _ = step_fn(_state, _time, step_size) + + return _state, _time + + state, time = keras.ops.fori_loop(0, steps, body, (state, time)) + + return state + + +def integrate_adaptive( + fn: Callable, + state: dict[str, ArrayLike], + start_time: ArrayLike, + stop_time: ArrayLike, + min_steps: int = 10, + max_steps: int = 1000, + method: str = "rk45", + **kwargs, +) -> dict[str, ArrayLike]: + if max_steps <= min_steps: + raise ValueError("Maximum number of steps must be greater than minimum number of steps.") + + match method: + case "euler": + step_fn = euler_step + case "rk45": + step_fn = rk45_step + case str() as name: + raise ValueError(f"Unknown integration method name: {name!r}") + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=True) + + def cond(_state, _time, _step_size, _step): + # while step < min_steps or time_remaining > 0 and step < max_steps + + # time remaining after the next step + time_remaining = keras.ops.abs(stop_time - (_time + _step_size)) + + return keras.ops.logical_or( + keras.ops.all(_step < min_steps), + keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.all(_step < max_steps)), + ) + + def body(_state, _time, _step_size, _step): + _step = _step + 1 + + # time remaining after the next step + time_remaining = stop_time - (_time + _step_size) + + min_step_size = time_remaining / (max_steps - _step) + max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0) + + # reorder + min_step_size, max_step_size = ( + keras.ops.minimum(min_step_size, max_step_size), + keras.ops.maximum(min_step_size, max_step_size), + ) + + _state, _time, _step_size = step_fn( + _state, _time, _step_size, min_step_size=min_step_size, max_step_size=max_step_size + ) + + return _state, _time, _step_size, _step + + # select initial step size conservatively + step_size = (stop_time - start_time) / max_steps + + step = 0 + time = start_time + + state, time, step_size, step = keras.ops.while_loop(cond, body, [state, time, step_size, step]) + + # do the last step + step_size = stop_time - time + state, _, _ = step_fn(state, time, step_size) + step = step + 1 + + logging.debug("Finished integration after {} steps.", step) + + return state + + +def integrate( + fn: Callable, + state: dict[str, ArrayLike], + start_time: ArrayLike, + stop_time: ArrayLike, + min_steps: int = 10, + max_steps: int = 10_000, + steps: int = "adaptive", + method: str = "rk45", + **kwargs, +) -> dict[str, ArrayLike]: + match steps: + case "adaptive" | "dynamic": + return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) + case int(): + return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) + case _: + raise RuntimeError("Type or value of `steps` not understood.") diff --git a/bayesflow/utils/jacobian/jacobian_trace.py b/bayesflow/utils/jacobian/jacobian_trace.py index bafffb016..a81448263 100644 --- a/bayesflow/utils/jacobian/jacobian_trace.py +++ b/bayesflow/utils/jacobian/jacobian_trace.py @@ -1,13 +1,19 @@ from collections.abc import Callable + import keras from bayesflow.types import Tensor - from .jacobian import jacobian from .vjp import vjp -def jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor, max_steps: int = None, return_output: bool = False): +def jacobian_trace( + f: Callable[[Tensor], Tensor], + x: Tensor, + max_steps: int = None, + return_output: bool = False, + seed: int | keras.random.SeedGenerator = None, +): """Compute or estimate the trace of the Jacobian matrix of f. :param f: The function to be differentiated. @@ -25,6 +31,9 @@ def jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor, max_steps: int = No Whether to return the output of f(x) along with the trace of the Jacobian. Default: False + :param seed: int or keras SeedGenerator + The seed to use for hutchinson trace estimation. Only has an effect when max_steps < d. + :return: 2-tuple of tensors: 1. The output of f(x) (if return_output is True) 2. Tensor of shape (n,) @@ -36,7 +45,7 @@ def jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor, max_steps: int = No fx, jac = jacobian(f, x, return_output=True) trace = keras.ops.trace(jac, axis1=-2, axis2=-1) else: - fx, trace = _hutchinson(f, x, steps=max_steps, return_output=True) + fx, trace = _hutchinson(f, x, steps=max_steps, return_output=True, seed=seed) if return_output: return fx, trace @@ -44,7 +53,9 @@ def jacobian_trace(f: Callable[[Tensor], Tensor], x: Tensor, max_steps: int = No return trace -def _hutchinson(f: callable, x: Tensor, steps: int = 1, return_output: bool = False): +def _hutchinson( + f: callable, x: Tensor, steps: int = 1, return_output: bool = False, seed: int | keras.random.SeedGenerator = None +): """Estimate the trace of the Jacobian matrix of f using Hutchinson's algorithm. :param f: The function to be differentiated. @@ -67,7 +78,7 @@ def _hutchinson(f: callable, x: Tensor, steps: int = 1, return_output: bool = Fa fx, vjp_fn = vjp(f, x, return_output=True) for _ in range(steps): - projector = keras.random.normal(shape) + projector = keras.random.normal(shape, seed=seed) trace += keras.ops.sum(vjp_fn(projector) * projector, axis=-1) if return_output: diff --git a/tests/conftest.py b/tests/conftest.py index d7933545e..54a067d5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,12 @@ +import logging + import keras import pytest - BACKENDS = ["jax", "numpy", "tensorflow", "torch"] +logging.getLogger("bayesflow").setLevel(logging.DEBUG) + def pytest_runtest_setup(item): """Skips backends by test markers. Unmarked tests are treated as backend-agnostic""" diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index 33395cdc5..44d93e617 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -12,7 +12,10 @@ def test_build(inference_network, random_samples, random_conditions): assert inference_network.built is False - inference_network(random_samples, conditions=random_conditions) + samples_shape = keras.ops.shape(random_samples) + conditions_shape = keras.ops.shape(random_conditions) if random_conditions is not None else None + + inference_network.build(samples_shape, conditions_shape=conditions_shape) assert inference_network.built is True @@ -22,7 +25,9 @@ def test_build(inference_network, random_samples, random_conditions): def test_variable_batch_size(inference_network, random_samples, random_conditions): # build with one batch size - inference_network(random_samples, conditions=random_conditions) + samples_shape = keras.ops.shape(random_samples) + conditions_shape = keras.ops.shape(random_conditions) if random_conditions is not None else None + inference_network.build(samples_shape, conditions_shape=conditions_shape) # run with another batch size batch_sizes = np.random.choice(10, replace=False, size=3) @@ -78,6 +83,7 @@ def test_cycle_consistency(inference_network, random_samples, random_conditions) assert allclose(forward_log_density, inverse_log_density, atol=1e-3, rtol=1e-3) +# TODO: make this backend-agnostic @pytest.mark.torch def test_density_numerically(inference_network, random_samples, random_conditions): import torch