@@ -334,10 +334,7 @@ def train(
334
334
eval_loss , eval_metric , step_loss , step_metric = evaluation_helper (
335
335
model , train_config , eval_dataloader , device
336
336
)
337
- # Print evaluation metrics
338
- logger .log_rank_zero (
339
- f"Epoch { epoch + 1 } : Eval Loss: { eval_loss .detach ().cpu ():.4f} , Eval metric: { eval_metric .detach ().cpu ():.4f} "
340
- )
337
+
341
338
if eval_loss < best_val_loss :
342
339
best_val_loss = eval_loss
343
340
logger .log_rank_zero (f"Best eval loss on epoch { epoch + 1 } is { best_val_loss :.4f} " )
@@ -350,6 +347,16 @@ def train(
350
347
val_loss .append (float (eval_loss ))
351
348
val_metric .append (float (eval_metric ))
352
349
350
+ if train_config .enable_ddp :
351
+ dist .all_reduce (eval_loss , op = dist .ReduceOp .SUM )
352
+ eval_loss /= get_num_ddp_devices ()
353
+ dist .all_reduce (eval_metric , op = dist .ReduceOp .SUM )
354
+ eval_metric /= get_num_ddp_devices ()
355
+
356
+ logger .log_rank_zero (
357
+ f"Epoch { epoch + 1 } : Eval Loss: { eval_loss .detach ().cpu ():.4f} , Eval metric: { eval_metric .detach ().cpu ():.4f} "
358
+ )
359
+
353
360
# saving the adapters after completion of each epoch
354
361
if train_config .save_model :
355
362
if train_config .enable_ddp :
@@ -469,12 +476,6 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
469
476
else :
470
477
eval_metric = torch .exp (eval_loss )
471
478
472
- if train_config .enable_ddp :
473
- dist .all_reduce (eval_loss , op = dist .ReduceOp .SUM )
474
- eval_loss /= get_num_ddp_devices ()
475
- dist .all_reduce (eval_metric , op = dist .ReduceOp .SUM )
476
- eval_metric /= get_num_ddp_devices ()
477
-
478
479
return eval_loss , eval_metric , val_step_loss , val_step_metric
479
480
480
481
0 commit comments