Skip to content

Commit 07e60dd

Browse files
committed
update Flow Matching (WIP)
also remove flow matching from tests so tests run for now change semantics of inference_network to return a density instead of the log_det of the jacobian, which is more permissive
1 parent e797100 commit 07e60dd

File tree

10 files changed

+168
-196
lines changed

10 files changed

+168
-196
lines changed

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ def __init__(
5050

5151
self.invertible_layers = []
5252
for i in range(depth):
53-
if (p := find_permutation(permutation, **kwargs)) is not None:
53+
if (p := find_permutation(permutation, **kwargs.get("permutation_kwargs", {}))) is not None:
5454
self.invertible_layers.append(p)
5555

56-
self.invertible_layers.append(DualCoupling(subnet, transform, **kwargs))
56+
self.invertible_layers.append(DualCoupling(subnet, transform, **kwargs.get("coupling_kwargs", {})))
5757

5858
if use_actnorm:
59-
self.invertible_layers.append(ActNorm(**kwargs))
59+
self.invertible_layers.append(ActNorm(**kwargs.get("actnorm_kwargs", {})))
6060

6161
# noinspection PyMethodOverriding
6262
def build(self, xz_shape, conditions_shape=None):
@@ -79,37 +79,42 @@ def call(
7979
return self._forward(xz, conditions=conditions, **kwargs)
8080

8181
def _forward(
82-
self, x: Tensor, conditions: Tensor = None, jacobian: bool = False, **kwargs
82+
self, x: Tensor, conditions: Tensor = None, density: bool = False, **kwargs
8383
) -> Tensor | tuple[Tensor, Tensor]:
8484
z = x
8585
log_det = keras.ops.zeros(keras.ops.shape(x)[:-1])
8686
for layer in self.invertible_layers:
8787
z, det = layer(z, conditions=conditions, inverse=False, **kwargs)
8888
log_det += det
8989

90-
if jacobian:
91-
return z, log_det
90+
if density:
91+
log_prob = self.base_distribution.log_prob(z)
92+
log_density = log_prob + log_det
93+
return z, log_density
94+
9295
return z
9396

9497
def _inverse(
95-
self, z: Tensor, conditions: Tensor = None, jacobian: bool = False, **kwargs
98+
self, z: Tensor, conditions: Tensor = None, density: bool = False, **kwargs
9699
) -> Tensor | tuple[Tensor, Tensor]:
97100
x = z
98101
log_det = keras.ops.zeros(keras.ops.shape(z)[:-1])
99102
for layer in reversed(self.invertible_layers):
100103
x, det = layer(x, conditions=conditions, inverse=True, **kwargs)
101104
log_det += det
102105

103-
if jacobian:
104-
return x, log_det
106+
if density:
107+
log_prob = self.base_distribution.log_prob(z)
108+
log_density = log_prob - log_det
109+
return x, log_density
110+
105111
return x
106112

107113
def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> dict[str, Tensor]:
108114
inference_variables = data["inference_variables"]
109115
inference_conditions = data.get("inference_conditions")
110116

111-
z, log_det = self(inference_variables, conditions=inference_conditions, inverse=False, jacobian=True)
112-
log_prob = self.base_distribution.log_prob(z)
113-
loss = -keras.ops.mean(log_prob + log_det)
117+
z, log_density = self(inference_variables, conditions=inference_conditions, inverse=False, density=True)
118+
loss = -keras.ops.mean(log_density)
114119

115120
return {"loss": loss}
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
21
from .flow_matching import FlowMatching
Lines changed: 113 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -1,182 +1,140 @@
1-
2-
from typing import Tuple, Union
3-
41
import keras
52
from keras.saving import (
63
register_keras_serializable,
74
)
8-
from scipy.integrate import solve_ivp
95

106
from bayesflow.types import Tensor
11-
from bayesflow.utils import find_network, keras_kwargs
7+
from bayesflow.utils import find_network, jacobian_trace, keras_kwargs, optimal_transport
128

139
from ..inference_network import InferenceNetwork
1410

1511

1612
@register_keras_serializable(package="bayesflow.networks")
1713
class FlowMatching(InferenceNetwork):
18-
def __init__(self, network: str = "resnet", **kwargs):
19-
super().__init__(**keras_kwargs(kwargs))
20-
self.network = find_network(network, **kwargs)
14+
def __init__(self, subnet: str = "resnet", base_distribution: str = "normal", **kwargs):
15+
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))
16+
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
17+
18+
output_projector_kwargs = kwargs.get("output_projector_kwargs", {})
19+
output_projector_kwargs.setdefault("bias_initializer", "zeros")
20+
self.output_projector = keras.layers.Dense(None, **output_projector_kwargs)
21+
22+
def build(self, xz_shape, conditions_shape=None):
23+
super().build(xz_shape)
24+
25+
self.output_projector.units = xz_shape[-1]
26+
27+
xz = keras.ops.zeros(xz_shape)
28+
if conditions_shape is None:
29+
conditions = None
30+
else:
31+
conditions = keras.ops.zeros(conditions_shape)
32+
33+
self.call(xz, conditions=conditions, steps=1)
34+
35+
def call(
36+
self,
37+
xz: Tensor,
38+
conditions: Tensor = None,
39+
inverse: bool = False,
40+
**kwargs,
41+
):
42+
if inverse:
43+
return self._inverse(xz, conditions=conditions, **kwargs)
44+
return self._forward(xz, conditions=conditions, **kwargs)
45+
46+
def velocity(self, x: Tensor, t: int | float | Tensor, conditions: Tensor = None) -> Tensor:
47+
t = keras.ops.convert_to_tensor(t, dtype=x.dtype)
48+
match keras.ops.ndim(t):
49+
case 0:
50+
t = keras.ops.full((keras.ops.shape(x)[0], 1), t, dtype=x.dtype)
51+
case 1:
52+
t = keras.ops.expand_dims(t, 1)
2153

