Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.
This repository was archived by the owner on May 6, 2025. It is now read-only.

Questions about memory consumption of infinitely wide NTK #166

@jasonli0707

Description

@jasonli0707

I am working on a simple MNIST example. I found that I could not compute the NTK for the entire dataset without running out of memory. Below is the code snippet I used:

import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from jax import random, jit
import jax.numpy as jnp

def FC(depth=1, num_classes=10, W_std=1.0, b_std=0.0):
    layers = [stax.Flatten()]
    for _ in range(depth):
        layers += [stax.Dense(1, W_std, b_std), stax.Relu()]
    layers += [stax.Dense(num_classes, W_std, b_std)]
    return stax.serial(*layers)

x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', data_dir="./data", permute_train=True)

key = random.PRNGKey(0)
init_fn, apply_fn, kernel_fn = FC()
_, params = init_fn(key, (-1, 784))

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2,))

batched_kernel_fn = nt.batch(kernel_fn, 1000, store_on_device=False)

k_train_train = kernel_fn(x_train, None, 'ntk')
k_test_train = kernel_fn(x_test, x_train, 'ntk')
predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train)
fx_train_0 = apply_fn(params, x_train)
fx_test_0 = apply_fn(params, x_test)
fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)

I am running this on two RTX3090 each having a 24Gb buffer.
Is there something I'm doing wrong, or is it normal for NTK to consume so much memory?
Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions