Skip to content

Commit 71c6049

Browse files
[Kernel] Build flash-attn from source (vllm-project#8245)
1 parent 0faab90 commit 71c6049

File tree

9 files changed

+124
-41
lines changed

9 files changed

+124
-41
lines changed

.github/workflows/scripts/build.sh

+1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ $python_executable -m pip install -r requirements-cuda.txt
1515
export MAX_JOBS=1
1616
# Make sure release wheels are built for the following architectures
1717
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
18+
export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real"
1819
# Build
1920
$python_executable setup.py bdist_wheel --dist-dir=dist

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# vllm commit id, generated by setup.py
22
vllm/commit_id.py
33

4+
# vllm-flash-attn built from source
5+
vllm/vllm_flash_attn/
6+
47
# Byte-compiled / optimized / DLL files
58
__pycache__/
69
*.py[cod]
@@ -12,6 +15,8 @@ __pycache__/
1215
# Distribution / packaging
1316
.Python
1417
build/
18+
cmake-build-*/
19+
CMakeUserPresets.json
1520
develop-eggs/
1621
dist/
1722
downloads/

CMakeLists.txt

+73-25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
cmake_minimum_required(VERSION 3.26)
22

3+
# When building directly using CMake, make sure you run the install step
4+
# (it places the .so files in the correct location).
5+
#
6+
# Example:
7+
# mkdir build && cd build
8+
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. ..
9+
# cmake --build . --target install
10+
#
11+
# If you want to only build one target, make sure to install it manually:
12+
# cmake --build . --target _C
13+
# cmake --install . --component _C
314
project(vllm_extensions LANGUAGES CXX)
415

516
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
@@ -13,6 +24,9 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
1324
# Suppress potential warnings about unused manually-specified variables
1425
set(ignoreMe "${VLLM_PYTHON_PATH}")
1526

27+
# Prevent installation of dependencies (cutlass) by default.
28+
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
29+
1630
#
1731
# Supported python versions. These versions will be searched in order, the
1832
# first match will be selected. These should be kept in sync with setup.py.
@@ -70,19 +84,6 @@ endif()
7084
find_package(Torch REQUIRED)
7185

7286
#
73-
# Add the `default` target which detects which extensions should be
74-
# built based on platform/architecture. This is the same logic that
75-
# setup.py uses to select which extensions should be built and should
76-
# be kept in sync.
77-
#
78-
# The `default` target makes direct use of cmake easier since knowledge
79-
# of which extensions are supported has been factored in, e.g.
80-
#
81-
# mkdir build && cd build
82-
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
83-
# cmake --build . --target default
84-
#
85-
add_custom_target(default)
8687
message(STATUS "Enabling core extension.")
8788

8889
# Define _core_C extension
@@ -100,8 +101,6 @@ define_gpu_extension_target(
100101
USE_SABI 3
101102
WITH_SOABI)
102103

103-
add_dependencies(default _core_C)
104-
105104
#
106105
# Forward the non-CUDA device extensions to external CMake scripts.
107106
#
@@ -167,6 +166,8 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
167166
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
168167
endif()
169168

169+
include(FetchContent)
170+
170171
#
171172
# Define other extension targets
172173
#
@@ -190,7 +191,6 @@ set(VLLM_EXT_SRC
190191
"csrc/torch_bindings.cpp")
191192

192193
if(VLLM_GPU_LANG STREQUAL "CUDA")
193-
include(FetchContent)
194194
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
195195
FetchContent_Declare(
196196
cutlass
@@ -283,6 +283,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
283283
csrc/quantization/machete/machete_pytorch.cu)
284284
endif()
285285

286+
message(STATUS "Enabling C extension.")
286287
define_gpu_extension_target(
287288
_C
288289
DESTINATION vllm
@@ -313,6 +314,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
313314
"csrc/moe/marlin_moe_ops.cu")
314315
endif()
315316

317+
message(STATUS "Enabling moe extension.")
316318
define_gpu_extension_target(
317319
_moe_C
318320
DESTINATION vllm
@@ -323,7 +325,6 @@ define_gpu_extension_target(
323325
USE_SABI 3
324326
WITH_SOABI)
325327

326-
327328
if(VLLM_GPU_LANG STREQUAL "HIP")
328329
#
329330
# _rocm_C extension
@@ -343,16 +344,63 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
343344
WITH_SOABI)
344345
endif()
345346

347+
# vllm-flash-attn currently only supported on CUDA
348+
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
349+
return()
350+
endif ()
346351

347-
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
348-
message(STATUS "Enabling C extension.")
349-
add_dependencies(default _C)
352+
#
353+
# Build vLLM flash attention from source
354+
#
355+
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
356+
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
357+
# They should be identical but if they aren't, this is a massive footgun.
358+
#
359+
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
360+
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
361+
# If no component is specified, vllm-flash-attn is still installed.
350362

351-
message(STATUS "Enabling moe extension.")
352-
add_dependencies(default _moe_C)
363+
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
364+
# This is to enable local development of vllm-flash-attn within vLLM.
365+
# It can be set as an environment variable or passed as a cmake argument.
366+
# The environment variable takes precedence.
367+
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
368+
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
353369
endif()
354370

355-
if(VLLM_GPU_LANG STREQUAL "HIP")
356-
message(STATUS "Enabling rocm extension.")
357-
add_dependencies(default _rocm_C)
371+
if(VLLM_FLASH_ATTN_SRC_DIR)
372+
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
373+
else()
374+
FetchContent_Declare(
375+
vllm-flash-attn
376+
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
377+
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
378+
GIT_PROGRESS TRUE
379+
)
358380
endif()
381+
382+
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
383+
set(VLLM_PARENT_BUILD ON)
384+
385+
# Make sure vllm-flash-attn install rules are nested under vllm/
386+
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
387+
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
388+
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
389+
390+
# Fetch the vllm-flash-attn library
391+
FetchContent_MakeAvailable(vllm-flash-attn)
392+
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
393+
394+
# Restore the install prefix
395+
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
396+
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
397+
398+
# Copy over the vllm-flash-attn python files
399+
install(
400+
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
401+
DESTINATION vllm/vllm_flash_attn
402+
COMPONENT vllm_flash_attn_c
403+
FILES_MATCHING PATTERN "*.py"
404+
)
405+
406+
# Nothing after vllm-flash-attn, see comment about macros above

Dockerfile

+3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
4848
# see https://github.com/pytorch/pytorch/pull/123243
4949
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
5050
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
51+
# Override the arch list for flash-attn to reduce the binary size
52+
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
53+
ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
5154
#################### BASE BUILD IMAGE ####################
5255

5356
#################### WHEEL BUILD IMAGE ####################

cmake/utils.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -364,5 +364,5 @@ function (define_gpu_extension_target GPU_MOD_NAME)
364364
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
365365
endif()
366366

367-
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION})
367+
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
368368
endfunction()