22-
def velocity(self, x: Tensor, t: Tensor, conditions: any = None):
2354
if conditions is None:
24-
xtc = keras.ops.concatenate([x, t], axis=1)
55+
xtc = keras.ops.concatenate([x, t], axis=-1)
2556
else:
26-
xtc = keras.ops.concatenate([x, t, conditions], axis=1)
57+
xtc = keras.ops.concatenate([x, t, conditions], axis=-1)
2758

28-
return self.network(xtc)
59+
return self.output_projector(self.subnet(xtc))
2960

30-
def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]:
31-
def dfdt(t: float, x: Tensor):
32-
t = keras.ops.full((keras.ops.shape(x)[0], 1), t)
33-
return self.velocity(x, t, conditions)
61+
def _forward(
62+
self, x: Tensor, conditions: Tensor = None, density: bool = False, **kwargs
63+
) -> Tensor | tuple[Tensor, Tensor]:
64+
steps = kwargs.get("steps", 100)
65+
z = keras.ops.copy(x)
66+
t = keras.ops.ones((keras.ops.shape(x)[0], 1), dtype=x.dtype)
67+
dt = -1.0 / steps
3468

35-
return solve_ivp(dfdt, t_span=(1.0, 0.0), y0=x, method=method, vectorized=True)[1]
69+
if density:
70+
trace = keras.ops.zeros(keras.ops.shape(x)[0], dtype=x.dtype)
71+
72+
def f(arg):
73+
return self.velocity(arg, t, conditions)
74+
75+
for _ in range(steps):
76+
v, tr = jacobian_trace(f, z, kwargs.get("trace_samples", 100))
77+
z += dt * v
78+
trace += dt * tr
79+
80+
log_prob = self.base_distribution.log_prob(z)
81+
82+
log_density = log_prob + trace
83+
84+
return z, log_density
85+
else:
86+
for _ in range(steps):
87+
v = self.velocity(z, t, conditions)
88+
z += dt * v
3689

37-
def _inverse(self, z: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]:
38-
def dfdt(t: float, x: Tensor):
39-
t = keras.ops.full((keras.ops.shape(x)[0], 1), t)
40-
return self.velocity(x, t, conditions)
90+
return z
4191

