This repository was archived by the owner on May 6, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 237
This repository was archived by the owner on May 6, 2025. It is now read-only.
batch-decorated kernel function gives out-of-memory error #89
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't working
Description
I have an empirical kernel function that works on a small batch of inputs.
I then decorated the kernel function with batch decorator like below, then it gives out of memory error. I would like to use the batch decorator to scale the kernel function to more inputs, but it seems that it consumes more memory than not using it.
print(kernel_fn(x1=x, x2=None)) # this one works
kernel_fn_batched = nt.batch(kernel_fn, device_count=0, batch_size=1)
print(kernel_fn_batched(x1=x, x2=None)) # gives out of memory errorThe out of memory error trace
RuntimeError Traceback (most recent call last)
<ipython-input-82-4b3cbb07e3d0> in <module>
----> 1 res = p['kernel_fn_1'](x1_or_kernel=x, x2=None)
~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
174 @functools.wraps(f)
175 def h(*args, **kwargs):
--> 176 return g(*args, **kwargs)
177
178 h.__signature__ = inspect.signature(f)
~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in serial_fn(x1_or_kernel, x2, *args, **kwargs)
453 **kwargs) -> NTTree[Kernel]:
454 if utils.is_nt_tree_of(x1_or_kernel, np.ndarray):
--> 455 return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
456 elif utils.is_nt_tree_of(x1_or_kernel, Kernel):
457 if x2 is not None:
~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in serial_fn_x1(x1, x2, *args, **kwargs)
383 return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
384
--> 385 _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
386 return flatten(kernel, x2_is_none)
387
~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in _scan(f, init, xs)
140 for flat_x in zip(*flat_xs):
141 x = tree_unflatten(tree_def, flat_x)
--> 142 carry, y = f(carry, x)
143 ys += [y]
144
~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in row_fn(_, x1)
372
373 def row_fn(_, x1):
--> 374 return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
375
376 def col_fn(x1, x2):
~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in _scan(f, init, xs)
140 for flat_x in zip(*flat_xs):
141 x = tree_unflatten(tree_def, flat_x)
--> 142 carry, y = f(carry, x)
143 ys += [y]
144
~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in col_fn(x1, x2)
381 **dict((k, (kwargs1[k], kwargs2[k])) for k in kwargs1)
382 }
--> 383 return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
384
385 _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
174 @functools.wraps(f)
175 def h(*args, **kwargs):
--> 176 return g(*args, **kwargs)
177
178 h.__signature__ = inspect.signature(f)
~/anaconda3/envs/gppr/lib/python3.8/site-packages/neural_tangents/utils/batch.py in f_pmapped(x_or_kernel, *args, **kwargs)
735 # Broadcast `np.ndarray` arguments and apply the new function to them.
736 args_np = tree_map(broadcast, args_np)
--> 737 return _f(x_or_kernel, *args_np, **kwargs_np)
738
739 return f_pmapped
~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
131 def reraise_with_filtered_traceback(*args, **kwargs):
132 try:
--> 133 return fun(*args, **kwargs)
134 except Exception as e:
135 if not is_under_reraiser(e):
~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
369 return cache_miss(*args, **kwargs)[0] # probably won't return
370 else:
--> 371 return cpp_jitted_f(*args, **kwargs)
372 f_jitted._cpp_jitted_f = cpp_jitted_f
373
~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py in cache_miss(*args, **kwargs)
276 _check_arg(arg)
277 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 278 out_flat = xla.xla_call(
279 flat_fun,
280 *args_flat,
~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
1187
1188 def bind(self, fun, *args, **params):
-> 1189 return call_bind(self, fun, *args, **params)
1190
1191 def process(self, trace, fun, tracers, params):
~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1178 tracers = map(top_trace.full_raise, args)
1179 with maybe_new_sublevel(top_trace):
-> 1180 outs = primitive.process(top_trace, fun, tracers, params)
1181 return map(full_lower, apply_todos(env_trace_todo(), outs))
1182
~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1190
1191 def process(self, trace, fun, tracers, params):
-> 1192 return trace.process_call(self, fun, tracers, params)
1193
1194 def post_process(self, trace, out_tracers, params):
~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
581
582 def process_call(self, primitive, f, tracers, params):
--> 583 return primitive.impl(f, *tracers, **params)
584 process_map = process_call
585
~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
561 *unsafe_map(arg_spec, args))
562 try:
--> 563 return compiled_fun(*args)
564 except FloatingPointError:
565 assert FLAGS.jax_debug_nans # compiled_fun can only raise in this case
~/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py in _execute_compiled(compiled, avals, handlers, *args)
809 device, = compiled.local_devices()
810 input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
--> 811 out_bufs = compiled.execute(input_bufs)
812 if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
813 return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memoryThe version of neural-tangent I am using is 0.3.5.
nikhilvyas
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working