-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Influence model cannot execute functional call #640
Comments
Can you provide a bit more context to reproduce the issue? A minimal working example with your CNN. Also, which version of pydvl are you using? @schroedk Any ideas? |
Hi @danielkaplan137, |
@mdbenito @schroedk My network structure is very simple as well. It contains a backbone and two branches. class Net(nn.Module): #BackBone #branch 1 #Branch 2
I'm not sure which part is jitted? All of them? Anything I can do within my already trained networks? |
Here's a simple code to reproduce |
While using CgInfluence on a pre-trained CNN I get the following error from torch:
raise RuntimeError("The stateless API can't be used with Jitted modules")
with the full backtrace:
`influences = influence_model.influences(x, y, x,y, mode="up")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/base_influence_function_model.py", line 162, in influences
return self._influences(x_test, y_test, x, y, mode)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/base_influence_function_model.py", line 453, in _influences
return cast(TensorType, sum(tensors))
^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/types.py", line 631, in generate_interactions
yield comp_block.interactions(left_batch, right_batch, mode)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/types.py", line 431, in interactions
return bilinear_form.grads_inner_prod(left_batch, right_batch, self.gp)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/types.py", line 253, in grads_inner_prod
left_grad = gradient_provider.flat_grads(left)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/torch/base.py", line 251, in flat_grads
self.grads(batch).values(), shape=(batch.x.shape[0], -1)
^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/torch/base.py", line 190, in grads
gradient_dict = self._grads(batch.to(self.device))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/torch/base.py", line 120, in _grads
result: Dict[str, torch.Tensor] = torch.vmap(
^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/apis.py", line 399, in wrapper
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 1449, in grad_impl
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 48, in fn
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 1407, in grad_and_value_impl
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/torch/base.py", line 116, in _compute_loss
outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/functional_call.py", line 148, in functional_call
return nn.utils.stateless._functional_call(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^`
The text was updated successfully, but these errors were encountered: