Skip to content

Commit a5751d0

Browse files
authored
SwiftKV backup PR (#367)
* SwiftKV support added for CB as well as non-cb
1 parent bd465a2 commit a5751d0

File tree

10 files changed

+656
-5
lines changed

10 files changed

+656
-5
lines changed

QEfficient/__init__.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -----------------------------------------------------------------------------
22
#
3-
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
@@ -12,8 +12,19 @@
1212
# hf_transfer is imported (will happen on line 15 via leading imports)
1313
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
1414

15+
from transformers import AutoConfig
16+
17+
from QEfficient.transformers.modeling_utils import MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS
1518
from QEfficient.utils.logging_utils import logger
1619

20+
# loop over all the model types which are not present in transformers and register them
21+
for model_type, model_cls in MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS.items():
22+
# Register the model config class based on the model type. This will be first element in the tuple
23+
AutoConfig.register(model_type, model_cls[0])
24+
25+
# Register the non transformer library Class and config class using AutoModelClass
26+
model_cls[2].register(model_cls[0], model_cls[1])
27+
1728

1829
def check_qaic_sdk():
1930
"""Check if QAIC SDK is installed"""

QEfficient/transformers/cache_utils.py

+77
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,83 @@ class QEffDynamicCache(DynamicCache):
3636
3737
"""
3838

39+
def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
40+
"""
41+
Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
42+
43+
Parameters:
44+
key_states (`torch.Tensor`):
45+
The new key states to cache.
46+
value_states (`torch.Tensor`):
47+
The new value states to cache.
48+
layer_idx (`int`):
49+
The index of the layer to cache the states for.
50+
cache_kwargs (`Dict[str, Any]`, `optional`):
51+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
52+
"""
53+
# Update the cache
54+
if len(self.key_cache) <= layer_idx:
55+
self.key_cache.append(key_states)
56+
self.value_cache.append(value_states)
57+
else:
58+
position_ids = cache_kwargs.get("position_ids")
59+
batch_index = cache_kwargs.get("batch_index", None)
60+
61+
# Scatter
62+
if batch_index is not None:
63+
invalid_scatter_index = torch.iinfo(torch.int32).max
64+
scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids)
65+
66+
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
67+
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
68+
)
69+
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
70+
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
71+
)
72+
else:
73+
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
74+
self.value_cache[layer_idx] = CtxScatterFunc.apply(
75+
self.value_cache[layer_idx], position_ids, value_states
76+
)
77+
78+
def read_only(self, layer_idx, cache_kwargs):
79+
"""
80+
Reads the `key_states` and `value_states` for the layer `layer_idx`.
81+
82+
Parameters:
83+
layer_idx (`int`):
84+
The index of the layer to cache the states for.
85+
cache_kwargs (`Dict[str, Any]`, `optional`):
86+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
87+
88+
Return:
89+
A tuple containing the updated key and value states.
90+
"""
91+
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
92+
position_ids = cache_kwargs.get("position_ids")
93+
batch_index = cache_kwargs.get("batch_index", None)
94+
ctx_len = k_out.shape[2]
95+
ctx_indices = torch.arange(ctx_len)[None, None, ...]
96+
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
97+
invalid_mask = ctx_indices > gather_limit
98+
99+
if torch.onnx.is_in_onnx_export():
100+
invalid_idx_value = torch.iinfo(torch.int32).max
101+
else:
102+
invalid_idx_value = 0
103+
104+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
105+
106+
if batch_index is not None:
107+
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
108+
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
109+
else:
110+
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
111+
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
112+
113+
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
114+
return k_out, v_out
115+
39116
def update(
40117
self,
41118
key_states: torch.Tensor,

QEfficient/transformers/modeling_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
import torch.nn as nn
13+
from transformers import AutoModelForCausalLM
1314
from transformers.models.codegen.modeling_codegen import (
1415
CodeGenAttention,
1516
CodeGenBlock,
@@ -88,6 +89,12 @@
8889

8990
from QEfficient.customop import CustomRMSNormAIC
9091

92+
# Placeholder for all non-transformer models
93+
from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import (
94+
QEffLlamaSwiftKVConfig,
95+
QEffLlamaSwiftKVForCausalLM,
96+
)
97+
9198
from .models.codegen.modeling_codegen import (
9299
QEffCodeGenAttention,
93100
QeffCodeGenBlock,
@@ -271,6 +278,11 @@
271278
WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration,
272279
}
273280

281+
# Map of model type to config class, Modelling class and transformer model architecture class
282+
MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = {
283+
"llama_swiftkv": [QEffLlamaSwiftKVConfig, QEffLlamaSwiftKVForCausalLM, AutoModelForCausalLM],
284+
}
285+
274286

275287
def _prepare_cross_attention_mask(
276288
cross_attention_mask: torch.Tensor,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)