This repository was archived by the owner on May 6, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 237
This repository was archived by the owner on May 6, 2025. It is now read-only.
Strange result with a simple 2-layers NN #138
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Hello,
I will give a snippet
import jax
import jax.numpy as jnp
from jax import jit
from jax import grad
from jax.example_libraries import optimizers
from jax.config import config
config.update("jax_enable_x64", True) # DOUBLE PRECISION pour les operations matricielles
import numpy as np
import neural_tangents as nt
from neural_tangents import stax
###########
key = jax.random.PRNGKey(0) #initial seed
# Some dimensions
d=15
N=6
ns=165
n_test=1_000
batch_size=5
# A vector beta once for all
beta = jax.random.normal(key, shape=(1,d))
norm = jnp.linalg.norm(beta, axis=1)
beta = beta / norm
# Utils to generate a dataset
def gen_x(key=None, r=1.0, d=20,ns=50):
x = jax.random.normal(key, shape=(ns,d))
norm = jnp.linalg.norm(x, axis=1)
x_normed = r * x / norm.reshape(x.shape[0],1)
return x_normed
def gen_y(key, X, beta, sigma_eps=0.5):
" Target generation"
Xbeta = X @ beta.T # <beta, Xi>
y = jnp.sin(Xbeta)
noise = jax.random.normal(key,shape=(X.shape[0],1)) * sigma_eps
return y + noise
# The MSE loss
loss = lambda fx, y_hat: 0.5*jnp.mean((fx - y_hat) ** 2)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
#Test Dataset
key, x_key, y_key = jax.random.split(key, 3)
X_test = gen_x(x_key, r=np.sqrt(d), d=d, ns=n_test)
Y_test = gen_y(y_key,X_test, beta, sigma_eps=0.) # no error
# Train Dataste
key, x_key, y_key = jax.random.split(key, 3)
X_train = gen_x(x_key, r=np.sqrt(d), d=d, ns=ns)
Y_train = gen_y(y_key,X_train, beta, sigma_eps=0.5)
#NN 2-layers for regression 1 ouput
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(N, W_std=1., parameterization='standard'),
stax.Relu(),
stax.Dense(1, W_std=1., parameterization='standard')
)
#Finite Width NTK with batch size as the number of samples can be large
emp_ntk_kernel_fn = nt.batch(nt.empirical_ntk_fn(apply_fn),device_count=-1, batch_size=batch_size)
#Initialize the parameters and the NTK_train_train /NTK_test_train kernel matrix
_, params = init_fn(key, (-1, d))
kntk_emp_train_train = emp_ntk_kernel_fn(X_train, None, params)
kntk_emp_test_train = emp_ntk_kernel_fn(X_test, X_train, params)
predict_fn = nt.predict.gradient_descent_mse(kntk_emp_train_train, Y_train, diag_reg=1.e-9)
#First (t=0) inference of the Network for Train & Test sets
fx_train_0 = apply_fn(params, X_train)
fx_test_0 = apply_fn(params, X_test)
# MSE @ t=Infinity inference (= ridgeless regression min-norm)
fx_train_inf, fx_test_inf = predict_fn(None, fx_train_0, fx_test_0, kntk_emp_test_train)
# The MSE loss on Train & Test datasets
loss(fx_train_inf, Y_train), loss(fx_test_inf, Y_test)I get (DeviceArray(0., dtype=float64), DeviceArray(nan, dtype=float64)).
But, I would expect as Nd=90 (the number parameter of 1st Dense layer wo bias) is smaller than the number of samples (165) that the train MSE is not 0 ( I am not in the overparametrized regime) and the test MSE is not diverging as Nd =/= ns.
So I am puzzled and certainly I have missed something. What I wanted to do is to compute the MSE (time infinite) inference with the finite width Neural Tangent Kernel. By the way I am trying to reproduce more or less the results of Figure 1 & 2 of https://arxiv.org/pdf/2007.12826.pdf by Andrea Montanari and Yiqiao Zhong.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working