-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
99 lines (90 loc) · 2.32 KB
/
setup.py
File metadata and controls
99 lines (90 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Setup script for CUDA LLM Kernel Optimization package.
Version is read from pyproject.toml (single source of truth).
"""
import os
import platform
import re
from pathlib import Path
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
def _read_version() -> str:
"""Read version from pyproject.toml."""
text = Path(__file__).with_name("pyproject.toml").read_text()
match = re.search(r'^version\s*=\s*"([^"]+)"', text, re.MULTILINE)
if not match:
raise RuntimeError("Cannot find version in pyproject.toml")
return match.group(1)
# CUDA architectures to compile for
CUDA_ARCHS = os.environ.get('CUDA_ARCHS', '70;75;80;86;89;90')
# Source files
cuda_sources = [
'src/naive_attention.cu',
'src/tiled_attention.cu',
'src/flash_attention.cu',
'src/tensor_core_gemm.cu',
'src/hgemm_kernel.cu',
'python/bindings.cpp',
]
# Include directories
include_dirs = [
'include',
]
# Compiler flags (platform-aware)
if platform.system() == 'Windows':
extra_compile_args = {
'cxx': ['/O2', '/std:c++17'],
'nvcc': [
'-O3',
'--use_fast_math',
'-std=c++17',
]
}
else:
extra_compile_args = {
'cxx': ['-O3', '-std=c++17'],
'nvcc': [
'-O3',
'--use_fast_math',
'-std=c++17',
'-Xcompiler', '-fPIC',
]
}
# Add architecture flags
for arch in CUDA_ARCHS.split(';'):
extra_compile_args['nvcc'].extend([
f'-gencode=arch=compute_{arch},code=sm_{arch}',
])
setup(
name='cuda_llm_ops',
version=_read_version(),
description='High-performance CUDA kernels for LLM inference',
author='CUDA LLM Kernel Optimization',
packages=find_packages(),
ext_modules=[
CUDAExtension(
name='cuda_llm_ops',
sources=cuda_sources,
include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
)
],
cmdclass={
'build_ext': BuildExtension
},
install_requires=[
'torch>=2.0.0',
'numpy',
],
extras_require={
'test': [
'pytest',
'hypothesis',
],
'benchmark': [
'matplotlib',
'pandas',
],
},
python_requires='>=3.8',
)