Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import triton
import orjson
import os
Expand All @@ -11,8 +12,8 @@
from frozendict import frozendict
from lightllm.utils.device_utils import get_current_device_name
from lightllm.utils.log_utils import init_logger
from typing import Callable, Optional, Union, List
from lightllm.utils.envs_utils import get_triton_autotune_level
from typing import Callable, Optional, Tuple, Union, List
from lightllm.utils.envs_utils import get_env_start_args, get_triton_autotune_level
from lightllm.common.kernel_config import KernelConfigs
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node

Expand Down Expand Up @@ -218,6 +219,76 @@ def _try_load_cache(self, static_key):
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
with open(cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
elif get_env_start_args().enable_kernel_config_fallback:

def parse_triton_version_tag(tag: str) -> Optional[Tuple[int, int, int]]:
"""
Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0.
Returns None if invalid.
"""
match = re.match(r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$", tag)
if not match:
return None
x, y, z = match.groups()
return (int(x), int(y), int(z) if z is not None else 0)

def version_distance(v1: Tuple[int, int, int], v2: Tuple[int, int, int]) -> int:
"""
Compute weighted distance: major * 1e6 + minor * 1e3 + patch
Ensures lexicographic ordering.
"""
return abs((v1[0] - v2[0]) * 1_000_000 + (v1[1] - v2[1]) * 1_000 + (v1[2] - v2[2]))

current_triton_version = get_triton_version()
current_parsed = parse_triton_version_tag(current_triton_version)
if current_parsed is None:
logger.error("Unable to parse current Triton version. Triton may not be installed properly.")
possible_dirs = [
d
for d in os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs"))
if d.startswith("triton_")
]
possible_dirs.sort()
else:
config_dir = os.path.join(Path(__file__).parent, "autotune_kernel_configs")
possible_dirs = []
for d in os.listdir(config_dir):
if not d.startswith("triton_"):
continue
parsed = parse_triton_version_tag(d)
if parsed is not None:
dist = version_distance(parsed, current_parsed)
possible_dirs.append((dist, d, parsed))
else:
logger.debug(f"Skipping invalid version directory: {d}")
possible_dirs.sort(key=lambda x: x[0])
possible_dirs = [d for _, d, _ in possible_dirs]

loaded = False
for triton_version in possible_dirs:
fallback_cache_file = os.path.join(
Path(__file__).parent,
"autotune_kernel_configs",
triton_version,
get_current_device_name(),
self.kernel_name,
KernelConfigs.get_config_file_name(static_key),
)
if os.path.exists(fallback_cache_file):
try:
logger.warning(
f"Fallback loading cached configs for {self.kernel_name} - {static_key} "
f"from triton version {triton_version} (current: {current_triton_version})"
)
with open(fallback_cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
loaded = True
break
except Exception as e:
logger.error(f"Failed to load fallback config from {fallback_cache_file}: {e}")

if not loaded:
logger.info(f"No fallback config found for {self.kernel_name} - {static_key}")
Comment on lines +222 to +291
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这部分的回退逻辑写得很好,但有几个地方可以优化以提高代码的可读性和可维护性:

  1. autotune_kernel_configs 目录的路径在代码中被多次拼接,可以将其提取到一个变量中以避免重复。
  2. 在寻找最接近版本时,possible_dirs 列表中存储了 (dist, d, parsed),但 parsed 变量在之后没有被使用,这造成了轻微的冗余。

我已经将这些优化合并到了下面的代码建议中,它重构了这部分逻辑,使其更简洁。

        elif get_env_start_args().enable_kernel_config_fallback:
            autotune_configs_dir = os.path.join(Path(__file__).parent, "autotune_kernel_configs")

            def parse_triton_version_tag(tag: str) -> Optional[Tuple[int, int, int]]:
                """
                Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0.
                Returns None if invalid.
                """
                match = re.match(r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$", tag)
                if not match:
                    return None
                x, y, z = match.groups()
                return (int(x), int(y), int(z) if z is not None else 0)

            def version_distance(v1: Tuple[int, int, int], v2: Tuple[int, int, int]) -> int:
                """
                Compute weighted distance: major * 1e6 + minor * 1e3 + patch
                Ensures lexicographic ordering.
                """
                return abs((v1[0] - v2[0]) * 1_000_000 + (v1[1] - v2[1]) * 1_000 + (v1[2] - v2[2]))

            current_triton_version = get_triton_version()
            current_parsed = parse_triton_version_tag(current_triton_version)
            if current_parsed is None:
                logger.error("Unable to parse current Triton version. Triton may not be installed properly.")
                possible_dirs = [
                    d
                    for d in os.listdir(autotune_configs_dir)
                    if d.startswith("triton_")
                ]
                possible_dirs.sort()
            else:
                dirs_with_dist = []
                for d in os.listdir(autotune_configs_dir):
                    if not d.startswith("triton_"):
                        continue
                    parsed = parse_triton_version_tag(d)
                    if parsed is not None:
                        dist = version_distance(parsed, current_parsed)
                        dirs_with_dist.append((dist, d))
                    else:
                        logger.debug(f"Skipping invalid version directory: {d}")
                dirs_with_dist.sort(key=lambda x: x[0])
                possible_dirs = [d for _, d in dirs_with_dist]

            loaded = False
            for triton_version in possible_dirs:
                fallback_cache_file = os.path.join(
                    autotune_configs_dir,
                    triton_version,
                    get_current_device_name(),
                    self.kernel_name,
                    KernelConfigs.get_config_file_name(static_key),
                )
                if os.path.exists(fallback_cache_file):
                    try:
                        logger.warning(
                            f"Fallback loading cached configs for {self.kernel_name} - {static_key} "
                            f"from triton version {triton_version} (current: {current_triton_version})"
                        )
                        with open(fallback_cache_file, "rb") as f:
                            self.cached_configs[static_key] = orjson.loads(f.read())
                        loaded = True
                        break
                    except Exception as e:
                        logger.error(f"Failed to load fallback config from {fallback_cache_file}: {e}")

            if not loaded:
                logger.info(f"No fallback config found for {self.kernel_name} - {static_key}")

return True

def kernel_warmup(self, static_key, *args, **kwargs):
Expand Down
5 changes: 5 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
action="store_true",
help="""inference backend will use the fa3 attention kernel for prefill and decode""",
)
parser.add_argument(
"--enable_kernel_config_fallback",
action="store_true",
help="""Whether to enable kernel config fallback when triton version is not compatible.""",
)
parser.add_argument(
"--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources"
)
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,4 @@ class StartArgs:

# kernel setting
enable_fa3: bool = field(default=False)
enable_kernel_config_fallback: bool = field(default=False)