@@ -1077,7 +1077,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
10771077 # get dispatch backward event
10781078 dispatch_backward_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
10791079
1080- paddle .base .core .nvprof_nvtx_push ("dispatch_backward_dw " )
1080+ paddle .base .core .nvprof_nvtx_push ("mlp_backward_dw " )
10811081 WeightGradStore .pop ()
10821082 assert WeightGradStore .funcs_queue .empty ()
10831083 paddle .base .core .nvprof_nvtx_pop ()
@@ -1108,11 +1108,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
11081108
11091109 if pp_stream is not None :
11101110 send_recv_stream = paddle .device .Stream (stream_base = pp_stream )
1111-
1112- # combine_forward_event.custom_stream_wait( pp_stream)
1113- # final_out_event.custom_stream_wait(pp_stream)
1114-
1115- paddle .base .core .nvprof_nvtx_push ("pp stream add" )
1111+ paddle .base .core .nvprof_nvtx_push ("pp_stream_add" )
11161112
11171113 with paddle .device .stream_guard (send_recv_stream ):
11181114 combine_forward_event .current_stream_wait ()
@@ -1127,20 +1123,27 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
11271123
11281124 dispatch_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
11291125
1130- paddle .base .core .nvprof_nvtx_push ("attn_backward " )
1126+ paddle .base .core .nvprof_nvtx_push ("attn_backward_dx " )
11311127 assert WeightGradStore .funcs_queue .empty ()
11321128 WeightGradStore .enabled = True
11331129 output_grad = self .backward_node .attn_backward (output_grad )
11341130 event_to_wait = deep_ep .get_event_from_calc_stream (self .backward_node .moe_group .id )
1131+ paddle .base .core .nvprof_nvtx_pop ()
11351132
11361133 if EventStore is not None :
11371134 EventStore .set (event_to_wait )
11381135
1136+ if pp_stream is not None :
1137+ # TODO(liangshuhao): this wait may be unnecessary, but we would suffer slow
1138+ # convergence rate without this, so we temporarily put a wait here.
1139+ with paddle .device .stream_guard (send_recv_stream ):
1140+ event_to_wait .current_stream_wait ()
1141+
1142+ paddle .base .core .nvprof_nvtx_push ("attn_backward_dw" )
11391143 WeightGradStore .enabled = False
11401144 WeightGradStore .flush ()
11411145 WeightGradStore .pop ()
11421146 assert WeightGradStore .funcs_queue .empty ()
1143-
11441147 paddle .base .core .nvprof_nvtx_pop ()
11451148
11461149 # residual add
0 commit comments