Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
77fe3a4
disable navi
micmelesse Jan 3, 2025
92fb040
start test
micmelesse Jan 3, 2025
957b0e6
test fp16 against fp8
micmelesse Jan 3, 2025
a290a6d
save scaling code so far
micmelesse Jan 7, 2025
3542300
global scaling
micmelesse Jan 7, 2025
0121712
add per_head_scaling
micmelesse Jan 7, 2025
1cef817
dump qk
micmelesse Jan 7, 2025
390e990
save dumping q, k and qk to fp32 tensor
micmelesse Jan 8, 2025
e834dd2
fix pointer bug
micmelesse Jan 9, 2025
2a4899a
save reproducer
micmelesse Jan 9, 2025
65ad5f2
dump p and acc
micmelesse Jan 9, 2025
d080d33
fp8 working with my debug input
micmelesse Jan 10, 2025
daa7532
save
micmelesse Jan 10, 2025
b4c3dc3
change api for dequant
micmelesse Jan 11, 2025
dd5002d
pass descale_p
micmelesse Jan 12, 2025
4dfe187
clean up
micmelesse Jan 12, 2025
7afec5c
most working
micmelesse Jan 12, 2025
ba01300
save
micmelesse Jan 13, 2025
d1c3e46
save
micmelesse Jan 13, 2025
1fd2219
varlen half way
micmelesse Jan 13, 2025
277fac6
some varlen examples work
micmelesse Jan 14, 2025
554bee9
improve varlen debug input
micmelesse Jan 15, 2025
93fab91
varlen mostly working
micmelesse Jan 15, 2025
06982bf
push working cases
micmelesse Jan 16, 2025
4c110bd
fix ref bug
micmelesse Jan 16, 2025
ce51aac
fix backward bug
micmelesse Jan 16, 2025
6e2dcbf
fix varlen backward bug
micmelesse Jan 16, 2025
db4a331
use descale to set fp8
micmelesse Jan 16, 2025
9a5b607
check arch fp8 support
micmelesse Jan 17, 2025
a811071
cache arch
micmelesse Jan 17, 2025
5037533
try again
micmelesse Jan 17, 2025
fb5c01e
skip bad config on MI200
micmelesse Jan 17, 2025
f38f6df
skip decode nan config on MI200
micmelesse Jan 17, 2025
3058fef
fix mistake
micmelesse Jan 17, 2025
35ac3ef
skip more
micmelesse Jan 17, 2025
323d8dc
run full suit
micmelesse Jan 17, 2025
8a6fa25
Update amd_tests.yml
micmelesse Jan 17, 2025
d51ee78
address comments
micmelesse Jan 20, 2025
fd49369
navi ci is broken
micmelesse Jan 20, 2025
e728cab
raise error tolerance to 2.5e-1
micmelesse Jan 20, 2025
d1b6fd9
target MI300 directly
micmelesse Jan 21, 2025
3d9e0dd
show gfx
micmelesse Jan 22, 2025
8cb52c7
try again
micmelesse Jan 22, 2025
e50bc0c
don't fail matrix if one path fails
micmelesse Jan 22, 2025
15cb129
try upstream triton
micmelesse Jan 22, 2025
4dd92bd
just get MI300 working
micmelesse Jan 22, 2025
ae8d4cf
Fix install bug
micmelesse Jan 23, 2025
b715392
run ref on cpu
micmelesse Jan 23, 2025
0c54226
move ref test to navi machines
micmelesse Jan 23, 2025
ef7e107
pin triton
micmelesse Jan 23, 2025
924cad4
add bench deps
micmelesse Jan 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
branches: [main_perf]
types: [checks_requested]
push:
branches: [main_perf, micmelesse/upstream_pr]
branches: [main_perf]

concurrency:
group: ${{ github.ref }}
Expand All @@ -17,70 +17,75 @@ concurrency:
permissions: read-all

jobs:
Runner-Preparation-AMD:
runs-on: ubuntu-latest
timeout-minutes: 30
outputs:
matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx1100"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi

