@@ -195,6 +195,10 @@ def scaled_dot_product_attention(
195
195
return (attn_output , attn_weights ) if output_attentions else attn_output
196
196
197
197
198
+ colwise_placements = [dist .Replicate (), dist .Shard (1 )]
199
+ rowise_placement = [dist .Replicate (), dist .Shard (0 )]
200
+
201
+
198
202
class LlamaRMSNormAuto (nn .Layer ):
199
203
def __init__ (self , config , ipp ):
200
204
super ().__init__ ()
@@ -237,16 +241,6 @@ def __init__(self, config, ipp: Optional[int] = None):
237
241
self .fuse_attention_ffn = config .fuse_attention_ffn
238
242
self .ipp = ipp
239
243
self .config = config
240
- colwise_placements = (
241
- [dist .Replicate (), dist .Shard (1 )]
242
- if self .config .tensor_parallel_degree > 1
243
- else [dist .Replicate (), dist .Replicate ()]
244
- )
245
- rowise_placement = (
246
- [dist .Replicate (), dist .Shard (0 )]
247
- if self .config .tensor_parallel_degree > 1
248
- else [dist .Replicate (), dist .Replicate ()]
249
- )
250
244
251
245
if config .fuse_attention_ffn and not enable_fuse_ffn_qkv_pass ():
252
246
self .gate_up_fused_proj = nn .Linear (self .hidden_size , self .intermediate_size * 2 , bias_attr = False )
@@ -316,17 +310,6 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
316
310
self .recompute_granularity = config .recompute_granularity
317
311
self .ipp = ipp
318
312
319
- colwise_placements = (
320
- [dist .Replicate (), dist .Shard (1 )]
321
- if self .config .tensor_parallel_degree > 1
322
- else [dist .Replicate (), dist .Replicate ()]
323
- )
324
- rowise_placement = (
325
- [dist .Replicate (), dist .Shard (0 )]
326
- if self .config .tensor_parallel_degree > 1
327
- else [dist .Replicate (), dist .Replicate ()]
328
- )
329
-
330
313
self .use_fused_rope = config .use_fused_rope
331
314
if self .use_fused_rope and get_env_device () not in ["npu" , "mlu" , "xpu" , "gcu" , "intel_hpu" ]:
332
315
if "gpu" not in paddle .device .get_device () or fused_rotary_position_embedding is None :
@@ -1201,10 +1184,23 @@ def forward(self, prediction_scores, masked_lm_labels):
1201
1184
masked_lm_labels .unsqueeze (2 ),
1202
1185
)
1203
1186
1204
- # Hack for XPU that doesn't support Allgather yet .
1187
+ # XPU dose not support allgather mask with bool dtype, so we use LocalLayer here .
1205
1188
if get_env_device () == "xpu" :
1206
- # masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
1207
- loss = paddle .mean (masked_lm_loss , axis = - 1 )
1189
+
1190
+ class LocalLossLayer (paddle .distributed .LocalLayer ):
1191
+ def __init__ (self , out_dist_attrs ):
1192
+ super ().__init__ (out_dist_attrs )
1193
+
1194
+ def forward (self , x , mask ):
1195
+ masked_lm_loss = paddle .masked_select (x , mask ).astype ("float32" )
1196
+ loss = paddle .mean (masked_lm_loss )
1197
+ return loss
1198
+
1199
+ out_dist_attrs = [
1200
+ (masked_lm_loss .process_mesh , [dist .Partial (dist .ReduceType .kRedSum ), dist .Replicate ()]),
1201
+ ]
1202
+ loss_func = LocalLossLayer (out_dist_attrs )
1203
+ loss = loss_func (masked_lm_loss , masked_lm_loss > 0 )
1208
1204
else :
1209
1205
masked_lm_loss = paddle .masked_select (masked_lm_loss , masked_lm_loss > 0 ).astype ("float32" )
1210
1206
loss = paddle .mean (masked_lm_loss )
@@ -1216,11 +1212,7 @@ class LlamaLMHeadAuto(nn.Layer):
1216
1212
def __init__ (self , config : LlamaConfig ):
1217
1213
super (LlamaLMHeadAuto , self ).__init__ ()
1218
1214
self .config = config
1219
- colwise_placements = (
1220
- [dist .Replicate (), dist .Shard (1 )]
1221
- if self .config .tensor_parallel_degree > 1
1222
- else [dist .Replicate (), dist .Replicate ()]
1223
- )
1215
+
1224
1216
vocab_size = config .vocab_size
1225
1217
self .weight = self .create_parameter (
1226
1218
shape = [config .hidden_size , vocab_size ],
0 commit comments