Skip to content

Commit a33fb48

Browse files
committed
Addressed comments and added support of the external model using Infer
Signed-off-by: Amit Raj <[email protected]>
1 parent a74b9a7 commit a33fb48

File tree

7 files changed

+136
-8
lines changed

7 files changed

+136
-8
lines changed

QEfficient/base/common.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from transformers import AutoConfig
1919

2020
from QEfficient.base.modeling_qeff import QEFFBaseModel
21-
from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING
21+
from QEfficient.transformers.modeling_utils import EXTERNAL_MODEL_CLASS_MAPPING, MODEL_CLASS_MAPPING
2222
from QEfficient.utils import login_and_download_hf_lm
2323

2424

@@ -40,9 +40,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
4040
"""
4141
Downloads HuggingFace model if already doesn't exist locally, returns QEFFAutoModel object based on type of model.
4242
"""
43-
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
43+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
4444

45-
class_name = MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
45+
# class_name = MODEL_CLASS_MAPPING.get(config.__class__.__name__, None) OR MODEL_EXTERNAL_CLASS_MAPPING(config.__class__.__name__)
46+
class_name = (
47+
MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
48+
or EXTERNAL_MODEL_CLASS_MAPPING[config.__class__.__name__]
49+
)
4650
if class_name:
4751
module = __import__("QEfficient.transformers.models.modeling_auto")
4852
model_class = getattr(module, class_name)
@@ -61,6 +65,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
6165
pretrained_model_name_or_path=(local_model_dir if local_model_dir else pretrained_model_name_or_path),
6266
token=hf_token,
6367
continuous_batching=continuous_batching,
68+
trust_remote_code=True,
6469
**kwargs,
6570
)
6671
return qeff_model

QEfficient/base/modeling_qeff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, model: torch.nn.Module) -> None:
5252
self.onnx_path: Optional[str] = None
5353
self.qpc_path: Optional[str] = None
5454
self.qpc_session: Optional[QAICInferenceSession] = None
55-
55+
model = model.to(torch.float32)
5656
# Apply the transformations
5757
any_transformed = False
5858
for transform in self._pytorch_transforms:

QEfficient/transformers/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def build_model_class_mapping(auto_model_class, qeff_class_name):
283283
}
284284

285285

