|
| 1 | +import keras |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +from bayesflow.types import Tensor |
| 5 | + |
| 6 | + |
| 7 | +def jacobian_trace(f: callable, x: Tensor, samples: int = 1) -> (Tensor, Tensor): |
| 8 | + """ |
| 9 | + Returns an unbiased estimate of the trace of the Jacobian of f, using Hutchinson's estimator. |
| 10 | +
|
| 11 | + :param f: The function to be differentiated. |
| 12 | + Must take x as its only argument and return a single output Tensor. |
| 13 | +
|
| 14 | + :param x: Tensor of shape (n, d) |
| 15 | + The input tensor to f. |
| 16 | +
|
| 17 | + :param samples: The number of random samples to use for the estimate. |
| 18 | + If this exceeds the dimensionality of f(x) or you pass None, we |
| 19 | + will instead perform an exact computation which takes that many samples. |
| 20 | + Default: 1 |
| 21 | +
|
| 22 | + :return: Tensor of shape (n,) |
| 23 | + An unbiased estimate of the trace of the Jacobian of f. |
| 24 | + """ |
| 25 | + |
| 26 | + batch_size, dims = keras.ops.shape(x) |
| 27 | + |
| 28 | + match keras.backend.backend(): |
| 29 | + case "jax": |
| 30 | + import jax |
| 31 | + |
| 32 | + fx, vjp_fn = jax.vjp(f, x) |
| 33 | + vjp_fn = jax.jit(vjp_fn) |
| 34 | + |
| 35 | + trace = keras.ops.zeros((batch_size,), dtype=x.dtype) |
| 36 | + |
| 37 | + # TODO: can we use jax.vmap to avoid the for loop? |
| 38 | + |
| 39 | + if samples is None or dims <= samples: |
| 40 | + # exact |
| 41 | + for dim in range(dims): |
| 42 | + projector = keras.ops.zeros((batch_size, dims), dtype=x.dtype) |
| 43 | + projector = projector.at[:, dim].set(1.0) |
| 44 | + |
| 45 | + vjp = vjp_fn(projector)[0] |
| 46 | + |
| 47 | + trace += vjp[:, dim] |
| 48 | + else: |
| 49 | + # estimate |
| 50 | + for sample in range(samples): |
| 51 | + projector = keras.random.normal((batch_size, dims), dtype=x.dtype) |
| 52 | + |
| 53 | + vjp = vjp_fn(projector)[0] |
| 54 | + |
| 55 | + trace += keras.ops.sum(vjp * projector, axis=1) |
| 56 | + |
| 57 | + case "tensorflow": |
| 58 | + import tensorflow as tf |
| 59 | + |
| 60 | + with tf.GradientTape(persistent=True) as tape: |
| 61 | + tape.watch(x) |
| 62 | + fx = f(x) |
| 63 | + |
| 64 | + trace = keras.ops.zeros((batch_size,)) |
| 65 | + |
| 66 | + # TODO: can we use tf.gradients to avoid the for loop? |
| 67 | + |
| 68 | + if samples is None or dims <= samples: |
| 69 | + # exact |
| 70 | + for dim in range(dims): |
| 71 | + projector = np.zeros((batch_size, dims), dtype=keras.backend.standardize_dtype(x.dtype)) |
| 72 | + projector[:, dim] = 1.0 |
| 73 | + projector = keras.ops.convert_to_tensor(projector) |
| 74 | + |
| 75 | + vjp = tape.gradient(fx, x, projector) |
| 76 | + |
| 77 | + trace += vjp[:, dim] |
| 78 | + else: |
| 79 | + # estimate |
| 80 | + for _ in range(samples): |
| 81 | + projector = keras.random.normal((batch_size, dims), dtype=x.dtype) |
| 82 | + |
| 83 | + vjp = tape.gradient(fx, x, projector) |
| 84 | + |
| 85 | + trace += keras.ops.sum(vjp * projector, axis=1) / samples |
| 86 | + case "torch": |
| 87 | + import torch |
| 88 | + |
| 89 | + with torch.enable_grad(): |
| 90 | + x.requires_grad = True |
| 91 | + fx = f(x) |
| 92 | + |
| 93 | + trace = keras.ops.zeros(keras.ops.shape(x)[0]) |
| 94 | + |
| 95 | + # TODO: can we use is_grads_batched to avoid the for loop? |
| 96 | + |
| 97 | + if samples is None or dims <= samples: |
| 98 | + # exact |
| 99 | + for dim in range(dims): |
| 100 | + projector = keras.ops.zeros((batch_size, dims), dtype=x.dtype) |
| 101 | + projector[:, dim] = 1.0 |
| 102 | + |
| 103 | + vjp = torch.autograd.grad(fx, x, projector, retain_graph=True)[0] |
| 104 | + |
| 105 | + trace += vjp[:, dim] |
| 106 | + else: |
| 107 | + # estimate |
| 108 | + for _ in range(samples): |
| 109 | + projector = keras.random.normal((batch_size, dims), dtype=x.dtype) |
| 110 | + |
| 111 | + vjp = torch.autograd.grad(fx, x, projector, retain_graph=True)[0] |
| 112 | + |
| 113 | + trace += keras.ops.sum(vjp * projector, axis=1) / samples |
| 114 | + case other: |
| 115 | + raise NotImplementedError(f"Jacobian trace computation is currently not supported for backend '{other}'.") |
| 116 | + |
| 117 | + return fx, trace |
0 commit comments