diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml index 1a9ed80ec..608915133 100644 --- a/.github/workflows/amd_tests.yml +++ b/.github/workflows/amd_tests.yml @@ -21,7 +21,7 @@ jobs: runs-on: ${{ matrix.runner }} strategy: matrix: - runner: [ubuntu-22.04, linux-mi300-gpu-1] + runner: [linux-mi300-gpu-1] fail-fast: false # disables failing the entire job when one matrix entry fails container: image: rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 @@ -52,14 +52,6 @@ jobs: export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install - # CPU Tests - - name: Flash Attention Tests Using Reference Impl - if: matrix.runner == 'ubuntu-22.04' - run: | - export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - export FLASH_ATTENTION_TRITON_AMD_REF=1 - pytest tests/test_flash_attn_triton_amd.py - # CDNA Tests - name: Flash Attention CDNA Tests if: matrix.runner == 'linux-mi300-gpu-1' @@ -84,6 +76,12 @@ jobs: python flash_attn/flash_attn_triton_amd/bench.py # RDNA Tests + - name: Flash Attention Tests Using Reference Impl + if: matrix.runner == 'gfx1100' + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + export FLASH_ATTENTION_TRITON_AMD_REF=1 + pytest tests/test_flash_attn_triton_amd.py - name: Flash Attention RDNA Tests if: matrix.runner == 'gfx1100' run: |