Skip to content

Does SpinQuant implemented R3 when using quantized kv cache? #9705

Open
@WeiMa01

Description

@WeiMa01

When we execute SpinQuant using ExecuTorch, we observe that only R4 supports online rotation, while R3 does not. We would like to confirm whether ExecuTorch does not support R3 for SpinQuant.

  1. convert to pte, already enable quantize_kv_cache
    python -m examples.models.llama.export_llama
    --model "llama3_2"
    --checkpoint "/home/zhuan.zhang/llama_models/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8/consolidated.00.pth"
    --params "/home/zhuan.zhang/llama_models/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8/params.json"
    --use_sdpa_with_kv_cache
    -X
    --xnnpack-extended-ops
    --preq_mode 8da4w_output_8da8w
    --preq_group_size 32
    --max_seq_length 2048
    --max_seq_length 2048
    --output_name "llama3_2.pte"
    -kv
    -d fp32
    --preq_embedding_quantize 8,0
    --quantize_kv_cache
    --output_name 'llama3_2_spinquant_qkv.pte'
    --use_spin_quant native
    --generate_etrecord

  2. Runtime delegate op show "llama_fast_hadamard_transform_default" calling 16 times(1 time / decoder layer), which is R4
    | op_type | occurrences_in_delegated_graphs | occurrences_in_non_delegated_graphs |
    19 | llama_fast_hadamard_transform | 0 | 16 |

  3. Source code show using SpinQuant, which just replace FeedForward with FeedForwardNativeCustom using inject_fast_hadamard_transform_native_for_spin_quant

def _get_source_transforms(  # noqa
    modelname: str, dtype_override: Optional[DType], args
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
    transforms = []
    if args.use_spin_quant:
        if args.use_spin_quant == "cuda":
            from .source_transformation.spin_quant import (
                inject_fast_hadamard_transform_cuda_for_spin_quant,
            )
            transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
        elif args.use_spin_quant == "native":
            from .source_transformation.spin_quant import (
                inject_fast_hadamard_transform_native_for_spin_quant,
            )
            transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
def _inject_fast_hadamard_transform_native_for_spin_quant(module: torch.nn.Module):
    """
    SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
    R3 needs to be injected as well when KV cache quantization is enabled.
    """
    class FeedForwardNativeCustom(nn.Module):
        def __init__(self, w1, w2, w3):
            super().__init__()
            self.w1 = w1
            self.w2 = w2
            self.w3 = w3
        def forward(self, x):
            return self.w2(
                torch.ops.llama.fast_hadamard_transform(F.silu(self.w1(x)) * self.w3(x))
            )
    for name, child in module.named_children():
        if isinstance(child, FeedForward):
            setattr(module, name, FeedForwardNativeCustom(child.w1, child.w2, child.w3))
        else:
            _inject_fast_hadamard_transform_native_for_spin_quant(child)

cc @kimishpatel @jerryzh168 @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng

Metadata

Metadata

Assignees

Labels

module: llmIssues related to LLM examples and apps, and to the extensions/llm/ codemodule: quantizationIssues related to quantization

Type

No type

Projects

Status

To triage

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions