Skip to content

Commit b69692a

Browse files
[Kernel] LoRA - Refactor sgmv kernels (#13110)
1 parent a64a844 commit b69692a

File tree

3 files changed

+327
-129
lines changed

3 files changed

+327
-129
lines changed
+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Utilities for Punica kernel construction.
4+
"""
5+
import triton
6+
import triton.language as tl
7+
8+
9+
@triton.jit
10+
def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr,
11+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
12+
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr,
13+
b_dtype: tl.constexpr):
14+
"""
15+
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
16+
B (k x n), iterate, through the K dimension to compute the partial/complete
17+
matrix block product.
18+
If SPLIT_K == 1, the output m x n product is complete.
19+
If SPLIT_K > 1, the thread block computes partial outputs. The partial
20+
outputs are then atomically summed in the caller code.
21+
Args:
22+
a_ptr: Array of pointers, identifying rows of A
23+
b_ptr: Array of pointers, identifying columns of B
24+
ak_stride: K dimension stride of the A matrix
25+
bk_stride: K dimension stride of the B matrix
26+
K: Length of the K dimension
27+
BLOCK_M: M dimension of the output block m x n
28+
BLOCK_N: N dimension of the output block m x n
29+
BLOCK_K: K dimension atom
30+
EVEN_K: True if the blocks of A and B can be loaded without any
31+
masking.
32+
SPLIT_K: Parameter signifying parallelism in the K dimension.
33+
CAST_TYPE: if True, cast the values from the A matrix to the B
34+
matrix dtype.
35+
b_dtype: datatype of the B matrix
36+
"""
37+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
38+
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
39+
if EVEN_K:
40+
tiled_a = tl.load(a_ptr)
41+
tiled_b = tl.load(b_ptr)
42+
else:
43+
tiled_a = tl.load(a_ptr,
44+
mask=offset_k[None, :]
45+
< K - k * (BLOCK_K * SPLIT_K),
46+
other=0)
47+
tiled_b = tl.load(b_ptr,
48+
mask=offset_k[:, None]
49+
< K - k * (BLOCK_K * SPLIT_K),
50+
other=0)
51+
if CAST_TYPE:
52+
tiled_a = tiled_a.to(b_dtype)
53+
accumulator += tl.dot(
54+
tiled_a,
55+
tiled_b,
56+
)
57+
a_ptr += BLOCK_K * SPLIT_K * ak_stride
58+
b_ptr += BLOCK_K * SPLIT_K * bk_stride
59+
return accumulator
60+
61+
62+
@triton.jit
63+
def do_expand_kernel(
64+
pid_n,
65+
lora_index,
66+
slice_id,
67+
input_ptr,
68+
lora_ptr,
69+
out_ptr,
70+
N,
71+
K,
72+
M_LEN,
73+
ram, # array identifying the rows of Input ptr to operate on
74+
slice_start_loc,
75+
# input ptr strides
76+
input_d0_stride,
77+
input_d1_stride,
78+
input_d2_stride,
79+
# lora ptr strides
80+
ls_d0_ptr,
81+
ls_d1_ptr,
82+
ls_d2_ptr,
83+
# out ptr strides
84+
output_d0_stride,
85+
output_d1_stride,
86+
# constants
87+
BLOCK_M: tl.constexpr,
88+
BLOCK_N: tl.constexpr,
89+
BLOCK_K: tl.constexpr,
90+
SAME_STRIDE: tl.constexpr,
91+
SLICE_NUM: tl.constexpr,
92+
EVEN_K: tl.constexpr,
93+
CAST_TYPE: tl.constexpr,
94+
ADD_INPUTS: tl.constexpr,
95+
):
96+
"""
97+
Given an array of integers that identifies the rows of A, ram,
98+
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
99+
a slice_id that identifies the input/output slice,
100+
compute the matrix product and store in the appropriate output location.
101+
Given that this is an expand kernel, we don't perform any split-K reduction
102+
as the K dimension is assumed to be small.
103+
"""
104+
105+
# ls_d*_ptr can be either an integer or a pointer
106+
if SAME_STRIDE:
107+
# integer
108+
cur_lora_d0_stride = ls_d0_ptr
109+
cur_lora_d1_stride = ls_d1_ptr
110+
cur_lora_d2_stride = ls_d2_ptr
111+
else:
112+
# pointer
113+
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
114+
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
115+
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
116+
117+
# Identify the input_ptr and lora_ptr from slice_id.
118+
if SLICE_NUM == 1:
119+
cur_input_ptr = input_ptr
120+
cur_lora_ptr = lora_ptr
121+
else:
122+
cur_input_ptr = input_ptr + slice_id * input_d0_stride
123+
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
124+
tl.pointer_type(out_ptr.dtype.element_ty))
125+
126+
# Identify the column indices of B to process.
127+
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
128+
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
129+
130+
# Identify A and B block pointers
131+
offset_k = tl.arange(0, BLOCK_K)
132+
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
133+
offset_k[None, :] * input_d2_stride, )
134+
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
135+
offset_k[:, None] * cur_lora_d2_stride +
136+
rbn[None, :] * cur_lora_d1_stride)
137+
138+
# Compute the block matrix product.
139+
SPLIT_K = 1
140+
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride,
141+
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K,
142+
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
143+
144+
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
145+
if SLICE_NUM == 1:
146+
cur_slice_start = slice_start_loc
147+
else:
148+
cur_slice_start = tl.load(slice_start_loc + slice_id)
149+
150+
# Identify the C output pointers to store the results of the accumulator.
151+
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
152+
offset_cm = tl.arange(0, BLOCK_M)
153+
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
154+
offset_cn[None, :] * output_d1_stride)
155+
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
156+
< (cur_slice_start + N))
157+
158+
if ADD_INPUTS:
159+
tiled_out = tl.load(c_ptr, mask=c_mask)
160+
tiled_c += tiled_out
161+
tl.store(c_ptr, tiled_c, mask=c_mask)
162+
163+
164+
@triton.jit
165+
def do_shrink_kernel(
166+
pid_n,
167+
pid_sk,
168+
slice_id,
169+
lora_index,
170+
input_ptr,
171+
lora_ptr,
172+
out_ptr,
173+
N,
174+
K,
175+
M_LEN,
176+
ram,
177+
# input strides
178+
input_d0_stride,
179+
input_d1_stride,
180+
# lora strides
181+
lora_d0_stride,
182+
lora_d1_stride,
183+
lora_d2_stride,
184+
# output strides
185+
output_d0_stride,
186+
output_d1_stride,
187+
output_d2_stride,
188+
scaling,
189+
BLOCK_M: tl.constexpr,
190+
BLOCK_N: tl.constexpr,
191+
BLOCK_K: tl.constexpr,
192+
EVEN_K: tl.constexpr,
193+
SPLIT_K: tl.constexpr,
194+
SLICE_NUM: tl.constexpr,
195+
):
196+
"""
197+
Given an array of integers that identifies the rows of A, ram,
198+
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
199+
a slice_id that identifies the input/output slice, compute the
200+
matrix product and store in the appropriate output location.
201+
"""
202+
203+
# Identify the lora_ptr from slice_id.
204+
if SLICE_NUM == 1:
205+
# current lora ptr
206+
cur_lora_ptr = lora_ptr
207+
else:
208+
# current lora ptr
209+
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
210+
tl.pointer_type(input_ptr.dtype.element_ty))
211+
212+
# Identify the column indices of B to process.
213+
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
214+
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
215+
216+
# Identify A and B block pointers
217+
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
218+
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
219+
offset_k[None, :] * input_d1_stride)
220+
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
221+
rbn[None, :] * lora_d1_stride +
222+
offset_k[:, None] * lora_d2_stride)
223+
224+
# Compute partial/complete block matrix product.
225+
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k,
226+
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False,
227+
cur_lora_ptr.dtype.element_ty)
228+
229+
# Identify the C output pointers to store the results of the accumulator.
230+
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
231+
offset_cm = tl.arange(0, BLOCK_M)
232+
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
233+
slice_id * output_d0_stride)
234+
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
235+
None, :] * output_d2_stride
236+
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
237+
238+
accumulator *= scaling
239+
# handles write-back with reduction-splitting
240+
if SPLIT_K == 1:
241+
tl.store(c_ptr, accumulator, mask=c_mask)
242+
else:
243+
tl.atomic_add(c_ptr, accumulator, mask=c_mask)

vllm/lora/ops/triton_ops/sgmv_expand.py

+44-73
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from vllm.utils import direct_register_custom_op
1616

17+
from .kernel_utils import do_expand_kernel
1718
from .utils import _get_lora_b_ptr
1819

1920

@@ -63,86 +64,56 @@ def _sgmv_expand_kernel(
6364
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
6465
pid_m = pid // cta_n_num
6566
pid_n = pid % cta_n_num
67+
6668
M = tl.load(seq_lens + cur_batch)
67-
if pid_m * BLOCK_M > M:
69+
if pid_m * BLOCK_M >= M:
6870
return
69-
if pid_n * BLOCK_N > curr_N:
71+
if pid_n * BLOCK_N >= curr_N:
7072
return
7173
lora_index = tl.load(lora_indices + cur_batch)
7274
if lora_index == -1:
7375
return
7476

75-
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
76-
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
77-
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
78-
offset_k = tl.arange(0, BLOCK_K)
79-
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
80-
rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N),
81-
BLOCK_N)
82-
# ls_d*_ptr can be either an integer or a pointer
83-
if SAME_STRIDE:
84-
# integer
85-
cur_lora_d0_stride = ls_d0_ptr
86-
cur_lora_d1_stride = ls_d1_ptr
87-
cur_lora_d2_stride = ls_d2_ptr
88-
else:
89-
# pointer
90-
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
91-
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
92-
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
93-
if SLICE_NUM == 1:
94-
cur_input_ptr = input_ptr
95-
cur_lora_ptr = lora_ptr
96-
97-
else:
98-
cur_input_ptr = input_ptr + slice_id * input_d0_stride
99-
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
100-
tl.pointer_type(out_ptr.dtype.element_ty))
101-
102-
a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride +
103-
ram[:, None] * input_d1_stride +
104-
offset_k[None, :] * input_d2_stride, )
105-
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
106-
offset_k[:, None] * cur_lora_d2_stride +
107-
rbn[None, :] * cur_lora_d1_stride)
108-
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
109-
for k in range(tl.cdiv(K, BLOCK_K)):
110-
if EVEN_K:
111-
tiled_a = tl.load(a_ptr)
112-
tiled_b = tl.load(b_ptr)
113-
else:
114-
tiled_a = tl.load(a_ptr,
115-
mask=offset_k[None, :] < K - k * BLOCK_K,
116-
other=0)
117-
tiled_b = tl.load(b_ptr,
118-
mask=offset_k[:, None] < K - k * BLOCK_K,
119-
other=0)
120-
if CAST_TYPE:
121-
tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty)
122-
accumulator += tl.dot(
123-
tiled_a,
124-
tiled_b,
125-
)
126-
a_ptr += BLOCK_K * input_d2_stride
127-
b_ptr += BLOCK_K * cur_lora_d2_stride
128-
129-
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
130-
if SLICE_NUM == 1:
131-
cur_slice_start = slice_start_loc
132-
else:
133-
cur_slice_start = tl.load(slice_start_loc + slice_id)
134-
135-
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
136-
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
137-
c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride +
138-
offset_cn[None, :] * output_d1_stride)
139-
M = tl.load(seq_lens + cur_batch)
140-
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (
141-
offset_cn[None, :] < (cur_slice_start + curr_N))
142-
if ADD_INPUTS:
143-
tiled_out = tl.load(c_ptr, mask=c_mask)
144-
tiled_c += tiled_out
145-
tl.store(c_ptr, tiled_c, mask=c_mask)
77+
m_offset = tl.load(b_seq_start_loc + cur_batch)
78+
79+
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
80+
cta_m_offset = m_offset + (pid_m * BLOCK_M)
81+
offset_m = tl.arange(0, BLOCK_M)
82+
ram = cta_m_offset + tl.max_contiguous(
83+
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
84+
do_expand_kernel(
85+
pid_n,
86+
lora_index,
87+
slice_id,
88+
input_ptr,
89+
lora_ptr,
90+
out_ptr,
91+
curr_N,
92+
K,
93+
cta_m_len,
94+
ram, # array identifying the rows of Input ptr to operate on
95+
slice_start_loc,
96+
# input ptr strides
97+
input_d0_stride,
98+
input_d1_stride,
99+
input_d2_stride,
100+
# lora ptr strides
101+
ls_d0_ptr,
102+
ls_d1_ptr,
103+
ls_d2_ptr,
104+
# out ptr strides
105+
output_d0_stride,
106+
output_d1_stride,
107+
# constants
108+
BLOCK_M,
109+
BLOCK_N,
110+
BLOCK_K,
111+
SAME_STRIDE,
112+
SLICE_NUM,
113+
EVEN_K,
114+
CAST_TYPE,
115+
ADD_INPUTS,
116+
)
146117

147118

148119
@torch.inference_mode()

0 commit comments

Comments
 (0)