Integration-Tests-AMD:
needs: Runner-Preparation-AMD
if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != ''
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}}
runner: [linux-mi300-gpu-1]
fail-fast: false # disables failing the entire job when one matrix entry fails
container:
image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
image: rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Triton
- name: Show Device Info
run: |
rocminfo | grep gfx
- name: Uninstall Triton
run : |
pip uninstall -y triton
pip install matplotlib pandas pytest
rm -rf ~/.triton
rm -rf ./triton/python/build
- name: Install Triton
run: |
git clone https://github.com/triton-lang/triton
cd triton
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
pip install --verbose -e python
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
pip install ninja cmake wheel pybind11 # build-time dependencies
pip install matplotlib pandas pytest # triton bench dependencies
pip install --verbose --no-build-isolation ./python
cd ..
- name: Show Triton version
run: |
pip show triton
- name: Build
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python setup.py install
- name: Flash Attention Tests Using Reference Impl
if: matrix.runner[1] == 'gfx90a'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_REF=1
pytest tests/test_flash_attn_triton_amd.py
- name: Flash Attention Tests

# CDNA Tests
- name: Flash Attention CDNA Tests
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py
- name: AMD Tests
if: matrix.runner[1] == 'gfx90a'
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest -v -s flash_attn/flash_attn_triton_amd/test.py
- name: AMD Bench
if: matrix.runner[1] == 'gfx90a'
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python flash_attn/flash_attn_triton_amd/bench.py
- name: AMD Bench with Autotune
if: matrix.runner[1] == 'gfx90a'
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1
python flash_attn/flash_attn_triton_amd/bench.py

# RDNA Tests
- name: Flash Attention Tests Using Reference Impl
if: matrix.runner == 'gfx1100'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_REF=1
pytest tests/test_flash_attn_triton_amd.py
- name: Flash Attention RDNA Tests
if: matrix.runner == 'gfx1100'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output tests/test_flash_attn_triton_amd.py::test_flash_attn_varlen_output tests/test_flash_attn_triton_amd.py::test_flash_attn_kvcache
54 changes: 51 additions & 3 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def _flash_attn_forward(
window_size_right: int,
softcap: float,
alibi_slopes: Optional[torch.Tensor],
return_softmax: bool
return_softmax: bool,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
Expand All @@ -107,6 +111,10 @@ def _flash_attn_forward(
softcap,
return_softmax,
None,
descale_q,
descale_k,
descale_v,
descale_p
)
return out, softmax_lse, S_dmask, rng_state

Expand Down Expand Up @@ -164,6 +172,10 @@ def _flash_attn_varlen_forward(
block_table: Optional[torch.Tensor] = None,
leftpad_k: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
Expand All @@ -188,6 +200,10 @@ def _flash_attn_varlen_forward(
softcap,
return_softmax,
None,
descale_q,
descale_k,
descale_v,
descale_p
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
Expand Down Expand Up @@ -804,6 +820,10 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
descale_q,
descale_k,
descale_v,
descale_p
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -824,6 +844,10 @@ def forward(
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -867,7 +891,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand All @@ -890,6 +914,10 @@ def forward(
deterministic,
return_softmax,
block_table,
descale_q,
descale_k,
descale_v,
descale_p
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -915,6 +943,10 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Expand Down Expand Up @@ -966,7 +998,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


def flash_attn_qkvpacked_func(
Expand Down Expand Up @@ -1116,6 +1148,10 @@ def flash_attn_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -1177,6 +1213,10 @@ def flash_attn_func(
alibi_slopes,
deterministic,
return_attn_probs,
descale_q,
descale_k,
descale_v,
descale_p
)


Expand Down Expand Up @@ -1353,6 +1393,10 @@ def flash_attn_varlen_func(
deterministic=False,
return_attn_probs=False,
block_table=None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1426,6 +1470,10 @@ def flash_attn_varlen_func(
deterministic,
return_attn_probs,
block_table,
descale_q,
descale_k,
descale_v,
descale_p
)


Expand Down
Loading