42-
return solve_ivp(dfdt, t_span=(0.0, 1.0), y0=z, method=method, vectorized=True)[1]
92+
def _inverse(
93+
self, z: Tensor, conditions: Tensor = None, density: bool = False, **kwargs
94+
) -> Tensor | tuple[Tensor, Tensor]:
95+
steps = kwargs.get("steps", 100)
96+
x = keras.ops.copy(z)
97+
t = keras.ops.zeros((keras.ops.shape(x)[0], 1), dtype=x.dtype)
98+
dt = 1.0 / steps
99+
100+
if density:
101+
trace = keras.ops.zeros(keras.ops.shape(x)[0], dtype=x.dtype)
102+
103+
def f(arg):
104+
return self.velocity(arg, t, conditions)
105+
106+
for _ in range(steps):
107+
v, tr = jacobian_trace(f, x, kwargs.get("trace_samples", 100))
108+
x += dt * v
109+
trace += dt * tr
110+
111+
log_prob = self.base_distribution.log_prob(z)
112+
113+
log_density = log_prob - trace
114+
115+
return x, log_density
116+
else:
117+
for _ in range(steps):
118+
v = self.velocity(x, t, conditions)
119+
x += dt * v
120+
121+
return x
122+
123+
def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> dict[str, Tensor]:
124+
x1 = data["inference_variables"]
125+
c = data.get("inference_conditions")
126+
127+
x0 = self.base_distribution.sample(keras.ops.shape(x1))
128+
129+
x0, x1 = optimal_transport(x0, x1)
43130

44-
def compute_loss(self, x=None, **kwargs):
45-
x0, x1, *conditions = x
46131
t = keras.random.uniform((keras.ops.shape(x0)[0], 1))
47132

48133
x = t * x1 + (1 - t) * x0
49-
xtc = keras.ops.concatenate([x, t, *conditions], axis=-1)
50134

51-
predicted_velocity = self.network(xtc)
135+
predicted_velocity = self.velocity(x, t, c)
52136
target_velocity = x1 - x0
53137

