-
Notifications
You must be signed in to change notification settings - Fork 313
Open
Labels
enhancementNew feature or requestNew feature or request
Description
In order to support sub-byte dtypes for quantization, I (and many others) believe that it is better to pack these smaller dtypes into existing pytorch dtypes in order to reduce memory bandwidth contention for a bit of increased computation. Here is a preliminary algorithm in pytorch for doing this. It supports many types of conversions as seen in the tests.
Inspecting the compiled Triton code seems promising because it only launches one kernel and one buffer. Here is a snippit
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 4
x1 = (xindex // 4)
x2 = xindex
tmp0 = x0
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 1, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (x1), tmp4 & xmask, eviction_policy='evict_last', other=0.0)
tmp6 = tl.full([1], 6, tl.uint8)
tmp7 = tmp5 >> tmp6
tmp8 = tl.full([1], 3, tl.uint8)
tmp9 = tmp7 & tmp8
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp11 = tl.where(tmp4, tmp9, tmp10)
tmp12 = tmp0 >= tmp3
tmp13 = tl.full([1], 2, tl.int64)
tmp14 = tmp0 < tmp13
tmp15 = tmp12 & tmp14
tmp16 = tl.load(in_ptr0 + (x1), tmp15 & xmask, eviction_policy='evict_last', other=0.0)
tmp17 = tl.full([1], 4, tl.uint8)
tmp18 = tmp16 >> tmp17
tmp19 = tmp18 & tmp8
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp15, tmp19, tmp20)
tmp22 = tmp0 >= tmp13
tmp23 = tl.full([1], 3, tl.int64)
tmp24 = tmp0 < tmp23
tmp25 = tmp22 & tmp24
tmp26 = tl.load(in_ptr0 + (x1), tmp25 & xmask, eviction_policy='evict_last', other=0.0)
tmp27 = tl.full([1], 2, tl.uint8)
tmp28 = tmp26 >> tmp27
tmp29 = tmp28 & tmp8
tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
tmp31 = tl.where(tmp25, tmp29, tmp30)
tmp32 = tmp0 >= tmp23
tmp33 = tl.full([1], 4, tl.int64)
tmp34 = tmp0 < tmp33
tmp35 = tl.load(in_ptr0 + (x1), tmp32 & xmask, eviction_policy='evict_last', other=0.0)
tmp36 = tl.full([1], 0, tl.uint8)
tmp37 = tmp35 >> tmp36
tmp38 = tmp37 & tmp8
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
tmp40 = tl.where(tmp32, tmp38, tmp39)
tmp41 = tl.where(tmp25, tmp31, tmp40)
tmp42 = tl.where(tmp15, tmp21, tmp41)
tmp43 = tl.where(tmp4, tmp11, tmp42)
tl.store(out_ptr0 + (x2), tmp43, xmask)
''', device_str='cuda')
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1 = args
args.clear()
s0 = arg0_1
s1 = arg1_1
s2 = arg2_1
assert_size_stride(arg3_1, (s0, s1, s2), (s1*s2, s2, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((s0, s1, s2, 4), (4*s1*s2, 4*s2, 4, 1), torch.uint8)
# Source Nodes: [stack], Original ATen: [aten.stack]
triton_poi_fused_stack_0_xnumel = 4*s0*s1*s2
stream0 = get_raw_stream(0)
triton_poi_fused_stack_0.run(arg3_1, buf0, triton_poi_fused_stack_0_xnumel, grid=grid(triton_poi_fused_stack_0_xnumel), stream=stream0)
del arg3_1
return (reinterpret_tensor(buf0, (s0, s1, 4*s2), (4*s1*s2, 4*s2, 1), 0), )
msaroufimgau-nernst and mobicham
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
msaroufim commentedon May 26, 2024
This is quite cool and I've been thinking along similar lines
I think what we could to do to ship this is in
quantization/
merge the pack and unpack functions and then have tests to ensure the the codegen is efficient. In practice you can test that a single kernel is launched by in your tests doingtorch.compile(..., fullgraph=True)
- I'm not sure how we can validate that single buffer is used but perhaps @eellison doesAnd this can be a baseline for smaller dtypes. I'd be specific somewhere in the function names or docs that this is padding-based? Cause conceptually I can imagine another alternative where instead of wasting space you could pack 8 uint3 into 3 unint8 as a more general algorithm but that's finicky enough that we don't have to worry about it right now
msaroufim commentedon May 27, 2024
Also @mobicham had been asking us for standardizing bitpacking logic so curious on his thoughts too
mobicham commentedon May 27, 2024
Thanks @vayuda , very interesting, thanks of sharing!
Normally, bit-unpacking is almost never used in isolation, it's either fused in a dequant kernel or a low-bit matmul kernel. There are two main things to consider while designing a bitpacking logic:
The axis along which quantization is performed: if you quantize along axis=0, and you bitpack along the same axis, the scale/zero can be accessed only once per group. However, if you quantize along axis=1 and you bitpack along axis=0, you'll have to access the scale/zero more than once and it makes dequantization slower.
Here are two Triton dequant kernels for both cases I wrote, you can see in the second one, I had to access the zero/scale twice for 4-bit for axis=1, it would be even worse for lower bits:
axis=0: https://github.com/mobiusml/hqq/blob/triton/hqq/kernels/triton/dequant.py#L33-L39
axis=1: https://github.com/mobiusml/hqq/blob/triton/hqq/kernels/triton/dequant.py#L65-L71
Since most of the methods quantize along axis=1, it would make sense to have a bitpacking logic that is optimized for that case.
The memory access pattern should be taken into account: if someone writes a Cuda or Triton optimized fused kernel, bitpacking should be structured in a way that can fully take advantage of tensor cores.
@jeromeku suggested using interleaved access. Here's a 4-bit bitpacking example using that logic: https://github.com/mobiusml/hqq/blob/triton/hqq/kernels/triton/benchmark.py#L28-L35
@msaroufim do you know by any chance what kind of bitpacking logic is used in tiny_gemm?
vayuda commentedon May 27, 2024
@mobicham Thanks for the input. The interleaved accessing is interesting though I'm not really sure what it means to fully take advantage of tensor cores. I think this is something we can iterate on. For now I can create a version that can do row-wise pack/unpack.
As per @msaroufim suggestions, I will place these functions in the api file and write appropriate tests.
vadimkantorov commentedon Jun 27, 2024
Even in relative isolation (without op support) bit packing/unpacking, is still useful for saving memory footprint when storing bool tensors / masks / bitsets:
But of course, more op support is needed for compressed bool tensors / bittensors / bitsets as well...
(Similarly, for some other usecases, it is still useful even when packing/unpacking is not fused into ops where the bottleneck is actually memory efficiency and speed overhead can be tolerated)
Enable llama3 8B in generate.py (pytorch#284)