Open
Description
Hi Guys,
Thx for checking the repo, as you may still meet some problem due to various HW&SW settings, here's few links to help resolve potential issues:
- Although MultiheadAttention layer has been available since PyTorch 1.1, pls be sure to use PyTorch 1.6 or above, there’s some problems with attention mask implementation(for enforcing causality) in older versions shown in:
attn_mask
in nn.MultiheadAttention is additive pytorch/pytorch#21518 - There’s a small bug in original tf implementation of SASRec which has been fixed it: fix a small bug for generating masks kang205/SASRec#15 as we are using PyTorch's official MultiheadAttention implementation, similar problem should not exist.
- Pls output logits and use torch.nn.BCEWithLogitsLoss rather than applying sigmoid and use torch.nn.torch.nn.BCELoss for model training, depending on PyTorch version, you may meet a bug if do it in the second approach: CUDA assertion error binary_cross_entropy loss NVIDIA/pix2pixHD#9
- Current version converges slower compared to original tf implementation, I’m still checking the details to find out the root cause, pls help if you happen to be interested and have bandwidth for doing small fixes :)
Stay Healthy,
Zan