Skip to content

Commit e3df770

Browse files
committed
Moved eval all reduce code after including it in json list.
Signed-off-by: meetkuma <[email protected]>
1 parent 76ce094 commit e3df770

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

QEfficient/finetune/utils/train_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,9 @@ def train(
369369
eval_loss, eval_metric, step_loss, step_metric = evaluation_helper(
370370
model, train_config, eval_dataloader, device
371371
)
372-
# Print evaluation metrics
373-
print(
374-
f"Epoch {epoch + 1}: Eval Loss: {eval_loss.detach().cpu():.4f}, Eval metric: {eval_metric.detach().cpu():.4f}"
375-
)
376372
if eval_loss < best_val_loss:
377373
best_val_loss = eval_loss
378-
print(f"best eval loss on epoch {epoch + 1} is {best_val_loss:.4f}")
374+
print(f"Best eval loss on epoch {epoch + 1} is {best_val_loss:.4f}")
379375

380376
if is_rank_zero():
381377
tensorboard_updates.add_scalars("loss", {"eval": eval_loss}, total_train_steps)
@@ -385,6 +381,16 @@ def train(
385381
val_loss.append(float(eval_loss))
386382
val_metric.append(float(eval_metric))
387383

384+
if train_config.enable_ddp:
385+
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
386+
eval_loss /= get_num_ddp_devices()
387+
dist.all_reduce(eval_metric, op=dist.ReduceOp.SUM)
388+
eval_metric /= get_num_ddp_devices()
389+
390+
print(
391+
f"Epoch {epoch + 1}: Eval Loss: {eval_loss.detach().cpu():.4f}, Eval metric: {eval_metric.detach().cpu():.4f}"
392+
)
393+
388394
# saving the adapters after completion of each epoch
389395
if train_config.save_model:
390396
if train_config.enable_ddp:
@@ -507,12 +513,6 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
507513
else:
508514
eval_metric = torch.exp(eval_loss)
509515

510-
if train_config.enable_ddp:
511-
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
512-
eval_loss /= get_num_ddp_devices()
513-
dist.all_reduce(eval_metric, op=dist.ReduceOp.SUM)
514-
eval_metric /= get_num_ddp_devices()
515-
516516
return eval_loss, eval_metric, val_step_loss, val_step_metric
517517

518518

0 commit comments

Comments
 (0)