I am working on a simple MNIST example. I found that I could not compute the NTK for the entire dataset without running out of memory. Below is the code snippet I used:
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from jax import random, jit
import jax.numpy as jnp
def FC(depth=1, num_classes=10, W_std=1.0, b_std=0.0):
layers = [stax.Flatten()]
for _ in range(depth):
layers += [stax.Dense(1, W_std, b_std), stax.Relu()]
layers += [stax.Dense(num_classes, W_std, b_std)]
return stax.serial(*layers)
x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', data_dir="./data", permute_train=True)
key = random.PRNGKey(0)
init_fn, apply_fn, kernel_fn = FC()
_, params = init_fn(key, (-1, 784))
apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2,))
batched_kernel_fn = nt.batch(kernel_fn, 1000, store_on_device=False)
k_train_train = kernel_fn(x_train, None, 'ntk')
k_test_train = kernel_fn(x_test, x_train, 'ntk')
predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train)
fx_train_0 = apply_fn(params, x_train)
fx_test_0 = apply_fn(params, x_test)
fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)
I am running this on two RTX3090 each having a 24Gb buffer.
Is there something I'm doing wrong, or is it normal for NTK to consume so much memory?
Thank you!