Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.
This repository was archived by the owner on May 6, 2025. It is now read-only.

batch-decorated kernel function gives out-of-memory error #89

@qixuanf

Description

@qixuanf

I have an empirical kernel function that works on a small batch of inputs.
I then decorated the kernel function with batch decorator like below, then it gives out of memory error. I would like to use the batch decorator to scale the kernel function to more inputs, but it seems that it consumes more memory than not using it.

print(kernel_fn(x1=x, x2=None))   # this one works

kernel_fn_batched = nt.batch(kernel_fn, device_count=0, batch_size=1)
print(kernel_fn_batched(x1=x, x2=None)) # gives out of memory error

The out of memory error trace

RuntimeError                              Traceback (most recent call last)
<ipython-input-82-4b3cbb07e3d0> in <module>
----> 1 res = p['kernel_fn_1'](x1_or_kernel=x, x2=None)

~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in serial_fn(x1_or_kernel, x2, *args, **kwargs)
    453                 **kwargs) -> NTTree[Kernel]:
    454     if utils.is_nt_tree_of(x1_or_kernel, np.ndarray):
--> 455       return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
    456     elif utils.is_nt_tree_of(x1_or_kernel, Kernel):
    457       if x2 is not None:

~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in serial_fn_x1(x1, x2, *args, **kwargs)
    383       return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
    384 
--> 385     _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
    386     return flatten(kernel, x2_is_none)
    387 

~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in _scan(f, init, xs)
    140   for flat_x in zip(*flat_xs):
    141     x = tree_unflatten(tree_def, flat_x)
--> 142     carry, y = f(carry, x)
    143     ys += [y]
    144 

~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in row_fn(_, x1)
    372 
    373     def row_fn(_, x1):
--> 374       return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
    375 
    376     def col_fn(x1, x2):

~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in _scan(f, init, xs)
    140   for flat_x in zip(*flat_xs):
    141     x = tree_unflatten(tree_def, flat_x)
--> 142     carry, y = f(carry, x)
    143     ys += [y]
    144 

~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in col_fn(x1, x2)
    381           **dict((k, (kwargs1[k], kwargs2[k])) for k in kwargs1)
    382       }
--> 383       return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
    384 
    385     _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))

~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in f_pmapped(x_or_kernel, *args, **kwargs)
    735       # Broadcast `np.ndarray` arguments and apply the new function to them.
    736       args_np = tree_map(broadcast, args_np)
--> 737       return _f(x_or_kernel, *args_np, **kwargs_np)
    738 
    739     return f_pmapped

~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    131   def reraise_with_filtered_traceback(*args, **kwargs):
    132     try:
--> 133       return fun(*args, **kwargs)
    134     except Exception as e:
    135       if not is_under_reraiser(e):

~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    369         return cache_miss(*args, **kwargs)[0]  # probably won't return
    370     else:
--> 371       return cpp_jitted_f(*args, **kwargs)
    372   f_jitted._cpp_jitted_f = cpp_jitted_f
    373 

~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py in cache_miss(*args, **kwargs)
    276       _check_arg(arg)
    277     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 278     out_flat = xla.xla_call(
    279         flat_fun,
    280         *args_flat,

~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1187 
   1188   def bind(self, fun, *args, **params):
-> 1189     return call_bind(self, fun, *args, **params)
   1190 
   1191   def process(self, trace, fun, tracers, params):

~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1178   tracers = map(top_trace.full_raise, args)
   1179   with maybe_new_sublevel(top_trace):
-> 1180     outs = primitive.process(top_trace, fun, tracers, params)
   1181   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1182 

~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1190 
   1191   def process(self, trace, fun, tracers, params):
-> 1192     return trace.process_call(self, fun, tracers, params)
   1193 
   1194   def post_process(self, trace, out_tracers, params):

~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    581 
    582   def process_call(self, primitive, f, tracers, params):
--> 583     return primitive.impl(f, *tracers, **params)
    584   process_map = process_call
    585 

~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    561                                *unsafe_map(arg_spec, args))
    562   try:
--> 563     return compiled_fun(*args)
    564   except FloatingPointError:
    565     assert FLAGS.jax_debug_nans  # compiled_fun can only raise in this case

~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py in _execute_compiled(compiled, avals, handlers, *args)
    809   device, = compiled.local_devices()
    810   input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
--> 811   out_bufs = compiled.execute(input_bufs)
    812   if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
    813   return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]

RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory

The version of neural-tangent I am using is 0.3.5.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions