File tree 2 files changed +15
-3
lines changed
QEfficient/finetune/utils
2 files changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -300,8 +300,18 @@ def train(
300
300
else :
301
301
train_epoch_loss = total_loss / len (train_dataloader )
302
302
303
+ if train_config .enable_ddp :
304
+ # Get the correct train loss from all the nodes.
305
+ dist .barrier ()
306
+ dist .all_reduce (train_epoch_loss , op = dist .ReduceOp .SUM )
307
+ train_epoch_loss /= dist .get_world_size ()
308
+
303
309
if train_config .task_type == "seq_classification" :
304
- train_perplexity = acc_helper .compute ()
310
+ accuracy = acc_helper .compute ()
311
+ if train_config .enable_ddp :
312
+ dist .all_reduce (accuracy , op = dist .ReduceOp .SUM )
313
+ accuracy /= dist .get_world_size ()
314
+ train_perplexity = accuracy
305
315
else :
306
316
train_perplexity = torch .exp (train_epoch_loss )
307
317
@@ -319,6 +329,7 @@ def train(
319
329
)
320
330
dist .barrier ()
321
331
dist .all_reduce (eval_epoch_loss , op = dist .ReduceOp .SUM )
332
+ eval_epoch_loss /= dist .get_world_size ()
322
333
if local_rank == 0 :
323
334
tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
324
335
Original file line number Diff line number Diff line change @@ -9,7 +9,7 @@ license = { file = "LICENSE" }
9
9
authors = [{ name = " Qualcomm Cloud AI ML Team" }]
10
10
keywords = [" transformers" , " Cloud AI 100" , " Inference" ]
11
11
classifiers = [
12
- " Programming Language :: Python :: 3" ,
12
+ " Programming Language :: Python :: 3" ,
13
13
" Development Status :: 5 - Development/Unstable" ,
14
14
" Intended Audience :: Developers" ,
15
15
" Intended Audience :: Education" ,
@@ -38,6 +38,7 @@ dependencies = [
38
38
" tensorboard" ,
39
39
" fire" ,
40
40
" py7zr" ,
41
+ " torchmetrics==1.7.0" ,
41
42
" torch==2.4.1; platform_machine=='aarch64'" ,
42
43
# Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11
43
44
" torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'" ,
@@ -60,7 +61,7 @@ namespaces = false
60
61
61
62
[tool .setuptools .dynamic .version ]
62
63
attr = " QEfficient.__version__"
63
-
64
+
64
65
[tool .ruff ]
65
66
line-length = 120
66
67
# Enable the isort rules.
You can’t perform that action at this time.
0 commit comments