Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from lightllm.utils.infer_utils import post_empty_cache
from .attention import get_prefill_att_backend_class, get_decode_att_backend_class
from .attention import BaseAttBackend
from lightllm.utils.profiler import GlobalPerfContext

logger = init_logger(__name__)

Expand Down Expand Up @@ -255,6 +256,7 @@ def _init_att_backend1(self):
self.decode_att_backend1: BaseAttBackend = None
return

@GlobalPerfContext.disable()
def _init_cudagraph(self):
self.graph = (
None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch)
Expand Down Expand Up @@ -552,6 +554,7 @@ def _decode(

@final
def _context_forward(self, infer_state: InferStateInfo):
GlobalPerfContext.begin_with_sample_rate(sample_rate=1)
run_mode_index = 1 if self.enable_tpsp_mix_mode else 0
input_ids = infer_state.input_ids
cuda_input_ids = input_ids
Expand Down Expand Up @@ -602,10 +605,13 @@ def prefill_func(input_tensors, infer_state):
# 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
# 该调用没有实际意义
dist_group_manager.clear_deepep_buffer()
rank = torch.cuda.current_device()
GlobalPerfContext.finalize_async(marker=f"_context_forward_{rank}", save_jsonl=True, save_log=True)
return model_output

@final
def _token_forward(self, infer_state: InferStateInfo):
GlobalPerfContext.begin_with_sample_rate(sample_rate=0.05)
run_mode_index = 1 if self.enable_tpsp_mix_mode else 0
input_ids = infer_state.input_ids
cuda_input_ids = input_ids
Expand All @@ -632,6 +638,13 @@ def _token_forward(self, infer_state: InferStateInfo):
if infer_state.is_cuda_graph:
model_output.to_no_ref_tensor()

rank = torch.cuda.current_device()
GlobalPerfContext.finalize_async(marker=f"_token_forward_{rank}", save_jsonl=True, save_log=True)

# import time
# ts = time.time()
# PerfCounterContext.finalize_print(marker=f"_token_forward_{rank}", save=f"_token_forward_{rank}.jsonl")
# print(f"PerfCounterContext.finalize_print took {(time.time() - ts) * 1000:.3f} ms")
return model_output

@torch.no_grad()
Expand Down
4 changes: 3 additions & 1 deletion lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from lightllm.utils.profiler import GlobalPerfContext
from .infer_struct import InferStateInfo


Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}")

def can_run(self, batch_size, max_len_in_batch):
return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch
do_profile = GlobalPerfContext.cudagraph_helper(sample_rate=0.05)
return not do_profile and batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch

def need_capture(self, batch_size):
find_batch_size = self.find_closest_graph_batch_size(batch_size)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import torch
import torch.distributed as dist

from lightllm.utils.profiler import PerfCounter
from ..transformer_layer_infer import TransformerLayerInfer
from ...infer_struct import InferStateInfo
from lightllm.distributed import all_reduce
Expand Down Expand Up @@ -62,6 +64,7 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor
def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
raise Exception("need to impl")

@PerfCounter(type="LAYER")
def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
Expand All @@ -74,6 +77,7 @@ def context_attention_forward(self, input_embdings, infer_state: InferStateInfo,
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
return o

@PerfCounter(type="LAYER")
def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.context_attention_forward(input1, infer_state, layer_weight)
Expand All @@ -88,6 +92,7 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

@PerfCounter(type="LAYER")
def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
Expand All @@ -98,6 +103,7 @@ def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, l
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
return o

@PerfCounter(type="LAYER")
def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.token_attention_forward(input1, infer_state, layer_weight)
Expand All @@ -111,6 +117,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

@PerfCounter(type="LAYER")
def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
Expand All @@ -121,6 +128,7 @@ def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_sta
o = self._tpsp_get_o(o, infer_state, layer_weight)
return o

@PerfCounter(type="LAYER")
def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight)
Expand All @@ -133,6 +141,7 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

@PerfCounter(type="LAYER")
def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
Expand All @@ -141,6 +150,7 @@ def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state
o = self._tpsp_get_o(o, infer_state, layer_weight)
return o

@PerfCounter(type="LAYER")
def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lightllm.common.quantization.no_quant import NoQuantization
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.log_utils import init_logger
from lightllm.utils.profiler import PerfCounter
from .mm_slicer import SliceMixinTpl

