-
Notifications
You must be signed in to change notification settings - Fork 735
[ENH] Efficient Attention Backend for TimeXer #1997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Some more notes:
If your team's ok with me addressing these, I could address these in this PR or open a new one, whatever works for your team's code review process p.s. if anyone wants to benchmark the speed and memory consumption of the old and new attention backends, I can give you a script for that. I was not exactly sure if script such as that made sense to be included as part of the package itself |
Yes the issue is known although we still doesn't know the exact source (see #1998, and the discussion from the discord thread here)
I'd prefer a new PR (stacked on this PR, or maybe after this PR is merged) to keep the "responsibilities" separate for both the PRs. |
hm, that feels extremely useful! Could you put that into |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1997 +/- ##
=======================================
Coverage ? 86.99%
=======================================
Files ? 160
Lines ? 9494
Branches ? 0
=======================================
Hits ? 8259
Misses ? 1235
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this PR @anasashb ! This is really great
I have added some comments:
- related to the docstrings - I think this is from the older part of the code that's why it is still not using numpydoc style. I would really appreciate if you could use numpydoc style docstrings here. We still need to update the docstrings from the whole codebase :)
- Can you also add some fixtures in _timexer_pkg and _timexer_pkg_v2. We are moving from the standalone tests to a unified test framework and now only just adding test fixtures and some configs would work to test the whole models. You can see some examples in the fixtures already present in the above files and you can also look at any other model, all models have this
pkgclass now which is used to test these models. All you need to do is just update theget_base_test_params(in case of v1_timexer_pkg) andget_test_train_params(in case of v2_timexer_pkg_v2) to have the fixtures to test the new attention mechanism.
| attention_dropout (float): Dropout rate for attention scores. | ||
| output_attention (bool): Whether to output attention weights.""" | ||
| output_attention (bool): Whether to output attention weights. | ||
| efficient_attention (bool): Whether to use torch's native efficient |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be use_efficient_attention here?
Also please explain what "efficient attention" means here and how is it different from einsum_attention.
Also please use numpydoc style docstrings. I think this part is from the older part of the code, that's why it's still not updated. Updating the style to numpydoc style would be greatly appreciated!
| attention_dropout (float): Dropout rate for attention scores. | ||
| output_attention (bool): Whether to output attention weights.""" | ||
| output_attention (bool): Whether to output attention weights. | ||
| efficient_attention (bool): Whether to use torch's native efficient |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above!
Reference Issues/PRs
Fixes #1990.
What does this implement/fix? Explain your changes.
This Pull Request adds a
use_efficient_attentionboolean argument to the TimeXer model (v1 and v2 versions, both), which, if set to=True, switches to a more memory-efficient and faster attention implementation usingtorch.nn.functional_scaled_dot_product_attention()instead of thetorch.einsum()solution inside theFullAttentionclass (v1 and v2 versions both).The newly introduced argument is currently set to False to keep the new feature completely backwards compatible.
Additionally, there's a very minor bugfix in the
PositionalEmbeddingclass (v1 and v2 versions both), where a bug carried over fromtslibused to define:In torch, the correct attribute for whether a tensor requires grad is called
.requires_grad. This bug has also been fixed.What should a reviewer concentrate their feedback on?
Reviewers should focus on the implementation of
_einsum_attention()and_efficient attention()which are new private methods thatdef forward()of theFullAttentionclass calls to handle attention implementation.I did not make any other changes to the code, but if it works for you I could also:
tau,delta,factorscattered across the tslib code carried over hereOr if you'd be OK with those changes too, I can also open a separate PR.
Did you add any tests for the change?
Yes, in both:
tests/test_models/test_timxer.pyandtests/test_models/test_timexer_v2.py. These include new assertions in initialization tests, as well as parameterization of theuse_efficient_attentionfor integration tests.Any other comments?
PR checklist
pre-commit install.To run hooks independent of commit, execute
pre-commit run --all-files