-
Notifications
You must be signed in to change notification settings - Fork 66
[release/2.6] Integrate CK prebuilt library into Pytorch build #2031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…eate link target. Enable USE_CK_FLASH_ATTENTION based on USE_FLASH_ATTENTION option.
… using lintrunner.
… and code cleanup.
Jenkins build for 128a89689e52ed9d963503d3369fc78abd8575ea commit finished as FAILURE |
Jenkins build for 128a89689e52ed9d963503d3369fc78abd8575ea commit finished as FAILURE |
cmake/External/ck.cmake
Outdated
add_dependencies(__ck_lib ck_kernels_external) | ||
message(STATUS "Using CK_kernels from pre-compiled binary ${ck_kernels_package_full_url}; installed at ${ck_kernels_install_dir}") | ||
else() | ||
message(FATAL_ERROR "Unable to find an existing CK_kernels installation or to install CK_kernels library") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@akashveramd: @pruthvistony suggested we should replace this with steps to build CK from source, so that anyone trying to build this PyTorch branch from source outside a CI environment would also be able to set USE_CK_FLASH_ATTENTION
and build with CK backend successfully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented this feature and tested it for the following 4 cases-
- Don't set USE_CK_FLASH_ATTENTION when building PyTorch
Build pytorch using-
PYTORCH_ROCM_ARCH=gfx90a python setup.py develop |& tee build_1.log
1.a) Run test (using aotriton backend)
1.b) Run test (using CK backend) - Set USE_CK_FLASH_ATTENTION, but don't set CK_KERNELS_INSTALLED_PREFIX or
CK_KERNELS_PACKAGE_BASE_URL env vars
Build pytorch using-
USE_CK_FLASH_ATTENTION=1 PYTORCH_ROCM_ARCH=gfx90a python setup.py develop |& tee build_2.log
2.a) Run test (using aotriton backend)
2.b) Run test (using CK backend) - Set USE_CK_FLASH_ATTENTION, but set CK_KERNELS_PACKAGE_BASE_URL env var only to https://compute-artifactory.amd.com/artifactory/rocm-generic-local
Build pytorch using-
export CK_KERNELS_PACKAGE_BASE_URL=https://compute-artifactory.amd.com/artifactory/rocm-generic-local
USE_CK_FLASH_ATTENTION=1 PYTORCH_ROCM_ARCH=gfx90a python setup.py develop |& tee build_3.log
3.a) Run test (using aotriton backend)
3.b) Run test (using CK backend) - Set USE_CK_FLASH_ATTENTION, but set CK_KERNELS_INSTALLED_PREFIX env var only. Should pick up from preinstalled location
Build pytorch using-
export CK_KERNELS_INSTALLED_PREFIX=
USE_CK_FLASH_ATTENTION=1 PYTORCH_ROCM_ARCH=gfx90a python setup.py develop |& tee build_4.log
4.a) Run test (using aotriton backend)
4.b) Run test (using CK backend)
The tests that were used are-
- PYTORCH_TEST_WITH_ROCM=1 CI=1 python test_transformers.py -k test_mem_efficient_attention_vs_math_ref_grads_batch_size_1_seq_len_q_103_seq_len_k_103_head_dim_128_is_causal_True_dropout_p_0_22_float32_scale0_cuda_float32 2>&1
Should fail with CK backend, pass with AOTriton backend - PYTORCH_TEST_WITH_ROCM=1 CI=1 python test/test_transformers.py -k test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_143_seq_len_k_127_head_dim_8_is_causal_False_dropout_p_0_0_bfloat16_scale0_enable_gqa_False_n_heads0_cuda_bfloat16 2>&1
Should pass with both CK and AOTriton backends
Jenkins build for 82fb8f3ed28189eb30513c26d55a0430538c1f9e commit finished as FAILURE |
Jenkins build for 82fb8f3ed28189eb30513c26d55a0430538c1f9e commit finished as FAILURE |
This prints a warning message to stderr
|
…rs out due to libck_kernels.so dependency in torch/lib path not being available, as the copying via command happens later
Jenkins build for 2a426492b96ccc66c5ec78a3e5570677d12ce613 commit finished as FAILURE |
Jenkins build for e00962cf61a9fc0cab2c934acf3e5846110ef278 commit finished as FAILURE |
Jenkins build for e00962cf61a9fc0cab2c934acf3e5846110ef278 commit finished as FAILURE |
Jenkins build for e00962cf61a9fc0cab2c934acf3e5846110ef278 commit finished as FAILURE |
Jenkins build for e00962cf61a9fc0cab2c934acf3e5846110ef278 commit finished as FAILURE |
Jenkins build for e00962cf61a9fc0cab2c934acf3e5846110ef278 commit is in progress |
Includes changes from #2016
Validation steps: #2031 (comment)