54-
return keras.losses.mean_squared_error(predicted_velocity, target_velocity)
55-
56-
57-
# @register_keras_serializable(package="bayesflow.networks")
58-
# class FlowMatching(InferenceNetwork):
59-
# def __init__(self, network: keras.Layer, **kwargs):
60-
# super().__init__(**kwargs)
61-
# self.network = network
62-
#
63-
# @classmethod
64-
# def new(cls, network: str = "resnet", base_distribution: str = "normal"):
65-
# # TODO: we probably want to provide a factory method like this, since the other networks use it
66-
# # for high-level input parameters
67-
# # network = find_network(network)
68-
# return cls(network, base_distribution=base_distribution)
69-
#
70-
# @classmethod
71-
# def from_config(cls, config: dict, custom_objects=None) -> "FlowMatching":
72-
# # TODO: the base distribution must be savable and loadable
73-
# # ideally we also don't want to have to manually deserialize it in every subclass of InferenceNetwork
74-
# base_distribution = deserialize_keras_object(config.pop("base_distribution"))
75-
# network = deserialize_keras_object(config.pop("network"))
76-
# return cls(network, base_distribution=base_distribution, **config)
77-
#
78-
# def get_config(self) -> dict:
79-
# base_config = super().get_config()
80-
# config = {"network": serialize_keras_object(self.network)}
81-
# return base_config | config
82-
#
83-
# def build(self, input_shape):
84-
# self.network.build(input_shape)
85-
#
86-
# def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]:
87-
# # implement conditions = None and jacobian = False first
88-
# # then work your way up
89-
# raise NotImplementedError
90-
#
91-
# def _inverse(self, z: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]:
92-
# raise NotImplementedError
93-
#
94-
# def compute_loss(self, x=None, **kwargs):
95-
# # x should ideally contain both x0 and x1,
96-
# # where the optimal transport matching already happened in the worker process
97-
# # this is possible, but might not be super user-friendly. We will have to see.
98-
# x0, x1, t = x
99-
#
100-
# xt = t * x1 + (1 - t) * x0
101-
#
102-
# # get velocity at xt
103-
# v = ...
104-
#
105-
# # target velocity:
106-
# vstar = x1 - x0
107-
#
108-
# # return mse between v and vstar
109-
#
110-
#
111-
# # TODO: see below for reference implementation
112-
#
113-
#
114-
# class FlowMatching(keras.Model):
115-
# def __init__(self, network: keras.Layer, base_distribution):
116-
# super().__init__()
117-
# self.network = network
118-
# self.base_distribution = find_distribution(base_distribution)
119-
#
120-
# def call(self, inferred_variables, inference_conditions):
121-
# return self.network(keras.ops.concatenate([inferred_variables, inference_conditions], axis=1))
122-
#
123-
# def compute_loss(self, x=None, y=None, y_pred=None, **kwargs):
124-
# return keras.losses.mean_squared_error(y, y_pred)
125-
#
126-
# def velocity(self, x: Tensor, t: Tensor, c: Tensor = None):
127-
# if c is None:
128-
# xtc = keras.ops.concatenate([x, t], axis=1)
129-
# else:
130-
# xtc = keras.ops.concatenate([x, t, c], axis=1)
131-
#
132-
# return self.network(xtc)
133-
#
134-
# def forward(self, x, c=None, method="RK45") -> Tensor:
135-
# def f(t, x):
136-
# t = keras.ops.full((keras.ops.shape(x)[0], 1), t)
137-
# return self.velocity(x, t, c)
138-
#
139-
# bunch = solve_ivp(f, t_span=(1.0, 0.0), y0=x, method=method, vectorized=True)
140-
#
141-
# return bunch[1]
142-
#
143-
# def inverse(self, x, c=None, method="RK45") -> Tensor:
144-
# def f(t, x):
145-
# t = keras.ops.full((keras.ops.shape(x)[0], 1), t)
146-
# return self.velocity(x, t, c)
147-
#
148-
# bunch = solve_ivp(f, t_span=(0.0, 1.0), y0=x, method=method, vectorized=True)
149-
#
150-
# return bunch[1]
151-
#
152-
# def sample(self, batch_shape: Shape) -> Tensor:
153-
# z = self.base_distribution.sample(batch_shape)
154-
# return self.inverse(z)
155-
#
156-
# def log_prob(self, x: Tensor, c: Tensor = None) -> Tensor:
157-
# raise NotImplementedError(f"Keras does not yet support backend-agnostic Vector-Jacobian Products.")
158-
#
159-
#
160-
# def hutchinson_trace(f: callable, x: Tensor) -> (Tensor, Tensor):
161-
# # TODO: test this for all 3 backends
162-
# noise = keras.random.normal(keras.ops.shape(x))
163-
#
164-
# match keras.backend.backend():
165-
# case "jax":
166-
# import jax
167-
# fx, jvp = jax.jvp(f, (x,), (noise,))
168-
# case "tensorflow":
169-
# import tensorflow as tf
170-
# with tf.GradientTape(persistent=True) as tape:
171-
# tape.watch(x)
172-
# fx = f(x)
173-
# jvp = tape.gradient(fx, x, output_gradients=noise)
174-
# case "torch":
175-
# import torch
176-
# fx, jvp = torch.autograd.functional.jvp(f, x, noise, create_graph=True)
177-
# case other:
178-
# raise NotImplementedError(f"Backend {other} is not supported for trace estimation.")
179-
#
180-
# trace = keras.ops.sum(jvp * noise, axis=1)
181-
#
182-
# return fx, trace
138+
loss = keras.losses.mean_squared_error(predicted_velocity, target_velocity)
139+
140+
return {"loss": loss}

bayesflow/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@
77
from .jacobian_trace import jacobian_trace
88

99
from .dispatch import find_distribution, find_network, find_permutation, find_pooling, find_recurrent_net
10+
11+
from .optimal_transport import optimal_transport

bayesflow/utils/jacobian_trace/jacobian_trace.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def jacobian_trace(f: callable, x: Tensor, samples: int = 1) -> (Tensor, Tensor)
2222
:return: Tensor of shape (n,)
2323
An unbiased estimate of the trace of the Jacobian of f.
2424
"""
25-
25+
# copy here to avoid causing outside side effects
26+
# TODO: this may not be necessary for every backend
27+
x = keras.ops.copy(x)
2628
batch_size, dims = keras.ops.shape(x)
2729

2830
match keras.backend.backend():
@@ -86,8 +88,9 @@ def jacobian_trace(f: callable, x: Tensor, samples: int = 1) -> (Tensor, Tensor)
8688
case "torch":
8789
import torch
8890

91+
x.requires_grad_(True)
92+
8993
with torch.enable_grad():
90-
x.requires_grad = True
9194
fx = f(x)
9295

9396
trace = keras.ops.zeros(keras.ops.shape(x)[0])
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .optimal_transport import optimal_transport

0 commit comments

Comments
 (0)