286+
EXTERNAL_MODEL_CLASS_MAPPING = {"Grok1Config": "QEFFAutoModelForCausalLM"}
286287
MODEL_CLASS_MAPPING = {
287288
**build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM"),
288289
**build_model_class_mapping(mapping.AutoModelForImageTextToText, "QEFFAutoModelForImageTextToText"),

QEfficient/transformers/models/grok_1/modeling_grok1.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,25 @@ class QEFFGrok1CustomRMSNormAIC(nn.Module):
2828
"""
2929

3030
def forward(self, hidden_states):
31+
"""
32+
Forward pass of the RMSNorm module.
33+
34+
Args:
35+
hidden_states (torch.Tensor): Input tensor to be normalized.
36+
37+
Returns:
38+
torch.Tensor: Normalized tensor.
39+
"""
3140
return CustomRMSNormFunc.apply(
3241
hidden_states, self.scale, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps
3342
)
3443

3544

3645
class QEffGrok1MultiHeadAttention(nn.Module):
46+
"""
47+
Multi-head attention module.
48+
"""
49+
3750
def forward(
3851
self,
3952
hidden_states: torch.Tensor,
@@ -46,6 +59,22 @@ def forward(
4659
use_cache: bool = False,
4760
**kwargs,
4861
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
62+
"""
63+
Forward pass of the multi-head attention module.
64+
65+
Args:
66+
hidden_states (torch.Tensor): Input tensor.
67+
layer_idx (int): Layer index.
68+
attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
69+
position_ids (Optional[torch.LongTensor], optional): Position ids. Defaults to None.
70+
past_key_value (Optional[Tuple[torch.Tensor]], optional): Past key value. Defaults to None.
71+
batch_index (Optional[torch.LongTensor], optional): Batch index. Defaults to None.
72+
output_attentions (bool, optional): Whether to output attentions. Defaults to False.
73+
use_cache (bool, optional): Whether to use cache. Defaults to False.
74+
75+
Returns:
76+
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: Attention output, attention weights, and past key value.
77+
"""
4978
bsz, q_len, _ = hidden_states.size()
5079

5180
query_states = self.q_proj(hidden_states)
@@ -101,7 +130,20 @@ def forward(
101130

102131

103132
class QEffGrok1MoeBlock(nn.Module):
133+
"""
134+
Mixture of experts (MoE) block.
135+
"""
136+
104137
def forward(self, hidden_states: torch.Tensor):
138+
"""
139+
Forward pass of the MoE block.
140+
141+
Args:
142+
hidden_states (torch.Tensor): Input tensor.
143+
144+
Returns:
145+
torch.Tensor: MoE output.
146+
"""
105147
batch_size, sequence_length, hidden_dim = hidden_states.shape
106148
hidden_states = hidden_states.view(-1, hidden_dim)
107149
router_logits = self.gate(hidden_states)
@@ -116,8 +158,8 @@ def forward(self, hidden_states: torch.Tensor):
116158
torch.nn.functional.one_hot(selected_experts[:, 1], num_classes=self.num_experts).bool().T.unsqueeze(-1)
117159
)
118160

119-
gateupout1 = torch.zeros(hidden_states.shape[0], 32768) # T, hs
120-
gateupout2 = torch.zeros(hidden_states.shape[0], 32768) # T, hs
161+
gateupout1 = torch.zeros(hidden_states.shape[0], self.ffn_dim) # T, hs
162+
gateupout2 = torch.zeros(hidden_states.shape[0], self.ffn_dim) # T, hs
121163
for expert_idx in range(self.num_experts):
122164
expert_layer = self.experts[expert_idx]
123165
current_expert_output = expert_layer.act_fn(expert_layer.linear(hidden_states)) * expert_layer.linear_v(
@@ -150,6 +192,16 @@ def forward(self, hidden_states: torch.Tensor):
150192

151193

152194
class QEffGrok1DecoderLayer(nn.Module):
195+
"""
196+
Decoder block of Grok1 model.
197+
"""
198+
199+
def __qeff_init__(self):
200+
"""
201+
Assigning extra args to Moe block of decoder.
202+
"""
203+
self.moe_block.ffn_dim = self.config.intermediate_size
204+
153205
def forward(
154206
self,
155207
hidden_states: torch.Tensor,
@@ -162,6 +214,22 @@ def forward(
162214
use_cache: Optional[bool] = False,
163215
**kwargs,
164216
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
217+
"""
218+
Initialize the decoder layer.
219+
220+
Args:
221+
hidden_states (torch.Tensor): Input tensor.
222+
attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
223+
position_ids (Optional[torch.LongTensor], optional): Position ids. Defaults to None.
224+
past_key_value (Optional[Tuple[torch.Tensor]], optional): Past key value. Defaults to None.
225+
batch_index (Optional[torch.LongTensor], optional): Batch index. Defaults to None.
226+
output_attentions (Optional[bool], optional): Whether to output attentions. Defaults to False.
227+
output_router_logits (Optional[bool], optional): Whether to output router logits. Defaults to False.
228+
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
229+
230+
Returns:
231+
Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Decoder output, attention weights, and past key value.
232+
"""
165233
residual = hidden_states
166234
hidden_states = self.pre_attn_norm(hidden_states)
167235
hidden_states, attention_weights, present_key_value = self.attn(
@@ -194,9 +262,17 @@ def forward(
194262

195263

196264
class QEffGrok1Model(nn.Module):
265+
"""
266+
Grok1 model
267+
"""
268+
197269
def __qeff_init__(self):
270+
"""
271+
Initialize the extra args to model.
272+
"""
198273
for idx, layer in enumerate(self.layers):
199274
layer.layer_idx = idx
275+
layer.config = self.config
200276

201277
def forward(
202278
self,
@@ -212,6 +288,24 @@ def forward(
212288
output_router_logits: Optional[bool] = None,
213289
return_dict: Optional[bool] = None,
214290
) -> Union[Tuple, MoeModelOutputWithPast]:
291+
"""
292+
Forward pass of the Grok1 model.
293+
Args:
294+
input_ids (torch.LongTensor, optional): Input ids. Defaults to None.
295+
attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
296+
position_ids (Optional[torch.LongTensor], optional): Position ids. Defaults to None.
297+
past_key_values (Optional[List[torch.FloatTensor]], optional): Past key values. Defaults to None.
298+
batch_index (Optional[torch.LongTensor], optional): Batch index. Defaults to None.
299+
inputs_embeds (Optional[torch.FloatTensor], optional): Input embeddings. Defaults to None.
300+
use_cache (Optional[bool], optional): Whether to use cache. Defaults to None.
301+
output_attentions (Optional[bool], optional): Whether to output attentions. Defaults to None.
302+
output_hidden_states (Optional[bool], optional): Whether to output hidden states. Defaults to None.
303+
output_router_logits (Optional[bool], optional): Whether to output router logits. Defaults to None.
304+
return_dict (Optional[bool], optional): Whether to return a dictionary. Defaults to None.
305+
306+
Returns:
307+
Union[Tuple, MoeModelOutputWithPast]: Model output.
308+
"""
215309
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
216310
output_hidden_states = (
217311
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -294,6 +388,10 @@ def forward(
294388

295389

296390
class QEffGrok1ModelForCausalLM(nn.Module):
391+
"""
392+
Grok model for causal language modeling.
393+
"""
394+
297395
def forward(
298396
self,
299397
input_ids: torch.LongTensor = None,
@@ -310,6 +408,26 @@ def forward(
310408
return_dict: Optional[bool] = None,
311409
**kwargs,
312410
):
411+
"""
412+
Forward pass for Grok model for causal language modeling
413+
414+
Args:
415+
input_ids (torch.LongTensor, optional): Input ids. Defaults to None.
416+
attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
417+
position_ids (Optional[torch.LongTensor], optional): Position ids. Defaults to None.
418+
past_key_values (Optional[List[torch.FloatTensor]], optional): Past key values. Defaults to None.
419+
batch_index (Optional[torch.LongTensor], optional): Batch index. Defaults to None.
420+
inputs_embeds (Optional[torch.FloatTensor], optional): Input embeddings. Defaults to None.
421+
labels (Optional[torch.LongTensor], optional): Labels. Defaults to None.
422+
use_cache (Optional[bool], optional): Whether to use cache. Defaults to None.
423+
output_attentions (Optional[bool], optional): Whether to output attentions. Defaults to None.
424+
output_hidden_states (Optional[bool], optional): Whether to output hidden states. Defaults to None.
425+
output_router_logits (Optional[bool], optional): Whether to output router logits. Defaults to None.
426+
return_dict (Optional[bool], optional): Whether to return a dictionary. Defaults to None.
427+
428+
Returns:
429+
MoeCausalLMOutputWithPast: Model output.
430+
"""
313431
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
314432
output_router_logits = (
315433
output_router_logits if output_router_logits is not None else self.config.output_router_logits

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,10 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
499499
"forward": QEffGrok1Model.forward,
500500
"__qeff_init__": QEffGrok1Model.__qeff_init__,
501501
},
502-
"DecoderLayer": {"forward": QEffGrok1DecoderLayer.forward},
502+
"DecoderLayer": {
503+
"forward": QEffGrok1DecoderLayer.forward,
504+
"__qeff_init__": QEffGrok1DecoderLayer.__qeff_init__,
505+
},
503506
"MoeBlock": {"forward": QEffGrok1MoeBlock.forward},
504507
"MultiHeadAttention": {
505508
"forward": QEffGrok1MultiHeadAttention.forward,

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
---
77

88
*Latest news* :fire: <br>
9+
- [06/2025] Added support of model `hpcai-tech/grok-1` [hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1)
910
- [03/2025] Added support for swiftkv model [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct)
1011
- [02/2025] [VLMs support](https://github.com/quic/efficient-transformers/pull/267) added for the models [InternVL-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B), [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf) and [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
1112
- [01/2025] [FP8 models support](https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127) Added support for inference of FP8 models.

docs/source/validate.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
| **QwenForCausalLM** | DeepSeek-R1-Distill-Qwen | [DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | ✔️ |
3535
| | Qwen2, Qwen2.5 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | ✔️ |
3636
| **LlamaSwiftKVForCausalLM** | swiftkv | [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct) | ✔️ |
37-
37+
| **Grok1ModelForCausalLM** | grok-1 | [hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1) | ✔️ |
3838
## Embedding Models
3939

4040
### Text Embedding Task

0 commit comments

Comments
 (0)