Skip to content

Conversation

@sahirema
Copy link

Updates the fake tensor wrapper function to make flash_attn_func and flash_attn_varlen_func compatible with torch compile. Without the update torch.compile(flash_attn_func) throws "wrong number of dimensions" error because the bwd pass returns a 3d tensor when the fake tensor wrapper expects a 2d tensor. The details of the errors are reported in SWDEV-559708, SWDEV-546369, SWDEV-559718.

test_flash_attn_output and test_flash_attn_varlen_output are extended to test both the eager and compiled versions of the function with the help of a Boolean parameter called compiled. Here are the results of the test

  • Reference results in Eager mode
    • flash_attn_func : 42240 passed, 1 warning
    • flash_attn_varlen_func: 46464 passed, 1 warning
  • Compiled FA results without the patch
    • flash_attn_func : 10 failed, 42230 passed, 1 warning
    • flash_attn_varlen_func: 10 failed, 46454 passed, 1 warning
  • Compiled FA results with patch
    • flash_attn_func : 42240 passed, 1 warning
    • flash_attn_varlen_func: 46464 passed, 1 warning

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants