Skip to content

Explanation on why PyTorch 1.6 or above version is required and other info #1

Open
@pmixer

Description

@pmixer

Hi Guys,

Thx for checking the repo, as you may still meet some problem due to various HW&SW settings, here's few links to help resolve potential issues:

  • Although MultiheadAttention layer has been available since PyTorch 1.1, pls be sure to use PyTorch 1.6 or above, there’s some problems with attention mask implementation(for enforcing causality) in older versions shown in: attn_mask in nn.MultiheadAttention is additive pytorch/pytorch#21518
  • There’s a small bug in original tf implementation of SASRec which has been fixed it: fix a small bug for generating masks kang205/SASRec#15 as we are using PyTorch's official MultiheadAttention implementation, similar problem should not exist.
  • Pls output logits and use torch.nn.BCEWithLogitsLoss rather than applying sigmoid and use torch.nn.torch.nn.BCELoss for model training, depending on PyTorch version, you may meet a bug if do it in the second approach: CUDA assertion error binary_cross_entropy loss NVIDIA/pix2pixHD#9
  • Current version converges slower compared to original tf implementation, I’m still checking the details to find out the root cause, pls help if you happen to be interested and have bandwidth for doing small fixes :)

Stay Healthy,
Zan

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions