Skip to content

Commit 7c4a7c5

Browse files
committed
Fixed few comments. Need to rebase first and then test
1 parent 2a5c217 commit 7c4a7c5

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

QEfficient/finetune/utils/train_utils.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,6 @@ def train(
301301
else:
302302
train_epoch_loss = total_loss / len(train_dataloader)
303303

304-
if train_config.enable_ddp:
305-
# Get the correct train loss from all the nodes.
306-
dist.barrier()
307-
dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
308-
train_epoch_loss /= dist.get_world_size()
309-
310304
if train_config.task_type == "seq_classification":
311305
accuracy = acc_helper.compute()
312306
acc_helper.reset()
@@ -479,10 +473,10 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer,
479473
Returns: eval_acc, eval_epoch_loss
480474
"""
481475
model.eval()
482-
if local_rank is None:
483-
num_classes = model.classifier.out_features
484-
else:
476+
if train_config.enable_ddp:
485477
num_classes = model.module.classifier.out_features
478+
else:
479+
num_classes = model.classifier.out_features
486480

487481
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
488482

0 commit comments

Comments
 (0)