File tree Expand file tree Collapse file tree 2 files changed +14
-3
lines changed
QEfficient/finetune/utils Expand file tree Collapse file tree 2 files changed +14
-3
lines changed Original file line number Diff line number Diff line change @@ -301,8 +301,18 @@ def train(
301
301
else :
302
302
train_epoch_loss = total_loss / len (train_dataloader )
303
303
304
+ if train_config .enable_ddp :
305
+ # Get the correct train loss from all the nodes.
306
+ dist .barrier ()
307
+ dist .all_reduce (train_epoch_loss , op = dist .ReduceOp .SUM )
308
+ train_epoch_loss /= dist .get_world_size ()
309
+
304
310
if train_config .task_type == "seq_classification" :
305
- train_perplexity = acc_helper .compute ()
311
+ accuracy = acc_helper .compute ()
312
+ if train_config .enable_ddp :
313
+ dist .all_reduce (accuracy , op = dist .ReduceOp .SUM )
314
+ accuracy /= dist .get_world_size ()
315
+ train_perplexity = accuracy
306
316
else :
307
317
train_perplexity = torch .exp (train_epoch_loss )
308
318
Original file line number Diff line number Diff line change @@ -9,7 +9,7 @@ license = { file = "LICENSE" }
9
9
authors = [{ name = " Qualcomm Cloud AI ML Team" }]
10
10
keywords = [" transformers" , " Cloud AI 100" , " Inference" ]
11
11
classifiers = [
12
- " Programming Language :: Python :: 3" ,
12
+ " Programming Language :: Python :: 3" ,
13
13
" Development Status :: 5 - Development/Unstable" ,
14
14
" Intended Audience :: Developers" ,
15
15
" Intended Audience :: Education" ,
@@ -38,6 +38,7 @@ dependencies = [
38
38
" tensorboard" ,
39
39
" fire" ,
40
40
" py7zr" ,
41
+ " torchmetrics==1.7.0" ,
41
42
" torch==2.4.1; platform_machine=='aarch64'" ,
42
43
# Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11
43
44
" torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'" ,
@@ -60,7 +61,7 @@ namespaces = false
60
61
61
62
[tool .setuptools .dynamic .version ]
62
63
attr = " QEfficient.__version__"
63
-
64
+
64
65
[tool .ruff ]
65
66
line-length = 120
66
67
# Enable the isort rules.
You can’t perform that action at this time.
0 commit comments