Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Influence model cannot execute functional call #640

Open
danielkaplan137 opened this issue Jan 20, 2025 · 4 comments
Open

Influence model cannot execute functional call #640

danielkaplan137 opened this issue Jan 20, 2025 · 4 comments
Labels
bug Something isn't working

Comments

@danielkaplan137
Copy link

While using CgInfluence on a pre-trained CNN I get the following error from torch:
raise RuntimeError("The stateless API can't be used with Jitted modules")

with the full backtrace:
`influences = influence_model.influences(x, y, x,y, mode="up")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/base_influence_function_model.py", line 162, in influences
return self._influences(x_test, y_test, x, y, mode)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/base_influence_function_model.py", line 453, in _influences
return cast(TensorType, sum(tensors))
^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/types.py", line 631, in generate_interactions
yield comp_block.interactions(left_batch, right_batch, mode)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/types.py", line 431, in interactions
return bilinear_form.grads_inner_prod(left_batch, right_batch, self.gp)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/types.py", line 253, in grads_inner_prod
left_grad = gradient_provider.flat_grads(left)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/torch/base.py", line 251, in flat_grads
self.grads(batch).values(), shape=(batch.x.shape[0], -1)
^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/torch/base.py", line 190, in grads
gradient_dict = self._grads(batch.to(self.device))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/torch/base.py", line 120, in _grads
result: Dict[str, torch.Tensor] = torch.vmap(
^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/apis.py", line 399, in wrapper
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 1449, in grad_impl
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 48, in fn
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 1407, in grad_and_value_impl
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/pydvl/influence/torch/base.py", line 116, in _compute_loss
outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/opt/anaconda3/envs/sci/lib/python3.12/site-packages/torch/_functorch/functional_call.py", line 148, in functional_call
return nn.utils.stateless._functional_call(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^`

@mdbenito
Copy link
Collaborator

Can you provide a bit more context to reproduce the issue? A minimal working example with your CNN. Also, which version of pydvl are you using?

@schroedk Any ideas?

@mdbenito mdbenito added bug Something isn't working awaiting-reply Awaiting feedback / an answer for OP labels Jan 20, 2025
@schroedk
Copy link
Collaborator

Hi @danielkaplan137,
Thank you for reporting this issue. Do you have access to the source code of the model and the ability to run a version that is not JIT-compiled?
It seems that the restriction arises from the torch.func stateless API, which is not compatible with JIT-compiled models. Nevertheless, providing a minimal working example would be helpful to determine if modifying the implementation is straightforward.
Thanks:)

@danielkaplan137
Copy link
Author

@mdbenito @schroedk
Thanks so much for responding so swiftly guys. I'm using the latest devel version, i.e., '0.9.3.dev0'

My network structure is very simple as well. It contains a backbone and two branches.

class Net(nn.Module):
def init(self):
super(Net, self).init()

#BackBone
self.seq1 = nn.Sequential(
nn.Flatten(),
nn.Linear(120, 64),
nn.ReLU()
)

#branch 1
self.seq2 = nn.Sequential(
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 2),
nn.Softmax(dim=1)
)

#Branch 2
self.seq3 = nn.Sequential(
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.ReLU(),
nn.Linear(128, 1)
)

    def forward(self, x):
        x = self.seq1(x)
        output1 = self.seq2(x)
        output2 = self.seq3(x)
        return [output2, output1]

I'm not sure which part is jitted? All of them? Anything I can do within my already trained networks?
Happy to hear what can be done! Thanks!

@danielkaplan137
Copy link
Author

Here's a simple code to reproduce

export.zip

@mdbenito mdbenito removed the awaiting-reply Awaiting feedback / an answer for OP label Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants