@@ -28,12 +28,25 @@ class QEFFGrok1CustomRMSNormAIC(nn.Module):
28
28
"""
29
29
30
30
def forward (self , hidden_states ):
31
+ """
32
+ Forward pass of the RMSNorm module.
33
+
34
+ Args:
35
+ hidden_states (torch.Tensor): Input tensor to be normalized.
36
+
37
+ Returns:
38
+ torch.Tensor: Normalized tensor.
39
+ """
31
40
return CustomRMSNormFunc .apply (
32
41
hidden_states , self .scale , self .variance_epsilon if hasattr (self , "variance_epsilon" ) else self .eps
33
42
)
34
43
35
44
36
45
class QEffGrok1MultiHeadAttention (nn .Module ):
46
+ """
47
+ Multi-head attention module.
48
+ """
49
+
37
50
def forward (
38
51
self ,
39
52
hidden_states : torch .Tensor ,
@@ -46,6 +59,22 @@ def forward(
46
59
use_cache : bool = False ,
47
60
** kwargs ,
48
61
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
62
+ """
63
+ Forward pass of the multi-head attention module.
64
+
65
+ Args:
66
+ hidden_states (torch.Tensor): Input tensor.
67
+ layer_idx (int): Layer index.
68
+ attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
69
+ position_ids (Optional[torch.LongTensor], optional): Position ids. Defaults to None.
70
+ past_key_value (Optional[Tuple[torch.Tensor]], optional): Past key value. Defaults to None.
71
+ batch_index (Optional[torch.LongTensor], optional): Batch index. Defaults to None.
72
+ output_attentions (bool, optional): Whether to output attentions. Defaults to False.
73
+ use_cache (bool, optional): Whether to use cache. Defaults to False.
74
+
75
+ Returns:
76
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: Attention output, attention weights, and past key value.
77
+ """
49
78
bsz , q_len , _ = hidden_states .size ()
50
79
51
80
query_states = self .q_proj (hidden_states )
@@ -101,7 +130,20 @@ def forward(
101
130
102
131
103
132
class QEffGrok1MoeBlock (nn .Module ):
133
+ """
134
+ Mixture of experts (MoE) block.
135
+ """
136
+
104
137
def forward (self , hidden_states : torch .Tensor ):
138
+ """
139
+ Forward pass of the MoE block.
140
+
141
+ Args:
142
+ hidden_states (torch.Tensor): Input tensor.
143
+
144
+ Returns:
145
+ torch.Tensor: MoE output.
146
+ """
105
147
batch_size , sequence_length , hidden_dim = hidden_states .shape
106
148
hidden_states = hidden_states .view (- 1 , hidden_dim )
107
149
router_logits = self .gate (hidden_states )
@@ -116,8 +158,8 @@ def forward(self, hidden_states: torch.Tensor):
116
158
torch .nn .functional .one_hot (selected_experts [:, 1 ], num_classes = self .num_experts ).bool ().T .unsqueeze (- 1 )
117
159
)
118
160
119
- gateupout1 = torch .zeros (hidden_states .shape [0 ], 32768 ) # T, hs
120
- gateupout2 = torch .zeros (hidden_states .shape [0 ], 32768 ) # T, hs
161
+ gateupout1 = torch .zeros (hidden_states .shape [0 ], self . ffn_dim ) # T, hs
162
+ gateupout2 = torch .zeros (hidden_states .shape [0 ], self . ffn_dim ) # T, hs
121
163
for expert_idx in range (self .num_experts ):
122
164
expert_layer = self .experts [expert_idx ]
123
165
current_expert_output = expert_layer .act_fn (expert_layer .linear (hidden_states )) * expert_layer .linear_v (
@@ -150,6 +192,16 @@ def forward(self, hidden_states: torch.Tensor):
150
192
151
193
152
194
class QEffGrok1DecoderLayer (nn .Module ):
195
+ """
196
+ Decoder block of Grok1 model.
197
+ """
198
+
199
+ def __qeff_init__ (self ):
200
+ """
201
+ Assigning extra args to Moe block of decoder.
202
+ """
203
+ self .moe_block .ffn_dim = self .config .intermediate_size
204
+
153
205
def forward (
154
206
self ,
155
207
hidden_states : torch .Tensor ,
@@ -162,6 +214,22 @@ def forward(
162
214
use_cache : Optional [bool ] = False ,
163
215
** kwargs ,
164
216
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
217
+ """
218
+ Initialize the decoder layer.
219
+
220
+ Args:
221
+ hidden_states (torch.Tensor): Input tensor.
222
+ attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
223
+ position_ids (Optional[torch.LongTensor], optional): Position ids. Defaults to None.
224
+ past_key_value (Optional[Tuple[torch.Tensor]], optional): Past key value. Defaults to None.
225
+ batch_index (Optional[torch.LongTensor], optional): Batch index. Defaults to None.
226
+ output_attentions (Optional[bool], optional): Whether to output attentions. Defaults to False.
227
+ output_router_logits (Optional[bool], optional): Whether to output router logits. Defaults to False.
228
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
229
+
230
+ Returns:
231
+ Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Decoder output, attention weights, and past key value.
232
+ """
165
233
residual = hidden_states
166
234
hidden_states = self .pre_attn_norm (hidden_states )
167
235
hidden_states , attention_weights , present_key_value = self .attn (
@@ -194,9 +262,17 @@ def forward(
194
262
195
263
196
264
class QEffGrok1Model (nn .Module ):
265
+ """
266
+ Grok1 model
267
+ """
268
+
197
269
def __qeff_init__ (self ):
270
+ """
271
+ Initialize the extra args to model.
272
+ """
198
273
for idx , layer in enumerate (self .layers ):
199
274
layer .layer_idx = idx
275
+ layer .config = self .config
200
276
201
277
def forward (
202
278
self ,
@@ -212,6 +288,24 @@ def forward(
212
288
output_router_logits : Optional [bool ] = None ,
213
289
return_dict : Optional [bool ] = None ,
214
290
) -> Union [Tuple , MoeModelOutputWithPast ]:
291
+ """
292
+ Forward pass of the Grok1 model.
293
+ Args:
294
+ input_ids (torch.LongTensor, optional): Input ids. Defaults to None.
295
+ attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
296
+ position_ids (Optional[torch.LongTensor], optional): Position ids. Defaults to None.
297
+ past_key_values (Optional[List[torch.FloatTensor]], optional): Past key values. Defaults to None.
298
+ batch_index (Optional[torch.LongTensor], optional): Batch index. Defaults to None.
299
+ inputs_embeds (Optional[torch.FloatTensor], optional): Input embeddings. Defaults to None.
300
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to None.
301
+ output_attentions (Optional[bool], optional): Whether to output attentions. Defaults to None.
302
+ output_hidden_states (Optional[bool], optional): Whether to output hidden states. Defaults to None.
303
+ output_router_logits (Optional[bool], optional): Whether to output router logits. Defaults to None.
304
+ return_dict (Optional[bool], optional): Whether to return a dictionary. Defaults to None.
305
+
306
+ Returns:
307
+ Union[Tuple, MoeModelOutputWithPast]: Model output.
308
+ """
215
309
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
216
310
output_hidden_states = (
217
311
output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
@@ -294,6 +388,10 @@ def forward(
294
388
295
389
296
390
class QEffGrok1ModelForCausalLM (nn .Module ):
391
+ """
392
+ Grok model for causal language modeling.
393
+ """
394
+
297
395
def forward (
298
396
self ,
299
397
input_ids : torch .LongTensor = None ,
@@ -310,6 +408,26 @@ def forward(
310
408
return_dict : Optional [bool ] = None ,
311
409
** kwargs ,
312
410
):
411
+ """
412
+ Forward pass for Grok model for causal language modeling
413
+
414
+ Args:
415
+ input_ids (torch.LongTensor, optional): Input ids. Defaults to None.
416
+ attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
417
+ position_ids (Optional[torch.LongTensor], optional): Position ids. Defaults to None.
418
+ past_key_values (Optional[List[torch.FloatTensor]], optional): Past key values. Defaults to None.
419
+ batch_index (Optional[torch.LongTensor], optional): Batch index. Defaults to None.
420
+ inputs_embeds (Optional[torch.FloatTensor], optional): Input embeddings. Defaults to None.
421
+ labels (Optional[torch.LongTensor], optional): Labels. Defaults to None.
422
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to None.
423
+ output_attentions (Optional[bool], optional): Whether to output attentions. Defaults to None.
424
+ output_hidden_states (Optional[bool], optional): Whether to output hidden states. Defaults to None.
425
+ output_router_logits (Optional[bool], optional): Whether to output router logits. Defaults to None.
426
+ return_dict (Optional[bool], optional): Whether to return a dictionary. Defaults to None.
427
+
428
+ Returns:
429
+ MoeCausalLMOutputWithPast: Model output.
430
+ """
313
431
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
314
432
output_router_logits = (
315
433
output_router_logits if output_router_logits is not None else self .config .output_router_logits
0 commit comments