Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.
This repository was archived by the owner on May 6, 2025. It is now read-only.

Strange result with a simple 2-layers NN #138

@jecampagne

Description

@jecampagne

Hello,

I will give a snippet

import jax
import jax.numpy as jnp
from jax import jit
from jax import grad
from jax.example_libraries import optimizers

from jax.config import config
config.update("jax_enable_x64", True) # DOUBLE PRECISION pour les operations matricielles

import numpy as np

import neural_tangents as nt
from neural_tangents import stax
###########

key = jax.random.PRNGKey(0)  #initial seed

# Some dimensions
d=15
N=6
ns=165
n_test=1_000
batch_size=5

# A vector beta once for all
beta = jax.random.normal(key, shape=(1,d))
norm = jnp.linalg.norm(beta, axis=1)
beta =  beta / norm

# Utils to generate a dataset
def gen_x(key=None, r=1.0, d=20,ns=50):
    x = jax.random.normal(key, shape=(ns,d))
    norm = jnp.linalg.norm(x, axis=1)
    x_normed = r * x / norm.reshape(x.shape[0],1)
    return x_normed


def gen_y(key, X, beta, sigma_eps=0.5):
    " Target generation"
    Xbeta = X @ beta.T  # <beta, Xi>
    y = jnp.sin(Xbeta)
    noise = jax.random.normal(key,shape=(X.shape[0],1)) * sigma_eps
    return y + noise

# The MSE loss
loss = lambda fx, y_hat: 0.5*jnp.mean((fx - y_hat) ** 2)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    
#Test Dataset
key, x_key, y_key = jax.random.split(key, 3)
X_test = gen_x(x_key, r=np.sqrt(d), d=d, ns=n_test)
Y_test = gen_y(y_key,X_test, beta, sigma_eps=0.) # no error
    
# Train Dataste
key, x_key, y_key = jax.random.split(key, 3)
X_train = gen_x(x_key, r=np.sqrt(d), d=d, ns=ns)
Y_train = gen_y(y_key,X_train, beta, sigma_eps=0.5)
        
               
#NN 2-layers for regression 1 ouput
init_fn, apply_fn, kernel_fn = stax.serial(
                stax.Dense(N, W_std=1., parameterization='standard'), 
                stax.Relu(),
                stax.Dense(1, W_std=1., parameterization='standard')
)

#Finite Width NTK with batch size as the number of samples can be large
emp_ntk_kernel_fn = nt.batch(nt.empirical_ntk_fn(apply_fn),device_count=-1, batch_size=batch_size)
            
#Initialize the parameters and the NTK_train_train /NTK_test_train kernel matrix
_, params = init_fn(key, (-1, d))
kntk_emp_train_train = emp_ntk_kernel_fn(X_train, None, params)
kntk_emp_test_train  = emp_ntk_kernel_fn(X_test, X_train, params)
            
predict_fn = nt.predict.gradient_descent_mse(kntk_emp_train_train, Y_train,  diag_reg=1.e-9)

#First (t=0) inference of the Network for Train & Test sets
fx_train_0 = apply_fn(params, X_train)
fx_test_0  = apply_fn(params, X_test)
            
# MSE @ t=Infinity inference (= ridgeless regression min-norm)
fx_train_inf, fx_test_inf = predict_fn(None, fx_train_0, fx_test_0,  kntk_emp_test_train)

# The MSE loss on Train & Test datasets
loss(fx_train_inf, Y_train), loss(fx_test_inf, Y_test)

I get (DeviceArray(0., dtype=float64), DeviceArray(nan, dtype=float64)).

But, I would expect as Nd=90 (the number parameter of 1st Dense layer wo bias) is smaller than the number of samples (165) that the train MSE is not 0 ( I am not in the overparametrized regime) and the test MSE is not diverging as Nd =/= ns.

So I am puzzled and certainly I have missed something. What I wanted to do is to compute the MSE (time infinite) inference with the finite width Neural Tangent Kernel. By the way I am trying to reproduce more or less the results of Figure 1 & 2 of https://arxiv.org/pdf/2007.12826.pdf by Andrea Montanari and Yiqiao Zhong.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions