File tree Expand file tree Collapse file tree 1 file changed +1
-2
lines changed
applications/ColossalChat/coati/distributed Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Original file line number Diff line number Diff line change @@ -442,7 +442,6 @@ def _criterion(outputs, inputs):
442442 self .plugin .pp_size > 1
443443 and self .booster .plugin .stage_manager .is_last_stage ()
444444 and self .tp_rank == 0
445- and self .dp_rank == 0
446445 ):
447446 reward = all_reduce_mean (reward .mean (), self .plugin )
448447 format_acc = all_reduce_mean (format_acc .mean (), self .plugin )
@@ -469,7 +468,7 @@ def _criterion(outputs, inputs):
469468 self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
470469 ):
471470 if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
472- self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
471+ self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0 and self . dp_rank == 0
473472 ):
474473 raw_batch_reward_mean = torch .cat (self .raw_train_batch_reward , dim = 0 ).mean ().cpu ().item ()
475474 raw_batch_format_acc_mean = torch .cat (self .raw_train_batch_format_acc , dim = 0 ).mean ().cpu ().item ()
You can’t perform that action at this time.
0 commit comments