|
1 |
| - |
2 |
| -from typing import Tuple, Union |
3 |
| - |
4 | 1 | import keras
|
5 | 2 | from keras.saving import (
|
6 | 3 | register_keras_serializable,
|
7 | 4 | )
|
8 |
| -from scipy.integrate import solve_ivp |
9 | 5 |
|
10 | 6 | from bayesflow.types import Tensor
|
11 |
| -from bayesflow.utils import find_network, keras_kwargs |
| 7 | +from bayesflow.utils import find_network, jacobian_trace, keras_kwargs, optimal_transport |
12 | 8 |
|
13 | 9 | from ..inference_network import InferenceNetwork
|
14 | 10 |
|
15 | 11 |
|
16 | 12 | @register_keras_serializable(package="bayesflow.networks")
|
17 | 13 | class FlowMatching(InferenceNetwork):
|
18 |
| - def __init__(self, network: str = "resnet", **kwargs): |
19 |
| - super().__init__(**keras_kwargs(kwargs)) |
20 |
| - self.network = find_network(network, **kwargs) |
| 14 | + def __init__(self, subnet: str = "resnet", base_distribution: str = "normal", **kwargs): |
| 15 | + super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs)) |
| 16 | + self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {})) |
| 17 | + |
| 18 | + output_projector_kwargs = kwargs.get("output_projector_kwargs", {}) |
| 19 | + output_projector_kwargs.setdefault("bias_initializer", "zeros") |
| 20 | + self.output_projector = keras.layers.Dense(None, **output_projector_kwargs) |
| 21 | + |
| 22 | + def build(self, xz_shape, conditions_shape=None): |
| 23 | + super().build(xz_shape) |
| 24 | + |
| 25 | + self.output_projector.units = xz_shape[-1] |
| 26 | + |
| 27 | + xz = keras.ops.zeros(xz_shape) |
| 28 | + if conditions_shape is None: |
| 29 | + conditions = None |
| 30 | + else: |
| 31 | + conditions = keras.ops.zeros(conditions_shape) |
| 32 | + |
| 33 | + self.call(xz, conditions=conditions, steps=1) |
| 34 | + |
| 35 | + def call( |
| 36 | + self, |
| 37 | + xz: Tensor, |
| 38 | + conditions: Tensor = None, |
| 39 | + inverse: bool = False, |
| 40 | + **kwargs, |
| 41 | + ): |
| 42 | + if inverse: |
| 43 | + return self._inverse(xz, conditions=conditions, **kwargs) |
| 44 | + return self._forward(xz, conditions=conditions, **kwargs) |
| 45 | + |
| 46 | + def velocity(self, x: Tensor, t: int | float | Tensor, conditions: Tensor = None) -> Tensor: |
| 47 | + t = keras.ops.convert_to_tensor(t, dtype=x.dtype) |
| 48 | + match keras.ops.ndim(t): |
| 49 | + case 0: |
| 50 | + t = keras.ops.full((keras.ops.shape(x)[0], 1), t, dtype=x.dtype) |
| 51 | + case 1: |
| 52 | + t = keras.ops.expand_dims(t, 1) |
21 | 53 |
|
22 |
| - def velocity(self, x: Tensor, t: Tensor, conditions: any = None): |
23 | 54 | if conditions is None:
|
24 |
| - xtc = keras.ops.concatenate([x, t], axis=1) |
| 55 | + xtc = keras.ops.concatenate([x, t], axis=-1) |
25 | 56 | else:
|
26 |
| - xtc = keras.ops.concatenate([x, t, conditions], axis=1) |
| 57 | + xtc = keras.ops.concatenate([x, t, conditions], axis=-1) |
27 | 58 |
|
28 |
| - return self.network(xtc) |
| 59 | + return self.output_projector(self.subnet(xtc)) |
29 | 60 |
|
30 |
| - def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: |
31 |
| - def dfdt(t: float, x: Tensor): |
32 |
| - t = keras.ops.full((keras.ops.shape(x)[0], 1), t) |
33 |
| - return self.velocity(x, t, conditions) |
| 61 | + def _forward( |
| 62 | + self, x: Tensor, conditions: Tensor = None, density: bool = False, **kwargs |
| 63 | + ) -> Tensor | tuple[Tensor, Tensor]: |
| 64 | + steps = kwargs.get("steps", 100) |
| 65 | + z = keras.ops.copy(x) |
| 66 | + t = keras.ops.ones((keras.ops.shape(x)[0], 1), dtype=x.dtype) |
| 67 | + dt = -1.0 / steps |
34 | 68 |
|
35 |
| - return solve_ivp(dfdt, t_span=(1.0, 0.0), y0=x, method=method, vectorized=True)[1] |
| 69 | + if density: |
| 70 | + trace = keras.ops.zeros(keras.ops.shape(x)[0], dtype=x.dtype) |
| 71 | + |
| 72 | + def f(arg): |
| 73 | + return self.velocity(arg, t, conditions) |
| 74 | + |
| 75 | + for _ in range(steps): |
| 76 | + v, tr = jacobian_trace(f, z, kwargs.get("trace_samples", 100)) |
| 77 | + z += dt * v |
| 78 | + trace += dt * tr |
| 79 | + |
| 80 | + log_prob = self.base_distribution.log_prob(z) |
| 81 | + |
| 82 | + log_density = log_prob + trace |
| 83 | + |
| 84 | + return z, log_density |
| 85 | + else: |
| 86 | + for _ in range(steps): |
| 87 | + v = self.velocity(z, t, conditions) |
| 88 | + z += dt * v |
36 | 89 |
|
37 |
| - def _inverse(self, z: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: |
38 |
| - def dfdt(t: float, x: Tensor): |
39 |
| - t = keras.ops.full((keras.ops.shape(x)[0], 1), t) |
40 |
| - return self.velocity(x, t, conditions) |
| 90 | + return z |
41 | 91 |
|
42 |
| - return solve_ivp(dfdt, t_span=(0.0, 1.0), y0=z, method=method, vectorized=True)[1] |
| 92 | + def _inverse( |
| 93 | + self, z: Tensor, conditions: Tensor = None, density: bool = False, **kwargs |
| 94 | + ) -> Tensor | tuple[Tensor, Tensor]: |
| 95 | + steps = kwargs.get("steps", 100) |
| 96 | + x = keras.ops.copy(z) |
| 97 | + t = keras.ops.zeros((keras.ops.shape(x)[0], 1), dtype=x.dtype) |
| 98 | + dt = 1.0 / steps |
| 99 | + |
| 100 | + if density: |
| 101 | + trace = keras.ops.zeros(keras.ops.shape(x)[0], dtype=x.dtype) |
| 102 | + |
| 103 | + def f(arg): |
| 104 | + return self.velocity(arg, t, conditions) |
| 105 | + |
| 106 | + for _ in range(steps): |
| 107 | + v, tr = jacobian_trace(f, x, kwargs.get("trace_samples", 100)) |
| 108 | + x += dt * v |
| 109 | + trace += dt * tr |
| 110 | + |
| 111 | + log_prob = self.base_distribution.log_prob(z) |
| 112 | + |
| 113 | + log_density = log_prob - trace |
| 114 | + |
| 115 | + return x, log_density |
| 116 | + else: |
| 117 | + for _ in range(steps): |
| 118 | + v = self.velocity(x, t, conditions) |
| 119 | + x += dt * v |
| 120 | + |
| 121 | + return x |
| 122 | + |
| 123 | + def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> dict[str, Tensor]: |
| 124 | + x1 = data["inference_variables"] |
| 125 | + c = data.get("inference_conditions") |
| 126 | + |
| 127 | + x0 = self.base_distribution.sample(keras.ops.shape(x1)) |
| 128 | + |
| 129 | + x0, x1 = optimal_transport(x0, x1) |
43 | 130 |
|
44 |
| - def compute_loss(self, x=None, **kwargs): |
45 |
| - x0, x1, *conditions = x |
46 | 131 | t = keras.random.uniform((keras.ops.shape(x0)[0], 1))
|
47 | 132 |
|
48 | 133 | x = t * x1 + (1 - t) * x0
|
49 |
| - xtc = keras.ops.concatenate([x, t, *conditions], axis=-1) |
50 | 134 |
|
51 |
| - predicted_velocity = self.network(xtc) |
| 135 | + predicted_velocity = self.velocity(x, t, c) |
52 | 136 | target_velocity = x1 - x0
|
53 | 137 |
|
54 |
| - return keras.losses.mean_squared_error(predicted_velocity, target_velocity) |
55 |
| - |
56 |
| - |
57 |
| -# @register_keras_serializable(package="bayesflow.networks") |
58 |
| -# class FlowMatching(InferenceNetwork): |
59 |
| -# def __init__(self, network: keras.Layer, **kwargs): |
60 |
| -# super().__init__(**kwargs) |
61 |
| -# self.network = network |
62 |
| -# |
63 |
| -# @classmethod |
64 |
| -# def new(cls, network: str = "resnet", base_distribution: str = "normal"): |
65 |
| -# # TODO: we probably want to provide a factory method like this, since the other networks use it |
66 |
| -# # for high-level input parameters |
67 |
| -# # network = find_network(network) |
68 |
| -# return cls(network, base_distribution=base_distribution) |
69 |
| -# |
70 |
| -# @classmethod |
71 |
| -# def from_config(cls, config: dict, custom_objects=None) -> "FlowMatching": |
72 |
| -# # TODO: the base distribution must be savable and loadable |
73 |
| -# # ideally we also don't want to have to manually deserialize it in every subclass of InferenceNetwork |
74 |
| -# base_distribution = deserialize_keras_object(config.pop("base_distribution")) |
75 |
| -# network = deserialize_keras_object(config.pop("network")) |
76 |
| -# return cls(network, base_distribution=base_distribution, **config) |
77 |
| -# |
78 |
| -# def get_config(self) -> dict: |
79 |
| -# base_config = super().get_config() |
80 |
| -# config = {"network": serialize_keras_object(self.network)} |
81 |
| -# return base_config | config |
82 |
| -# |
83 |
| -# def build(self, input_shape): |
84 |
| -# self.network.build(input_shape) |
85 |
| -# |
86 |
| -# def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: |
87 |
| -# # implement conditions = None and jacobian = False first |
88 |
| -# # then work your way up |
89 |
| -# raise NotImplementedError |
90 |
| -# |
91 |
| -# def _inverse(self, z: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: |
92 |
| -# raise NotImplementedError |
93 |
| -# |
94 |
| -# def compute_loss(self, x=None, **kwargs): |
95 |
| -# # x should ideally contain both x0 and x1, |
96 |
| -# # where the optimal transport matching already happened in the worker process |
97 |
| -# # this is possible, but might not be super user-friendly. We will have to see. |
98 |
| -# x0, x1, t = x |
99 |
| -# |
100 |
| -# xt = t * x1 + (1 - t) * x0 |
101 |
| -# |
102 |
| -# # get velocity at xt |
103 |
| -# v = ... |
104 |
| -# |
105 |
| -# # target velocity: |
106 |
| -# vstar = x1 - x0 |
107 |
| -# |
108 |
| -# # return mse between v and vstar |
109 |
| -# |
110 |
| -# |
111 |
| -# # TODO: see below for reference implementation |
112 |
| -# |
113 |
| -# |
114 |
| -# class FlowMatching(keras.Model): |
115 |
| -# def __init__(self, network: keras.Layer, base_distribution): |
116 |
| -# super().__init__() |
117 |
| -# self.network = network |
118 |
| -# self.base_distribution = find_distribution(base_distribution) |
119 |
| -# |
120 |
| -# def call(self, inferred_variables, inference_conditions): |
121 |
| -# return self.network(keras.ops.concatenate([inferred_variables, inference_conditions], axis=1)) |
122 |
| -# |
123 |
| -# def compute_loss(self, x=None, y=None, y_pred=None, **kwargs): |
124 |
| -# return keras.losses.mean_squared_error(y, y_pred) |
125 |
| -# |
126 |
| -# def velocity(self, x: Tensor, t: Tensor, c: Tensor = None): |
127 |
| -# if c is None: |
128 |
| -# xtc = keras.ops.concatenate([x, t], axis=1) |
129 |
| -# else: |
130 |
| -# xtc = keras.ops.concatenate([x, t, c], axis=1) |
131 |
| -# |
132 |
| -# return self.network(xtc) |
133 |
| -# |
134 |
| -# def forward(self, x, c=None, method="RK45") -> Tensor: |
135 |
| -# def f(t, x): |
136 |
| -# t = keras.ops.full((keras.ops.shape(x)[0], 1), t) |
137 |
| -# return self.velocity(x, t, c) |
138 |
| -# |
139 |
| -# bunch = solve_ivp(f, t_span=(1.0, 0.0), y0=x, method=method, vectorized=True) |
140 |
| -# |
141 |
| -# return bunch[1] |
142 |
| -# |
143 |
| -# def inverse(self, x, c=None, method="RK45") -> Tensor: |
144 |
| -# def f(t, x): |
145 |
| -# t = keras.ops.full((keras.ops.shape(x)[0], 1), t) |
146 |
| -# return self.velocity(x, t, c) |
147 |
| -# |
148 |
| -# bunch = solve_ivp(f, t_span=(0.0, 1.0), y0=x, method=method, vectorized=True) |
149 |
| -# |
150 |
| -# return bunch[1] |
151 |
| -# |
152 |
| -# def sample(self, batch_shape: Shape) -> Tensor: |
153 |
| -# z = self.base_distribution.sample(batch_shape) |
154 |
| -# return self.inverse(z) |
155 |
| -# |
156 |
| -# def log_prob(self, x: Tensor, c: Tensor = None) -> Tensor: |
157 |
| -# raise NotImplementedError(f"Keras does not yet support backend-agnostic Vector-Jacobian Products.") |
158 |
| -# |
159 |
| -# |
160 |
| -# def hutchinson_trace(f: callable, x: Tensor) -> (Tensor, Tensor): |
161 |
| -# # TODO: test this for all 3 backends |
162 |
| -# noise = keras.random.normal(keras.ops.shape(x)) |
163 |
| -# |
164 |
| -# match keras.backend.backend(): |
165 |
| -# case "jax": |
166 |
| -# import jax |
167 |
| -# fx, jvp = jax.jvp(f, (x,), (noise,)) |
168 |
| -# case "tensorflow": |
169 |
| -# import tensorflow as tf |
170 |
| -# with tf.GradientTape(persistent=True) as tape: |
171 |
| -# tape.watch(x) |
172 |
| -# fx = f(x) |
173 |
| -# jvp = tape.gradient(fx, x, output_gradients=noise) |
174 |
| -# case "torch": |
175 |
| -# import torch |
176 |
| -# fx, jvp = torch.autograd.functional.jvp(f, x, noise, create_graph=True) |
177 |
| -# case other: |
178 |
| -# raise NotImplementedError(f"Backend {other} is not supported for trace estimation.") |
179 |
| -# |
180 |
| -# trace = keras.ops.sum(jvp * noise, axis=1) |
181 |
| -# |
182 |
| -# return fx, trace |
| 138 | + loss = keras.losses.mean_squared_error(predicted_velocity, target_velocity) |
| 139 | + |
| 140 | + return {"loss": loss} |
0 commit comments