Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for FlashAttention in Llama2 #584

Merged
merged 3 commits into from
Dec 12, 2023
Merged

Support for FlashAttention in Llama2 #584

merged 3 commits into from
Dec 12, 2023

Conversation

wszczurekhabana
Copy link
Contributor

@wszczurekhabana wszczurekhabana commented Dec 6, 2023

What does this PR do?

This PR introduces support for FusedSDPA operator (Flash Attention) in Llama2.

Below are preliminary results from this change:

Model BS Max input tokens Max new tokens Default performance [tokens/second] Performance with FusedSDPA [tokens/second] Throughput improvement over default Default memory allocated [GB] Memory allocated with FusedSDPA [GB] Memory reduction over default [GB]
7B 1x 1 16 4096 124.12 124.130 1.000 14.71 14.69 0.02
  4 16 4096 354.27 354.500 1.001 20.81 20.71 0.10
13B 1x 1 16 4096 68.81 68.840 1.000 27.57 27.54 0.03
  4 16 4096 203.63 203.790 1.001 37.05 36.95 0.10
70B 8x 1 16 100 55.55 56.440 1.016 16.68 16.66 0.02
  40 16 100 1875.95 1893.266 1.009 18.55 17.21 1.34
  1 16 2048 60.23 59.617 0.990 16.76 16.74 0.02
  40 16 2048 1685.81 1686.918 1.001 21.25 20.27 0.98
  60 16 2048 2206.59 2208.290 1.001 24.67 22.09 2.58
  1 16 4096 59.77 59.044 0.988 16.86 16.82 0.04
  40 16 4096 1366.84 1368.552 1.001 24.99 23.50 1.49
  60 16 4096 1689.34 1689.680 1.0 30.09 26.92 3.17

By default this is turned off.

@wszczurekhabana
Copy link
Contributor Author

Hi @regisss , could you take a look at this PR and trigger CI for it?

@puneeshkhanna
Copy link
Contributor

puneeshkhanna commented Dec 6, 2023

@wszczurekhabana - Changes look go to me. Just gave a minor comment for help text and pass the parameter in create_custom_forward() in modeling_llama.py.

@mandy-li mandy-li added the run-test Run CI for PRs from external contributors label Dec 6, 2023
@regisss
Copy link
Collaborator

regisss commented Dec 6, 2023

Is this the same as #583 ?
cc @mandy-li

@mandy-li
Copy link
Collaborator

mandy-li commented Dec 6, 2023

Is this the same as #583 ? cc @mandy-li

@regisss , this PR focus on inference, but there is one file overlapped. This PR can go first, and after merged, I will update my PR to only contain FT related code changes. thanks

@mandy-li
Copy link
Collaborator

mandy-li commented Dec 7, 2023

LGTM

@regisss regisss added run-test Run CI for PRs from external contributors and removed run-test Run CI for PRs from external contributors labels Dec 8, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

The code style check failed. Could you run the following please?

pip install --upgrade ruff
make style

@regisss
Copy link
Collaborator

regisss commented Dec 11, 2023

@wszczurekhabana There is a merge conflict to solve since #589 was merged

@regisss regisss added run-test Run CI for PRs from external contributors and removed run-test Run CI for PRs from external contributors labels Dec 12, 2023
Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants