|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +""" |
| 3 | +Utilities for Punica kernel construction. |
| 4 | +""" |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | + |
| 8 | + |
| 9 | +@triton.jit |
| 10 | +def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, |
| 11 | + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| 12 | + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr, |
| 13 | + b_dtype: tl.constexpr): |
| 14 | + """ |
| 15 | + Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of |
| 16 | + B (k x n), iterate, through the K dimension to compute the partial/complete |
| 17 | + matrix block product. |
| 18 | + If SPLIT_K == 1, the output m x n product is complete. |
| 19 | + If SPLIT_K > 1, the thread block computes partial outputs. The partial |
| 20 | + outputs are then atomically summed in the caller code. |
| 21 | + Args: |
| 22 | + a_ptr: Array of pointers, identifying rows of A |
| 23 | + b_ptr: Array of pointers, identifying columns of B |
| 24 | + ak_stride: K dimension stride of the A matrix |
| 25 | + bk_stride: K dimension stride of the B matrix |
| 26 | + K: Length of the K dimension |
| 27 | + BLOCK_M: M dimension of the output block m x n |
| 28 | + BLOCK_N: N dimension of the output block m x n |
| 29 | + BLOCK_K: K dimension atom |
| 30 | + EVEN_K: True if the blocks of A and B can be loaded without any |
| 31 | + masking. |
| 32 | + SPLIT_K: Parameter signifying parallelism in the K dimension. |
| 33 | + CAST_TYPE: if True, cast the values from the A matrix to the B |
| 34 | + matrix dtype. |
| 35 | + b_dtype: datatype of the B matrix |
| 36 | + """ |
| 37 | + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| 38 | + for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)): |
| 39 | + if EVEN_K: |
| 40 | + tiled_a = tl.load(a_ptr) |
| 41 | + tiled_b = tl.load(b_ptr) |
| 42 | + else: |
| 43 | + tiled_a = tl.load(a_ptr, |
| 44 | + mask=offset_k[None, :] |
| 45 | + < K - k * (BLOCK_K * SPLIT_K), |
| 46 | + other=0) |
| 47 | + tiled_b = tl.load(b_ptr, |
| 48 | + mask=offset_k[:, None] |
| 49 | + < K - k * (BLOCK_K * SPLIT_K), |
| 50 | + other=0) |
| 51 | + if CAST_TYPE: |
| 52 | + tiled_a = tiled_a.to(b_dtype) |
| 53 | + accumulator += tl.dot( |
| 54 | + tiled_a, |
| 55 | + tiled_b, |
| 56 | + ) |
| 57 | + a_ptr += BLOCK_K * SPLIT_K * ak_stride |
| 58 | + b_ptr += BLOCK_K * SPLIT_K * bk_stride |
| 59 | + return accumulator |
| 60 | + |
| 61 | + |
| 62 | +@triton.jit |
| 63 | +def do_expand_kernel( |
| 64 | + pid_n, |
| 65 | + lora_index, |
| 66 | + slice_id, |
| 67 | + input_ptr, |
| 68 | + lora_ptr, |
| 69 | + out_ptr, |
| 70 | + N, |
| 71 | + K, |
| 72 | + M_LEN, |
| 73 | + ram, # array identifying the rows of Input ptr to operate on |
| 74 | + slice_start_loc, |
| 75 | + # input ptr strides |
| 76 | + input_d0_stride, |
| 77 | + input_d1_stride, |
| 78 | + input_d2_stride, |
| 79 | + # lora ptr strides |
| 80 | + ls_d0_ptr, |
| 81 | + ls_d1_ptr, |
| 82 | + ls_d2_ptr, |
| 83 | + # out ptr strides |
| 84 | + output_d0_stride, |
| 85 | + output_d1_stride, |
| 86 | + # constants |
| 87 | + BLOCK_M: tl.constexpr, |
| 88 | + BLOCK_N: tl.constexpr, |
| 89 | + BLOCK_K: tl.constexpr, |
| 90 | + SAME_STRIDE: tl.constexpr, |
| 91 | + SLICE_NUM: tl.constexpr, |
| 92 | + EVEN_K: tl.constexpr, |
| 93 | + CAST_TYPE: tl.constexpr, |
| 94 | + ADD_INPUTS: tl.constexpr, |
| 95 | +): |
| 96 | + """ |
| 97 | + Given an array of integers that identifies the rows of A, ram, |
| 98 | + a lora index that identifies which LoRA to use from lora_ptr, lora_index, |
| 99 | + a slice_id that identifies the input/output slice, |
| 100 | + compute the matrix product and store in the appropriate output location. |
| 101 | + Given that this is an expand kernel, we don't perform any split-K reduction |
| 102 | + as the K dimension is assumed to be small. |
| 103 | + """ |
| 104 | + |
| 105 | + # ls_d*_ptr can be either an integer or a pointer |
| 106 | + if SAME_STRIDE: |
| 107 | + # integer |
| 108 | + cur_lora_d0_stride = ls_d0_ptr |
| 109 | + cur_lora_d1_stride = ls_d1_ptr |
| 110 | + cur_lora_d2_stride = ls_d2_ptr |
| 111 | + else: |
| 112 | + # pointer |
| 113 | + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) |
| 114 | + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) |
| 115 | + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) |
| 116 | + |
| 117 | + # Identify the input_ptr and lora_ptr from slice_id. |
| 118 | + if SLICE_NUM == 1: |
| 119 | + cur_input_ptr = input_ptr |
| 120 | + cur_lora_ptr = lora_ptr |
| 121 | + else: |
| 122 | + cur_input_ptr = input_ptr + slice_id * input_d0_stride |
| 123 | + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( |
| 124 | + tl.pointer_type(out_ptr.dtype.element_ty)) |
| 125 | + |
| 126 | + # Identify the column indices of B to process. |
| 127 | + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N |
| 128 | + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) |
| 129 | + |
| 130 | + # Identify A and B block pointers |
| 131 | + offset_k = tl.arange(0, BLOCK_K) |
| 132 | + a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + |
| 133 | + offset_k[None, :] * input_d2_stride, ) |
| 134 | + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + |
| 135 | + offset_k[:, None] * cur_lora_d2_stride + |
| 136 | + rbn[None, :] * cur_lora_d1_stride) |
| 137 | + |
| 138 | + # Compute the block matrix product. |
| 139 | + SPLIT_K = 1 |
| 140 | + accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, |
| 141 | + offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, |
| 142 | + CAST_TYPE, cur_lora_ptr.dtype.element_ty) |
| 143 | + |
| 144 | + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) |
| 145 | + if SLICE_NUM == 1: |
| 146 | + cur_slice_start = slice_start_loc |
| 147 | + else: |
| 148 | + cur_slice_start = tl.load(slice_start_loc + slice_id) |
| 149 | + |
| 150 | + # Identify the C output pointers to store the results of the accumulator. |
| 151 | + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start |
| 152 | + offset_cm = tl.arange(0, BLOCK_M) |
| 153 | + c_ptr = (out_ptr + ram[:, None] * output_d0_stride + |
| 154 | + offset_cn[None, :] * output_d1_stride) |
| 155 | + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] |
| 156 | + < (cur_slice_start + N)) |
| 157 | + |
| 158 | + if ADD_INPUTS: |
| 159 | + tiled_out = tl.load(c_ptr, mask=c_mask) |
| 160 | + tiled_c += tiled_out |
| 161 | + tl.store(c_ptr, tiled_c, mask=c_mask) |
| 162 | + |
| 163 | + |
| 164 | +@triton.jit |
| 165 | +def do_shrink_kernel( |
| 166 | + pid_n, |
| 167 | + pid_sk, |
| 168 | + slice_id, |
| 169 | + lora_index, |
| 170 | + input_ptr, |
| 171 | + lora_ptr, |
| 172 | + out_ptr, |
| 173 | + N, |
| 174 | + K, |
| 175 | + M_LEN, |
| 176 | + ram, |
| 177 | + # input strides |
| 178 | + input_d0_stride, |
| 179 | + input_d1_stride, |
| 180 | + # lora strides |
| 181 | + lora_d0_stride, |
| 182 | + lora_d1_stride, |
| 183 | + lora_d2_stride, |
| 184 | + # output strides |
| 185 | + output_d0_stride, |
| 186 | + output_d1_stride, |
| 187 | + output_d2_stride, |
| 188 | + scaling, |
| 189 | + BLOCK_M: tl.constexpr, |
| 190 | + BLOCK_N: tl.constexpr, |
| 191 | + BLOCK_K: tl.constexpr, |
| 192 | + EVEN_K: tl.constexpr, |
| 193 | + SPLIT_K: tl.constexpr, |
| 194 | + SLICE_NUM: tl.constexpr, |
| 195 | +): |
| 196 | + """ |
| 197 | + Given an array of integers that identifies the rows of A, ram, |
| 198 | + a lora index that identifies which LoRA to use from lora_ptr, lora_index, |
| 199 | + a slice_id that identifies the input/output slice, compute the |
| 200 | + matrix product and store in the appropriate output location. |
| 201 | + """ |
| 202 | + |
| 203 | + # Identify the lora_ptr from slice_id. |
| 204 | + if SLICE_NUM == 1: |
| 205 | + # current lora ptr |
| 206 | + cur_lora_ptr = lora_ptr |
| 207 | + else: |
| 208 | + # current lora ptr |
| 209 | + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( |
| 210 | + tl.pointer_type(input_ptr.dtype.element_ty)) |
| 211 | + |
| 212 | + # Identify the column indices of B to process. |
| 213 | + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N |
| 214 | + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) |
| 215 | + |
| 216 | + # Identify A and B block pointers |
| 217 | + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) |
| 218 | + a_ptr = (input_ptr + ram[:, None] * input_d0_stride + |
| 219 | + offset_k[None, :] * input_d1_stride) |
| 220 | + b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + |
| 221 | + rbn[None, :] * lora_d1_stride + |
| 222 | + offset_k[:, None] * lora_d2_stride) |
| 223 | + |
| 224 | + # Compute partial/complete block matrix product. |
| 225 | + accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k, |
| 226 | + K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False, |
| 227 | + cur_lora_ptr.dtype.element_ty) |
| 228 | + |
| 229 | + # Identify the C output pointers to store the results of the accumulator. |
| 230 | + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N |
| 231 | + offset_cm = tl.arange(0, BLOCK_M) |
| 232 | + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + |
| 233 | + slice_id * output_d0_stride) |
| 234 | + c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[ |
| 235 | + None, :] * output_d2_stride |
| 236 | + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N) |
| 237 | + |
| 238 | + accumulator *= scaling |
| 239 | + # handles write-back with reduction-splitting |
| 240 | + if SPLIT_K == 1: |
| 241 | + tl.store(c_ptr, accumulator, mask=c_mask) |
| 242 | + else: |
| 243 | + tl.atomic_add(c_ptr, accumulator, mask=c_mask) |
0 commit comments