Skip to content

Commit e797100

Browse files
committed
implement backend-agnostic jacobian trace for Flow Matching
1 parent aca3d7a commit e797100

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

bayesflow/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44
keras_kwargs,
55
)
66

7+
from .jacobian_trace import jacobian_trace
8+
79
from .dispatch import find_distribution, find_network, find_permutation, find_pooling, find_recurrent_net
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .jacobian_trace import jacobian_trace
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import keras
2+
import numpy as np
3+
4+
from bayesflow.types import Tensor
5+
6+
7+
def jacobian_trace(f: callable, x: Tensor, samples: int = 1) -> (Tensor, Tensor):
8+
"""
9+
Returns an unbiased estimate of the trace of the Jacobian of f, using Hutchinson's estimator.
10+
11+
:param f: The function to be differentiated.
12+
Must take x as its only argument and return a single output Tensor.
13+
14+
:param x: Tensor of shape (n, d)
15+
The input tensor to f.
16+
17+
:param samples: The number of random samples to use for the estimate.
18+
If this exceeds the dimensionality of f(x) or you pass None, we
19+
will instead perform an exact computation which takes that many samples.
20+
Default: 1
21+
22+
:return: Tensor of shape (n,)
23+
An unbiased estimate of the trace of the Jacobian of f.
24+
"""
25+
26+
batch_size, dims = keras.ops.shape(x)
27+
28+
match keras.backend.backend():
29+
case "jax":
30+
import jax
31+
32+
fx, vjp_fn = jax.vjp(f, x)
33+
vjp_fn = jax.jit(vjp_fn)
34+
35+
trace = keras.ops.zeros((batch_size,), dtype=x.dtype)
36+
37+
# TODO: can we use jax.vmap to avoid the for loop?
38+
39+
if samples is None or dims <= samples:
40+
# exact
41+
for dim in range(dims):
42+
projector = keras.ops.zeros((batch_size, dims), dtype=x.dtype)
43+
projector = projector.at[:, dim].set(1.0)
44+
45+
vjp = vjp_fn(projector)[0]
46+
47+
trace += vjp[:, dim]
48+
else:
49+
# estimate
50+
for sample in range(samples):
51+
projector = keras.random.normal((batch_size, dims), dtype=x.dtype)
52+
53+
vjp = vjp_fn(projector)[0]
54+
55+
trace += keras.ops.sum(vjp * projector, axis=1)
56+
57+
case "tensorflow":
58+
import tensorflow as tf
59+
60+
with tf.GradientTape(persistent=True) as tape:
61+
tape.watch(x)
62+
fx = f(x)
63+
64+
trace = keras.ops.zeros((batch_size,))
65+
66+
# TODO: can we use tf.gradients to avoid the for loop?
67+
68+
if samples is None or dims <= samples:
69+
# exact
70+
for dim in range(dims):
71+
projector = np.zeros((batch_size, dims), dtype=keras.backend.standardize_dtype(x.dtype))
72+
projector[:, dim] = 1.0
73+
projector = keras.ops.convert_to_tensor(projector)
74+
75+
vjp = tape.gradient(fx, x, projector)
76+
77+
trace += vjp[:, dim]
78+
else:
79+
# estimate
80+
for _ in range(samples):
81+
projector = keras.random.normal((batch_size, dims), dtype=x.dtype)
82+
83+
vjp = tape.gradient(fx, x, projector)
84+
85+
trace += keras.ops.sum(vjp * projector, axis=1) / samples
86+
case "torch":
87+
import torch
88+
89+
with torch.enable_grad():
90+
x.requires_grad = True
91+
fx = f(x)
92+
93+
trace = keras.ops.zeros(keras.ops.shape(x)[0])
94+
95+
# TODO: can we use is_grads_batched to avoid the for loop?
96+
97+
if samples is None or dims <= samples:
98+
# exact
99+
for dim in range(dims):
100+
projector = keras.ops.zeros((batch_size, dims), dtype=x.dtype)
101+
projector[:, dim] = 1.0
102+
103+
vjp = torch.autograd.grad(fx, x, projector, retain_graph=True)[0]
104+
105+
trace += vjp[:, dim]
106+
else:
107+
# estimate
108+
for _ in range(samples):
109+
projector = keras.random.normal((batch_size, dims), dtype=x.dtype)
110+
111+
vjp = torch.autograd.grad(fx, x, projector, retain_graph=True)[0]
112+
113+
trace += keras.ops.sum(vjp * projector, axis=1) / samples
114+
case other:
115+
raise NotImplementedError(f"Jacobian trace computation is currently not supported for backend '{other}'.")
116+
117+
return fx, trace

0 commit comments

Comments
 (0)