requirements-cuda.txt

-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ torch == 2.4.0
88
# These must be updated alongside torch
99
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
1010
xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
11-
vllm-flash-attn == 2.6.1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0

setup.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import subprocess
77
import sys
88
import warnings
9+
from pathlib import Path
910
from shutil import which
1011
from typing import Dict, List
1112

@@ -152,15 +153,8 @@ def configure(self, ext: CMakeExtension) -> None:
152153
default_cfg = "Debug" if self.debug else "RelWithDebInfo"
153154
cfg = envs.CMAKE_BUILD_TYPE or default_cfg
154155

155-
# where .so files will be written, should be the same for all extensions
156-
# that use the same CMakeLists.txt.
157-
outdir = os.path.abspath(
158-
os.path.dirname(self.get_ext_fullpath(ext.name)))
159-
160156
cmake_args = [
161157
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
162-
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir),
163-
'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp),
164158
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
165159
]
166160

@@ -224,10 +218,12 @@ def build_extensions(self) -> None:
224218
os.makedirs(self.build_temp)
225219

226220
targets = []
221+
target_name = lambda s: remove_prefix(remove_prefix(s, "vllm."),
222+
"vllm_flash_attn.")
227223
# Build all the extensions
228224
for ext in self.extensions:
229225
self.configure(ext)
230-
targets.append(remove_prefix(ext.name, "vllm."))
226+
targets.append(target_name(ext.name))
231227

232228
num_jobs, _ = self.compute_num_jobs()
233229

@@ -240,6 +236,28 @@ def build_extensions(self) -> None:
240236

241237
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
242238

239+
# Install the libraries
240+
for ext in self.extensions:
241+
# Install the extension into the proper location
242+
outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()
243+
244+
# Skip if the install directory is the same as the build directory
245+
if outdir == self.build_temp:
246+
continue
247+
248+
# CMake appends the extension prefix to the install path,
249+
# and outdir already contains that prefix, so we need to remove it.
250+
prefix = outdir
251+
for i in range(ext.name.count('.')):
252+
prefix = prefix.parent
253+
254+
# prefix here should actually be the same for all components
255+
install_args = [
256+
"cmake", "--install", ".", "--prefix", prefix, "--component",
257+
target_name(ext.name)
258+
]
259+
subprocess.check_call(install_args, cwd=self.build_temp)
260+
243261

244262
def _no_device() -> bool:
245263
return VLLM_TARGET_DEVICE == "empty"
@@ -467,6 +485,10 @@ def _read_requirements(filename: str) -> List[str]:
467485
if _is_hip():
468486
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
469487

488+
if _is_cuda():
489+
ext_modules.append(
490+
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
491+
470492
if _build_custom_ops():
471493
ext_modules.append(CMakeExtension(name="vllm._C"))
472494

vllm/attention/backends/flash_attn.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,13 @@
1919
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
2020
ModelInputForGPUWithSamplingMetadata)
2121

22-
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
23-
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
22+
# yapf: disable
23+
from vllm.vllm_flash_attn import (
24+
flash_attn_varlen_func as _flash_attn_varlen_func)
25+
from vllm.vllm_flash_attn import (
26+
flash_attn_with_kvcache as _flash_attn_with_kvcache)
27+
28+
# yapf: enable
2429

2530

2631
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])

vllm/attention/selector.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,7 @@ def which_attn_to_use(
244244
# FlashAttn is valid for the model, checking if the package is installed.
245245
if selected_backend == _Backend.FLASH_ATTN:
246246
try:
247-
import vllm_flash_attn # noqa: F401
248-
247+
import vllm.vllm_flash_attn # noqa: F401
249248
from vllm.attention.backends.flash_attn import ( # noqa: F401
250249
FlashAttentionBackend)
251250

@@ -258,8 +257,9 @@ def which_attn_to_use(
258257
except ImportError:
259258
logger.info(
260259
"Cannot use FlashAttention-2 backend because the "
261-
"vllm_flash_attn package is not found. "
262-
"`pip install vllm-flash-attn` for better performance.")
260+
"vllm.vllm_flash_attn package is not found. "
261+
"Make sure that vllm_flash_attn was built and installed "
262+
"(on by default).")
263263
selected_backend = _Backend.XFORMERS
264264

265265
return selected_backend

0 commit comments

Comments
 (0)