@@ -212,6 +212,34 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
212212 return assignment_list
213213
214214
215+ class LMHeadFunction (paddle .autograd .PyLayer ):
216+ @staticmethod
217+ def forward (ctx , x , weight , transpose_y ):
218+ out = paddle .matmul (x , weight , transpose_y = transpose_y )
219+
220+ ctx .save_for_backward (x , weight , transpose_y )
221+ return out
222+
223+ @staticmethod
224+ def backward (ctx , dout ):
225+ if dout .dtype == paddle .float32 :
226+ dout = dout .cast ( paddle .bfloat16 )
227+
228+ x , weight , transpose_y = ctx .saved_tensor ()
229+
230+ dx = paddle .matmul ( dout , weight , transpose_y = not transpose_y )
231+ if transpose_y :
232+ with paddle .amp .auto_cast (False ):
233+ paddle ._C_ops .fused_linear_param_grad_add (
234+ dout .reshape ( [- 1 , dout .shape [- 1 ]]), x .reshape ( [- 1 , x .shape [- 1 ]]), weight .main_grad , None , True , False
235+ )
236+ else :
237+ with paddle .amp .auto_cast (False ):
238+ paddle ._C_ops .fused_linear_param_grad_add (
239+ x .reshape ([- 1 , x .shape [- 1 ]]), dout .reshape ([- 1 , dout .shape [- 1 ]]), weight .main_grad , None , True , False
240+ )
241+ return dx , None
242+
215243def parallel_matmul (x : Tensor , y : Tensor , transpose_y = False , tensor_parallel_output = True ):
216244 is_fleet_init = True
217245 tensor_parallel_degree = 1
@@ -238,10 +266,9 @@ def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_out
238266 return paddle .distributed .collective ._c_concat (logits , group = model_parallel_group )
239267
240268 else :
241- logits = paddle . matmul (x , y , transpose_y = transpose_y )
269+ logits = LMHeadFunction . apply (x , y , transpose_y = transpose_y )
242270 return logits
243271
244-
245272def scaled_dot_product_attention (
246273 query_states ,
247274 config ,
@@ -2469,8 +2496,9 @@ def forward(
24692496 ) -> Tuple [paddle .Tensor , Optional [Tuple [paddle .Tensor , paddle .Tensor ]]]:
24702497 hidden_states = self .hnorm (hidden_states )
24712498 nextn_hidden_state = self .enorm (nextn_hidden_state )
2472-
2473- hidden_states = self .eh_proj (paddle .concat ([hidden_states , nextn_hidden_state ], axis = - 1 ))
2499+
2500+ concat_h = paddle .concat ([hidden_states , nextn_hidden_state ], axis = - 1 )
2501+ hidden_states = LMHeadFunction .apply ( concat_h , self .eh_proj .weight , False )
24742502
24752503 layer_outputs = super (DeepseekV2MTPLayer , self ).forward (
24762504 hidden_states ,
0 commit comments