-
Notifications
You must be signed in to change notification settings - Fork 288
[Feat] Add triton autotuner fallback logic #1148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feat] Add triton autotuner fallback logic #1148
Conversation
Summary of ChangesHello @flyinglandlord, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! 本次拉取请求引入了一项新功能,旨在增强 Triton 内核配置加载的鲁棒性。它增加了一个命令行参数来启用回退机制,当系统找不到与当前 Triton 版本精确匹配的内核配置时,将尝试从最接近的兼容 Triton 版本加载配置。这有助于避免因 Triton 版本不完全匹配预调优配置而导致的错误,从而提高系统的稳定性和兼容性。 Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
这个PR实现了一个很好的功能,当找不到当前Triton版本的内核配置时,可以回退使用最接近版本的配置。代码逻辑整体上是正确的,并且通过命令行参数控制该行为的开关,设计得很好。我在 autotuner.py 中发现了一些可以优化代码可读性和减少重复的地方,并提供了一个代码建议。请查看具体的审查评论。
| elif get_env_start_args().enable_kernel_config_fallback: | ||
|
|
||
| def parse_triton_version_tag(tag: str) -> Optional[Tuple[int, int, int]]: | ||
| """ | ||
| Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0. | ||
| Returns None if invalid. | ||
| """ | ||
| match = re.match(r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$", tag) | ||
| if not match: | ||
| return None | ||
| x, y, z = match.groups() | ||
| return (int(x), int(y), int(z) if z is not None else 0) | ||
|
|
||
| def version_distance(v1: Tuple[int, int, int], v2: Tuple[int, int, int]) -> int: | ||
| """ | ||
| Compute weighted distance: major * 1e6 + minor * 1e3 + patch | ||
| Ensures lexicographic ordering. | ||
| """ | ||
| return abs((v1[0] - v2[0]) * 1_000_000 + (v1[1] - v2[1]) * 1_000 + (v1[2] - v2[2])) | ||
|
|
||
| current_triton_version = get_triton_version() | ||
| current_parsed = parse_triton_version_tag(current_triton_version) | ||
| if current_parsed is None: | ||
| logger.error("Unable to parse current Triton version. Triton may not be installed properly.") | ||
| possible_dirs = [ | ||
| d | ||
| for d in os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs")) | ||
| if d.startswith("triton_") | ||
| ] | ||
| possible_dirs.sort() | ||
| else: | ||
| config_dir = os.path.join(Path(__file__).parent, "autotune_kernel_configs") | ||
| possible_dirs = [] | ||
| for d in os.listdir(config_dir): | ||
| if not d.startswith("triton_"): | ||
| continue | ||
| parsed = parse_triton_version_tag(d) | ||
| if parsed is not None: | ||
| dist = version_distance(parsed, current_parsed) | ||
| possible_dirs.append((dist, d, parsed)) | ||
| else: | ||
| logger.debug(f"Skipping invalid version directory: {d}") | ||
| possible_dirs.sort(key=lambda x: x[0]) | ||
| possible_dirs = [d for _, d, _ in possible_dirs] | ||
|
|
||
| loaded = False | ||
| for triton_version in possible_dirs: | ||
| fallback_cache_file = os.path.join( | ||
| Path(__file__).parent, | ||
| "autotune_kernel_configs", | ||
| triton_version, | ||
| get_current_device_name(), | ||
| self.kernel_name, | ||
| KernelConfigs.get_config_file_name(static_key), | ||
| ) | ||
| if os.path.exists(fallback_cache_file): | ||
| try: | ||
| logger.warning( | ||
| f"Fallback loading cached configs for {self.kernel_name} - {static_key} " | ||
| f"from triton version {triton_version} (current: {current_triton_version})" | ||
| ) | ||
| with open(fallback_cache_file, "rb") as f: | ||
| self.cached_configs[static_key] = orjson.loads(f.read()) | ||
| loaded = True | ||
| break | ||
| except Exception as e: | ||
| logger.error(f"Failed to load fallback config from {fallback_cache_file}: {e}") | ||
|
|
||
| if not loaded: | ||
| logger.info(f"No fallback config found for {self.kernel_name} - {static_key}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分的回退逻辑写得很好,但有几个地方可以优化以提高代码的可读性和可维护性:
autotune_kernel_configs目录的路径在代码中被多次拼接,可以将其提取到一个变量中以避免重复。- 在寻找最接近版本时,
possible_dirs列表中存储了(dist, d, parsed),但parsed变量在之后没有被使用,这造成了轻微的冗余。
我已经将这些优化合并到了下面的代码建议中,它重构了这部分逻辑,使其更简洁。
elif get_env_start_args().enable_kernel_config_fallback:
autotune_configs_dir = os.path.join(Path(__file__).parent, "autotune_kernel_configs")
def parse_triton_version_tag(tag: str) -> Optional[Tuple[int, int, int]]:
"""
Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0.
Returns None if invalid.
"""
match = re.match(r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$", tag)
if not match:
return None
x, y, z = match.groups()
return (int(x), int(y), int(z) if z is not None else 0)
def version_distance(v1: Tuple[int, int, int], v2: Tuple[int, int, int]) -> int:
"""
Compute weighted distance: major * 1e6 + minor * 1e3 + patch
Ensures lexicographic ordering.
"""
return abs((v1[0] - v2[0]) * 1_000_000 + (v1[1] - v2[1]) * 1_000 + (v1[2] - v2[2]))
current_triton_version = get_triton_version()
current_parsed = parse_triton_version_tag(current_triton_version)
if current_parsed is None:
logger.error("Unable to parse current Triton version. Triton may not be installed properly.")
possible_dirs = [
d
for d in os.listdir(autotune_configs_dir)
if d.startswith("triton_")
]
possible_dirs.sort()
else:
dirs_with_dist = []
for d in os.listdir(autotune_configs_dir):
if not d.startswith("triton_"):
continue
parsed = parse_triton_version_tag(d)
if parsed is not None:
dist = version_distance(parsed, current_parsed)
dirs_with_dist.append((dist, d))
else:
logger.debug(f"Skipping invalid version directory: {d}")
dirs_with_dist.sort(key=lambda x: x[0])
possible_dirs = [d for _, d in dirs_with_dist]
loaded = False
for triton_version in possible_dirs:
fallback_cache_file = os.path.join(
autotune_configs_dir,
triton_version,
get_current_device_name(),
self.kernel_name,
KernelConfigs.get_config_file_name(static_key),
)
if os.path.exists(fallback_cache_file):
try:
logger.warning(
f"Fallback loading cached configs for {self.kernel_name} - {static_key} "
f"from triton version {triton_version} (current: {current_triton_version})"
)
with open(fallback_cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
loaded = True
break
except Exception as e:
logger.error(f"Failed to load fallback config from {fallback_cache_file}: {e}")
if not loaded:
logger.info(f"No fallback config found for {self.kernel_name} - {static_key}")
默认: False (保持原有严格版本匹配模式);True: 启用版本容错机制,当config不存在时加载最接近版本的config