Skip to content

Commit e8c32c1

Browse files
committed
Added torchmetrics as dependency and fixed loss computation for ddp case.
Signed-off-by: Meet Patel <[email protected]>
1 parent c0b9b8c commit e8c32c1

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

QEfficient/finetune/utils/train_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,18 @@ def train(
301301
else:
302302
train_epoch_loss = total_loss / len(train_dataloader)
303303

304+
if train_config.enable_ddp:
305+
# Get the correct train loss from all the nodes.
306+
dist.barrier()
307+
dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
308+
train_epoch_loss /= dist.get_world_size()
309+
304310
if train_config.task_type == "seq_classification":
305-
train_perplexity = acc_helper.compute()
311+
accuracy = acc_helper.compute()
312+
if train_config.enable_ddp:
313+
dist.all_reduce(accuracy, op=dist.ReduceOp.SUM)
314+
accuracy /= dist.get_world_size()
315+
train_perplexity = accuracy
306316
else:
307317
train_perplexity = torch.exp(train_epoch_loss)
308318

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ license = { file = "LICENSE" }
99
authors = [{ name = "Qualcomm Cloud AI ML Team" }]
1010
keywords = ["transformers", "Cloud AI 100", "Inference"]
1111
classifiers = [
12-
"Programming Language :: Python :: 3",
12+
"Programming Language :: Python :: 3",
1313
"Development Status :: 5 - Development/Unstable",
1414
"Intended Audience :: Developers",
1515
"Intended Audience :: Education",
@@ -38,6 +38,7 @@ dependencies = [
3838
"tensorboard",
3939
"fire",
4040
"py7zr",
41+
"torchmetrics==1.7.0",
4142
"torch==2.4.1; platform_machine=='aarch64'",
4243
# Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11
4344
"torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'",
@@ -60,7 +61,7 @@ namespaces = false
6061

6162
[tool.setuptools.dynamic.version]
6263
attr = "QEfficient.__version__"
63-
64+
6465
[tool.ruff]
6566
line-length = 120
6667
# Enable the isort rules.

0 commit comments

Comments
 (0)