logger = init_logger(__name__)
Expand Down Expand Up @@ -53,9 +54,11 @@ def __init__(
self._create_weight()
self.gen_weight_quant_param_names()

@PerfCounter(type="GEMM_OP")
def mm(
self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True
) -> torch.Tensor:
self.mm.record_shape(m=input_tensor.shape[0], k=input_tensor.shape[1], n=self.mm_param.weight.shape[1])
return self.quant_method.apply(
input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger, bias=self.bias
)
Expand Down Expand Up @@ -215,6 +218,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
def verify_load(self):
return self.weight.load_ok

@PerfCounter(type="GEMM_OP")
def bmm(
self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True
) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import triton.language as tl
from typing import Dict

from lightllm.utils.profiler import PerfCounter


@triton.jit
def _fwd_kernel_ep_scatter_1(
Expand Down Expand Up @@ -87,6 +89,7 @@ def _fwd_kernel_ep_scatter_2(
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)


@PerfCounter(type="COMM_OP")
@torch.no_grad()
def ep_scatter(
recv_x: torch.Tensor,
Expand All @@ -104,6 +107,7 @@ def ep_scatter(
num_warps = 8
num_experts = num_recv_tokens_per_expert.shape[0] # 获取num_recv_tokens_per_expert的元素个数
hidden_size = recv_x.shape[1]
ep_scatter.record_shape(size=recv_x.element_size() * recv_x.numel(), hidden_size=hidden_size, num_experts=num_experts, recv_x_shape=recv_x.shape, output_tensor_shape=output_tensor.shape)
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts

Expand Down Expand Up @@ -194,6 +198,7 @@ def _fwd_kernel_ep_gather(
)


@PerfCounter(type="COMM_OP")
@torch.no_grad()
def ep_gather(
input_tensor: torch.Tensor,
Expand All @@ -206,6 +211,7 @@ def ep_gather(
num_warps = 2
num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1]
ep_gather.record_shape(size=input_tensor.element_size() * input_tensor.numel(), hidden_size=hidden_size, num_tokens=num_tokens, input_tensor_shape=input_tensor.shape, output_tensor_shape=output_tensor.shape)
assert hidden_size % BLOCK_D == 0
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
_fwd_kernel_ep_gather[grid](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import triton.language as tl
from typing import Any, Callable, Dict, Optional, Tuple
from lightllm.utils.log_utils import init_logger
from lightllm.utils.profiler import PerfCounter
from lightllm.utils.vllm_utils import vllm_ops
from lightllm.utils.device_utils import triton_support_tensor_descriptor
from .moe_silu_and_mul import silu_and_mul_fwd
Expand Down Expand Up @@ -267,6 +268,7 @@ def _get_moe_align_fused_configs():
]


@PerfCounter("moe_align_fused", type="OTHER_OP")
@autotune(
kernel_name="moe_align_fused:v1",
configs_gen_func=_get_moe_align_fused_configs,
Expand Down Expand Up @@ -670,6 +672,7 @@ def _get_grouped_matmul_configs():
]


@PerfCounter("grouped_matmul", type="GEMM_OP")
@autotune(
kernel_name="grouped_matmul:v1",
configs_gen_func=_get_grouped_matmul_configs,
Expand Down Expand Up @@ -716,6 +719,9 @@ def grouped_matmul(
assert expert_to_weights.is_contiguous()
assert expert_weights.is_contiguous()

m_total = int(expert_to_token_num.sum().item())
grouped_matmul.record_shape(m=m_total, k=k, n=n)
Comment on lines +722 to +723
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The expert_to_token_num.sum().item() call is a synchronous operation that forces a GPU-to-CPU synchronization. This will significantly degrade performance, especially since it runs unconditionally even when profiling is disabled (the record_shape call becomes a no-op but the sum and item transfer still occur). It should be guarded by a check to is_perf_counter_active().

Suggested change
m_total = int(expert_to_token_num.sum().item())
grouped_matmul.record_shape(m=m_total, k=k, n=n)
if grouped_matmul.is_perf_counter_active():
m_total = int(expert_to_token_num.sum().item())
grouped_matmul.record_shape(m=m_total, k=k, n=n)


# for deepseek_v3 block-wise quant
block_size_n = 0
block_size_k = 0
Expand Down Expand Up @@ -1157,6 +1163,7 @@ def outplace_fused_experts_impl_fake(
)


@PerfCounter(type="BLOCK")
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down
Loading
Loading