You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was trying to use the S5Block with an Adam optimizer with weight decay. However, I got a strange bug, that the sizes of parameters and gradients mismatch. The error only occures with cuda tensors/model and only when weight_decay is enabled. Below a minimal script that reproduces the bug:
froms5importS5Blockimporttorchx=torch.randn(16, 64, 256).cuda()
a=S5Block(256, 128, False).cuda()
a.train()
# h = torch.optim.Adam(a.parameters(), lr=0.001) # this worksh=torch.optim.Adam(a.parameters(), lr=0.001, weight_decay=0.0001) # this doesn't workout=a(x.cuda())
out.sum().backward()
h.step()
After a lot of digging I found the part that caused the error: complex data type handling of device parameters is faulty in the _multi_tensor_adam in the newest version 2.0.1 of pytorch. Specifically in L. 442 in torch/optim/adam.py was a wrong variable used for computing the weight decay.
However, this seems to have been fixed since May 9 with this commit. So with a newer pytorch version this should be working. Right now, this remains broken.
Just posting this here in case anyone else is having this issue.
The text was updated successfully, but these errors were encountered:
Hi,
I was trying to use the S5Block with an Adam optimizer with weight decay. However, I got a strange bug, that the sizes of parameters and gradients mismatch. The error only occures with cuda tensors/model and only when
weight_decay
is enabled. Below a minimal script that reproduces the bug:After a lot of digging I found the part that caused the error: complex data type handling of device parameters is faulty in the
_multi_tensor_adam
in the newest version2.0.1
of pytorch. Specifically in L. 442 in torch/optim/adam.py was a wrong variable used for computing the weight decay.However, this seems to have been fixed since May 9 with this commit. So with a newer pytorch version this should be working. Right now, this remains broken.
Just posting this here in case anyone else is having this issue.
The text was updated successfully, but these errors were encountered: