Skip to content

Conversation

@flyinglandlord
Copy link
Collaborator

  1. 新增启动参数: --enable-kernel-config-fallback (bool)
    默认: False (保持原有严格版本匹配模式);True: 启用版本容错机制,当config不存在时加载最接近版本的config
  2. 最接近版本通过os.listdir获取autotune_kernel_configs下面的triton版本路径获得,然后将路径中的版本号抽出,按照与目标版本的差值排序获得
  3. 依次扫描排序后的路径下是否存在相应的配置文件,如果存在则加载,若都不存在则报错

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @flyinglandlord, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

本次拉取请求引入了一项新功能,旨在增强 Triton 内核配置加载的鲁棒性。它增加了一个命令行参数来启用回退机制,当系统找不到与当前 Triton 版本精确匹配的内核配置时,将尝试从最接近的兼容 Triton 版本加载配置。这有助于避免因 Triton 版本不完全匹配预调优配置而导致的错误,从而提高系统的稳定性和兼容性。

Highlights

  • 新增启动参数: 新增了一个启动参数 --enable-kernel-config-fallback,允许用户选择是否启用 Triton 内核配置的容错机制。
  • 智能回退逻辑: 实现了 Triton 内核配置的智能回退逻辑,当找不到与当前 Triton 版本完全匹配的配置时,系统将自动查找并加载最接近版本的配置。
  • 版本匹配机制: 回退机制通过解析 autotune_kernel_configs 目录下的 Triton 版本标签,计算与当前 Triton 版本的“距离”,并按距离排序,以确保加载最兼容的配置。
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

这个PR实现了一个很好的功能,当找不到当前Triton版本的内核配置时,可以回退使用最接近版本的配置。代码逻辑整体上是正确的,并且通过命令行参数控制该行为的开关,设计得很好。我在 autotuner.py 中发现了一些可以优化代码可读性和减少重复的地方,并提供了一个代码建议。请查看具体的审查评论。

Comment on lines +222 to +291
elif get_env_start_args().enable_kernel_config_fallback:

def parse_triton_version_tag(tag: str) -> Optional[Tuple[int, int, int]]:
"""
Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0.
Returns None if invalid.
"""
match = re.match(r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$", tag)
if not match:
return None
x, y, z = match.groups()
return (int(x), int(y), int(z) if z is not None else 0)

def version_distance(v1: Tuple[int, int, int], v2: Tuple[int, int, int]) -> int:
"""
Compute weighted distance: major * 1e6 + minor * 1e3 + patch
Ensures lexicographic ordering.
"""
return abs((v1[0] - v2[0]) * 1_000_000 + (v1[1] - v2[1]) * 1_000 + (v1[2] - v2[2]))

current_triton_version = get_triton_version()
current_parsed = parse_triton_version_tag(current_triton_version)
if current_parsed is None:
logger.error("Unable to parse current Triton version. Triton may not be installed properly.")
possible_dirs = [
d
for d in os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs"))
if d.startswith("triton_")
]
possible_dirs.sort()
else:
config_dir = os.path.join(Path(__file__).parent, "autotune_kernel_configs")
possible_dirs = []
for d in os.listdir(config_dir):
if not d.startswith("triton_"):
continue
parsed = parse_triton_version_tag(d)
if parsed is not None:
dist = version_distance(parsed, current_parsed)
possible_dirs.append((dist, d, parsed))
else:
logger.debug(f"Skipping invalid version directory: {d}")
possible_dirs.sort(key=lambda x: x[0])
possible_dirs = [d for _, d, _ in possible_dirs]

loaded = False
for triton_version in possible_dirs:
fallback_cache_file = os.path.join(
Path(__file__).parent,
"autotune_kernel_configs",
triton_version,
get_current_device_name(),
self.kernel_name,
KernelConfigs.get_config_file_name(static_key),
)
if os.path.exists(fallback_cache_file):
try:
logger.warning(
f"Fallback loading cached configs for {self.kernel_name} - {static_key} "
f"from triton version {triton_version} (current: {current_triton_version})"
)
with open(fallback_cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
loaded = True
break
except Exception as e:
logger.error(f"Failed to load fallback config from {fallback_cache_file}: {e}")

if not loaded:
logger.info(f"No fallback config found for {self.kernel_name} - {static_key}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这部分的回退逻辑写得很好,但有几个地方可以优化以提高代码的可读性和可维护性:

  1. autotune_kernel_configs 目录的路径在代码中被多次拼接,可以将其提取到一个变量中以避免重复。
  2. 在寻找最接近版本时,possible_dirs 列表中存储了 (dist, d, parsed),但 parsed 变量在之后没有被使用,这造成了轻微的冗余。

我已经将这些优化合并到了下面的代码建议中,它重构了这部分逻辑,使其更简洁。

        elif get_env_start_args().enable_kernel_config_fallback:
            autotune_configs_dir = os.path.join(Path(__file__).parent, "autotune_kernel_configs")

            def parse_triton_version_tag(tag: str) -> Optional[Tuple[int, int, int]]:
                """
                Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0.
                Returns None if invalid.
                """
                match = re.match(r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$", tag)
                if not match:
                    return None
                x, y, z = match.groups()
                return (int(x), int(y), int(z) if z is not None else 0)

            def version_distance(v1: Tuple[int, int, int], v2: Tuple[int, int, int]) -> int:
                """
                Compute weighted distance: major * 1e6 + minor * 1e3 + patch
                Ensures lexicographic ordering.
                """
                return abs((v1[0] - v2[0]) * 1_000_000 + (v1[1] - v2[1]) * 1_000 + (v1[2] - v2[2]))

            current_triton_version = get_triton_version()
            current_parsed = parse_triton_version_tag(current_triton_version)
            if current_parsed is None:
                logger.error("Unable to parse current Triton version. Triton may not be installed properly.")
                possible_dirs = [
                    d
                    for d in os.listdir(autotune_configs_dir)
                    if d.startswith("triton_")
                ]
                possible_dirs.sort()
            else:
                dirs_with_dist = []
                for d in os.listdir(autotune_configs_dir):
                    if not d.startswith("triton_"):
                        continue
                    parsed = parse_triton_version_tag(d)
                    if parsed is not None:
                        dist = version_distance(parsed, current_parsed)
                        dirs_with_dist.append((dist, d))
                    else:
                        logger.debug(f"Skipping invalid version directory: {d}")
                dirs_with_dist.sort(key=lambda x: x[0])
                possible_dirs = [d for _, d in dirs_with_dist]

            loaded = False
            for triton_version in possible_dirs:
                fallback_cache_file = os.path.join(
                    autotune_configs_dir,
                    triton_version,
                    get_current_device_name(),
                    self.kernel_name,
                    KernelConfigs.get_config_file_name(static_key),
                )
                if os.path.exists(fallback_cache_file):
                    try:
                        logger.warning(
                            f"Fallback loading cached configs for {self.kernel_name} - {static_key} "
                            f"from triton version {triton_version} (current: {current_triton_version})"
                        )
                        with open(fallback_cache_file, "rb") as f:
                            self.cached_configs[static_key] = orjson.loads(f.read())
                        loaded = True
                        break
                    except Exception as e:
                        logger.error(f"Failed to load fallback config from {fallback_cache_file}: {e}")

            if not loaded:
                logger.info(f"No fallback config found for {self.kernel_name} - {static_key}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant