Skip to content

Commit e13e2a1

Browse files
authored
Autotuner (#1020)
1 parent a771412 commit e13e2a1

File tree

39 files changed

+2044
-17
lines changed

39 files changed

+2044
-17
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
4+
import gc
45
import copy
56
import json
67
import torch
@@ -24,8 +25,8 @@
2425
from lightllm.distributed.communication_op import dist_group_manager
2526
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
2627
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
27-
from lightllm.utils.envs_utils import set_model_init_status
28-
28+
from lightllm.utils.envs_utils import set_model_init_status, is_triton_autotune_enabled, disable_triton_autotune
29+
from lightllm.utils.infer_utils import post_empty_cache
2930

3031
logger = init_logger(__name__)
3132

@@ -100,6 +101,7 @@ def __init__(self, kvargs):
100101
self._init_some_value()
101102
self._init_custom()
102103
self._init_inferstate_cls()
104+
self._autotune_warmup()
103105
self._init_padded_req()
104106
self._init_cudagraph()
105107
self._check_max_len_infer()
@@ -721,6 +723,79 @@ def _check_max_len_infer(self):
721723
raise Exception(exception_str)
722724
return
723725

726+
def autotune_layers(self):
727+
# 控制autotune的层数,用于适配不同模型
728+
return self.config.get("first_k_dense_replace", 0) + 1
729+
730+
@final
731+
@torch.no_grad()
732+
@post_empty_cache
733+
def _autotune_warmup(self):
734+
if not is_triton_autotune_enabled():
735+
return
736+
737+
torch.distributed.barrier()
738+
739+
warmup_lengths = [1, 8, 16, 64, 128, 256, 1024, 2048, 4096]
740+
741+
if self.batch_max_tokens not in warmup_lengths:
742+
warmup_lengths.append(self.batch_max_tokens)
743+
744+
warmup_lengths = [e for e in warmup_lengths if e <= self.batch_max_tokens]
745+
746+
warmup_lengths.sort(reverse=True)
747+
748+
layer_num_bak = self.layers_num
749+
self.layers_num = self.autotune_layers()
750+
for input_len in warmup_lengths:
751+
try:
752+
logger.info(f"autotune warmup for length {input_len}")
753+
rand_gen = torch.Generator(device="cuda")
754+
rand_gen.manual_seed(input_len)
755+
dummy_input_ids = torch.randint(
756+
0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen
757+
)
758+
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
759+
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
760+
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
761+
b_seq_len[:] = input_len
762+
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
763+
total_token_num = input_len
764+
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
765+
model_input = ModelInput(
766+
batch_size=1,
767+
total_token_num=total_token_num,
768+
max_len_in_batch=input_len,
769+
input_ids=dummy_input_ids,
770+
mem_indexes=mem_indexes,
771+
b_req_idx=b_req_idx,
772+
b_seq_len=b_seq_len,
773+
b_mtp_index=b_mtp_index,
774+
is_prefill=True,
775+
b_ready_cache_len=b_ready_cache_len,
776+
multimodal_params=[],
777+
**self._gen_special_model_input(total_token_num),
778+
)
779+
model_output = self.forward(
780+
model_input,
781+
)
782+
del model_output
783+
self.req_manager.free_all()
784+
self.mem_manager.free_all()
785+
gc.collect()
786+
torch.cuda.empty_cache()
787+
logger.info(f"autotune warmup for length {input_len} ok")
788+
except Exception as e:
789+
logger.warning(f"autotune warmup for length {input_len} failed: {str(e)}")
790+
logger.exception(str(e))
791+
self.req_manager.free_all()
792+
self.mem_manager.free_all()
793+
gc.collect()
794+
torch.cuda.empty_cache()
795+
self.layers_num = layer_num_bak
796+
torch.distributed.barrier()
797+
disable_triton_autotune()
798+
724799
@final
725800
@torch.no_grad()
726801
def _init_padded_req(self):

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1919
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
20+
from lightllm.utils.envs_utils import is_triton_autotune_enabled
2021
from lightllm.utils.log_utils import init_logger
2122

2223
logger = init_logger(__name__)
@@ -353,6 +354,15 @@ def prefilled_group_gemm(
353354
)
354355
# gather and local reduce
355356
ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out)
357+
else:
358+
######################################## warning ##################################################
359+
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
360+
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
361+
if is_triton_autotune_enabled():
362+
_gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype)
363+
_silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype)
364+
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)
365+
_gemm_out_a, _silu_out = None, None
356366

