Skip to content

Commit 5a7999f

Browse files
Meet Patelquic-meetkuma
Meet Patel
authored andcommitted
Added torchmetrics as dependency and fixed loss computation for ddp case.
Signed-off-by: Meet Patel <[email protected]>
1 parent bd01ca6 commit 5a7999f

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

QEfficient/finetune/utils/train_utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,18 @@ def train(
300300
else:
301301
train_epoch_loss = total_loss / len(train_dataloader)
302302

303+
if train_config.enable_ddp:
304+
# Get the correct train loss from all the nodes.
305+
dist.barrier()
306+
dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
307+
train_epoch_loss /= dist.get_world_size()
308+
303309
if train_config.task_type == "seq_classification":
304-
train_perplexity = acc_helper.compute()
310+
accuracy = acc_helper.compute()
311+
if train_config.enable_ddp:
312+
dist.all_reduce(accuracy, op=dist.ReduceOp.SUM)
313+
accuracy /= dist.get_world_size()
314+
train_perplexity = accuracy
305315
else:
306316
train_perplexity = torch.exp(train_epoch_loss)
307317

@@ -319,6 +329,7 @@ def train(
319329
)
320330
dist.barrier()
321331
dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM)
332+
eval_epoch_loss /= dist.get_world_size()
322333
if local_rank == 0:
323334
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
324335

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ license = { file = "LICENSE" }
99
authors = [{ name = "Qualcomm Cloud AI ML Team" }]
1010
keywords = ["transformers", "Cloud AI 100", "Inference"]
1111
classifiers = [
12-
"Programming Language :: Python :: 3",
12+
"Programming Language :: Python :: 3",
1313
"Development Status :: 5 - Development/Unstable",
1414
"Intended Audience :: Developers",
1515
"Intended Audience :: Education",
@@ -38,6 +38,7 @@ dependencies = [
3838
"tensorboard",
3939
"fire",
4040
"py7zr",
41+
"torchmetrics==1.7.0",
4142
"torch==2.4.1; platform_machine=='aarch64'",
4243
# Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11
4344
"torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'",
@@ -60,7 +61,7 @@ namespaces = false
6061

6162
[tool.setuptools.dynamic.version]
6263
attr = "QEfficient.__version__"
63-
64+
6465
[tool.ruff]
6566
line-length = 120
6667
# Enable the isort rules.

0 commit comments

Comments
 (0)