357367
return gather_out
358368

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .moe_sum_reduce import moe_sum_reduce
3636
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
3737
from lightllm.utils.torch_ops_utils import direct_register_custom_op
38+
from lightllm.common.triton_utils.autotuner import autotune
3839

3940
FFN_MOE_CHUNK_SIZE = 32 * 1024
4041

@@ -449,6 +450,51 @@ def grouped_matmul_kernel(
449450
return
450451

451452

453+
def _get_grouped_matmul_static_key(
454+
expert_weights: torch.Tensor,
455+
topk_num: int,
456+
out: torch.Tensor,
457+
mul_routed_weight: bool,
458+
use_fp8_w8a8: bool,
459+
) -> dict:
460+
expert_num, n, k = expert_weights.shape
461+
return {
462+
"N": n,
463+
"K": k,
464+
"topk_num": topk_num,
465+
"expert_num": expert_num,
466+
"mul_routed_weight": mul_routed_weight,
467+
"use_fp8_w8a8": use_fp8_w8a8,
468+
"out_dtype": str(out.dtype),
469+
}
470+
471+
472+
def _get_grouped_matmul_configs():
473+
return [
474+
{
475+
"BLOCK_SIZE_M": bm,
476+
"BLOCK_SIZE_N": bn,
477+
"BLOCK_SIZE_K": bk,
478+
"GROUP_SIZE_M": gm,
479+
"num_warps": nw,
480+
"num_stages": ns,
481+
}
482+
for ns in [1, 2, 3, 4, 5]
483+
for gm in [1, 2, 4, 8]
484+
for nw in [2, 4, 8]
485+
for bm in [16, 32, 64, 128]
486+
for bn in [16, 32, 64, 128]
487+
for bk in [16, 32, 64, 128]
488+
]
489+
490+
491+
@autotune(
492+
kernel_name="grouped_matmul:v1",
493+
configs_gen_func=_get_grouped_matmul_configs,
494+
static_key_func=_get_grouped_matmul_static_key,
495+
run_key_func=lambda token_inputs: token_inputs.shape[0],
496+
mutates_args=["out"],
497+
)
452498
def grouped_matmul(
453499
token_num_mul_topk_num: int,
454500
token_inputs: torch.Tensor,

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1616
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
17+
from lightllm.utils.envs_utils import is_triton_autotune_enabled
1718
import numpy as np
1819

1920
logger = init_logger(__name__)
@@ -186,6 +187,16 @@ def fused_experts_impl(
186187

187188
# gather and local reduce
188189
ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out)
190+
else:
191+
######################################## warning ##################################################
192+
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
193+
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
194+
if is_triton_autotune_enabled():
195+
_gemm_out_a = torch.zeros((1, N), device=hidden_states.device, dtype=hidden_states.dtype)
196+
_silu_out = torch.zeros((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
197+
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)
198+
_gemm_out_a, _silu_out = None, None
199+
189200
# normal combine
190201
combined_x, _, event = buffer.combine(
191202
gather_out,

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import triton
44
import triton.language as tl
55
from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
6+
from lightllm.common.triton_utils.autotuner import autotune
67

78

89
@triton.jit
@@ -62,7 +63,28 @@ def _silu_and_mul_kernel_fast(
6263
)
6364

6465

65-
def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
66+
def _get_silu_and_mul_configs():
67+
return [
68+
{"BLOCK_M": bm, "BLOCK_N": bn, "num_warps": nw, "NUM_STAGES": ns}
69+
for ns in [1, 2, 4]
70+
for nw in [1, 4, 8]
71+
for bm in [32, 64, 128, 256]
72+
for bn in [32, 64, 128, 256]
73+
]
74+
75+
76+
def _get_silu_and_mul_static_key(input: torch.Tensor, output: torch.Tensor):
77+
return {"N": input.shape[-1] // 2, "out_dtype": str(output.dtype)}
78+
79+
80+
@autotune(
81+
kernel_name="silu_and_mul_fwd:v1",
82+
configs_gen_func=_get_silu_and_mul_configs,
83+
static_key_func=_get_silu_and_mul_static_key,
84+
run_key_func=lambda input: input.shape[0],
85+
mutates_args=["output"],
86+
)
87+
def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, run_config=None):
6688
assert input.is_contiguous()
6789
assert output.is_contiguous()
6890

lightllm/common/fused_moe/moe_sum_reduce.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import triton
44
import triton.language as tl
55
from .moe_sum_recude_config import MoeSumReduceKernelConfig
6+
from typing import Any, Callable, Dict, Optional, Tuple
7+
from lightllm.common.triton_utils.autotuner import autotune
68

79

810
@triton.jit
@@ -46,7 +48,28 @@ def _moe_sum_reduce_kernel(
4648
tl.store(store_t_ptr, accumulator.to(input_ptr.dtype.element_ty), mask=offs_dim < dim_end)
4749

4850

49-
def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, **run_config):
51+
def _get_moe_sum_reduce_static_key(input: torch.Tensor, output: torch.Tensor):
52+
return {"topk_num": input.shape[1], "hidden_dim": input.shape[2], "out_dtype": str(output.dtype)}
53+
54+
55+
def _get_moe_sum_reduce_configs():
56+
return [
57+
{"BLOCK_M": bm, "BLOCK_DIM": bd, "NUM_STAGE": ns, "num_warps": nw}
58+
for ns in [1, 2, 4]
59+
for nw in [1, 2, 4, 8, 16]
60+
for bm in [1, 2, 4, 8, 16, 32]
61+
for bd in [64, 128, 256, 512, 1024]
62+
]
63+
64+
65+
@autotune(
66+
kernel_name="moe_sum_reduce:v1",
67+
configs_gen_func=_get_moe_sum_reduce_configs,
68+
static_key_func=_get_moe_sum_reduce_static_key,
69+
run_key_func=lambda input: input.shape[0],
70+
mutates_args=["output"],
71+
)
72+
def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = None):
5073
assert input.is_contiguous()
5174
assert output.is_contiguous()
5275

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import lru_cache
88
from typing import Any, Dict, List, Optional, Tuple
99
from triton import Config
10+
from lightllm.common.triton_utils.autotuner import autotune
1011

1112

1213
class Fp8BlockMMKernelConfig(KernelConfigs):
@@ -142,6 +143,46 @@ def _block_scaled_block_gemm(
142143
tl.store(c_ptrs, acc, mask=mask)
143144

144145

146+
def get_test_configs():
147+
fp8_gemm_configs = [
148+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 3, "num_warps": 8},
149+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
150+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
151+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
152+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
153+
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
154+
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2},
155+
{"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2},
156+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 3, "num_warps": 8},
157+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 3, "num_warps": 8},
158+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
159+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
160+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
161+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
162+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
163+
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4},
164+
]
165+
return fp8_gemm_configs
166+
167+
168+
def _get_static_key(A, B, block_size, dtype):
169+
M, K = A.shape
170+
_, N = B.shape
171+
return {
172+
"N": N,
173+
"K": K,
174+
"block_size": block_size,
175+
"out_dtype": str(dtype),
176+
}
177+
178+
179+
@autotune(
180+
kernel_name="w8a8_block_fp8_matmul:v1",
181+
configs_gen_func=get_test_configs,
182+
static_key_func=_get_static_key,
183+
run_key_func=lambda A: A.shape[0],
184+
mutates_args=["C"],
185+
)
145186
def w8a8_block_fp8_matmul(
146187
A: torch.Tensor,
147188
B: torch.Tensor,
@@ -150,7 +191,7 @@ def w8a8_block_fp8_matmul(
150191
C: torch.Tensor,
151192
block_size: List[int],
152193
dtype: torch.dtype = torch.bfloat16,
153-
**run_config,
194+
run_config=None,
154195
) -> torch.Tensor:
155196
"""w8a8fp8 block-wise quantization mm.
156197
@@ -174,7 +215,9 @@ def w8a8_block_fp8_matmul(
174215
assert triton.cdiv(K, block_k) == Ascale.shape[-1] and Ascale.shape[-1] == Bscale.shape[0]
175216
assert triton.cdiv(N, block_n) == Bscale.shape[1]
176217
if not run_config:
177-
run_config = Fp8BlockMMKernelConfig.try_to_get_best_config(M, N, K, block_size, dtype)
218+
run_config = Fp8BlockMMKernelConfig.try_to_get_best_config(
219+
M=M, N=N, K=K, block_size=block_size, out_dtype=dtype
220+
)
178221
grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),)
179222
_block_scaled_block_gemm[grid](
180223
A,

lightllm/common/triton_utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)