Skip to content

Commit 3aaa2d8

Browse files
[QEff. Finetune]: Added support to sync gradients across devices during optimizer step only. (#477)
Disabling gradient is necessary when using gradient_accumulation_step > 1 with ddp enabled. Currently, we are syncing gradient at every loss.backward() call, which is called at all steps. When using gradient accumulation, the weight update during opt.step() step. Only during that step, the gradients across each devices should sync with each other. with model.no_sync() --> context manager solves this issue. Here, we are not using it but instead setting ddp_model.require_backward_grad_sync to True or False depending on which step we are. --------- Signed-off-by: Meet Patel <[email protected]> Signed-off-by: meetkuma <[email protected]>
1 parent 2ba491d commit 3aaa2d8

File tree

2 files changed

+78
-59
lines changed

2 files changed

+78
-59
lines changed

QEfficient/finetune/utils/helper.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
#
66
# -----------------------------------------------------------------------------
77
import os
8+
from contextlib import nullcontext
9+
10+
import torch
11+
12+
try:
13+
import torch_qaic.debug as qaic_debug # noqa: F401
14+
except ImportError as e:
15+
print(f"Warning: {e}. Moving ahead without these qaic modules.")
16+
817

918
TASK_TYPE = ["generation", "seq_classification"]
1019
PEFT_METHOD = ["lora"]
@@ -14,3 +23,34 @@
1423

1524
def get_num_ddp_devices():
1625
return int(os.getenv("WORLD_SIZE", 1))
26+
27+
28+
def get_autocast_ctx(use_autocast, device_type, dtype=torch.float16):
29+
return torch.autocast(device_type=device_type, dtype=dtype) if use_autocast else nullcontext()
30+
31+
32+
def get_op_verifier_ctx(
33+
use_op_by_op_verifier,
34+
train_device,
35+
dump_dir,
36+
step,
37+
ref_device="cpu",
38+
ref_dtype=torch.float32,
39+
atol=1e-1,
40+
rtol=1e-5,
41+
use_ref_output_on_mismatch=True,
42+
):
43+
if not use_op_by_op_verifier:
44+
return nullcontext()
45+
46+
filter_config = qaic_debug.DispatchFilterConfig.default(train_device)
47+
dump_dir = dump_dir + "_" + str(step)
48+
return qaic_debug.OpByOpVerifierMode(
49+
ref_device=ref_device,
50+
ref_dtype=ref_dtype,
51+
atol=atol,
52+
rtol=rtol,
53+
use_ref_output_on_mismatch=use_ref_output_on_mismatch,
54+
filter_config=filter_config,
55+
dump_root_dir=dump_dir,
56+
)

QEfficient/finetune/utils/train_utils.py

Lines changed: 38 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import json
99
import os
1010
import time
11-
from contextlib import nullcontext
1211
from datetime import datetime
12+
from functools import partial
1313
from typing import Dict, List, Tuple
1414

1515
import torch
@@ -19,6 +19,7 @@
1919
from tqdm import tqdm
2020

2121
from QEfficient.finetune.configs.training import TrainConfig
22+
from QEfficient.finetune.utils.helper import get_autocast_ctx, get_op_verifier_ctx
2223

2324
try:
2425
import torch_qaic # noqa: F401
@@ -110,6 +111,9 @@ def train(
110111
num_classes = model.classifier.out_features
111112
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
112113

114+
autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
115+
op_verifier_ctx = partial(get_op_verifier_ctx, train_config.opByOpVerifier, device, train_config.dump_root_dir)
116+
113117
# Start the training loop
114118
for epoch in range(train_config.num_epochs):
115119
if loss_0_counter.item() == train_config.convergence_counter:
@@ -174,60 +178,38 @@ def train(
174178
break
175179
batch = {k: v.to(device) for k, v in batch.items()} # move the batch elements to qaic device
176180

177-
with (
178-
torch.autocast(device_type=device_type, dtype=torch.float16)
179-
if train_config.use_autocast
180-
else nullcontext()
181-
):
182-
# an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
183-
if train_config.opByOpVerifier:
184-
with qaic_debug.OpByOpVerifierMode(
185-
ref_device="cpu",
186-
ref_dtype=torch.float32,
187-
# adjust atol & rtol this as required
188-
atol=1e-1,
189-
use_ref_output_on_mismatch=True,
190-
filter_config=qaic_debug.DispatchFilterConfig.default(device),
191-
dump_root_dir=train_config.dump_root_dir + str(step),
192-
) as verifier:
193-
model_outputs = model(**batch)
194-
loss = model_outputs.loss # Forward call
195-
if (batch["labels"] != -100).sum() == 0:
196-
loss = loss.nan_to_num(nan=0.0)
197-
num_dummy_samples += train_config.train_batch_size
198-
else:
199-
num_dummy_samples_per_batch = (
200-
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
201-
)
202-
if num_dummy_samples_per_batch > 0:
203-
num_dummy_samples += num_dummy_samples_per_batch
204-
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch
205-
206-
if train_config.task_type == "seq_classification":
207-
logits = model_outputs.logits
208-
labels = batch["labels"][:, 0]
209-
preds = torch.nn.functional.softmax(logits, dim=-1)
210-
acc_helper.forward(preds, labels)
211-
print("Mismatches detected:", verifier.get_perop_mismatch_count())
181+
is_optimizer_step = (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(
182+
train_dataloader
183+
) - 1
184+
if train_config.enable_ddp:
185+
# Below block derived from : https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/train.py#L293
186+
# in DDP training we only need to sync gradients at the last micro step.
187+
# the official way to do this is with model.no_sync() context manager, but
188+
# using too many context managers may bloat the code and forces us to repeat code
189+
# looking at the source of that context manager, it just toggles this variable
190+
model.require_backward_grad_sync = is_optimizer_step
191+
192+
with autocast_ctx, op_verifier_ctx(step) as verifier:
193+
model_outputs = model(**batch)
194+
loss = model_outputs.loss # Forward call
195+
if (batch["labels"] != -100).sum() == 0:
196+
loss = loss.nan_to_num(nan=0.0)
197+
num_dummy_samples += train_config.train_batch_size
212198
else:
213-
model_outputs = model(**batch)
214-
loss = model_outputs.loss # Forward call
215-
if (batch["labels"] != -100).sum() == 0:
216-
loss = loss.nan_to_num(nan=0.0)
217-
num_dummy_samples += train_config.train_batch_size
218-
else:
219-
num_dummy_samples_per_batch = (
220-
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
221-
)
222-
if num_dummy_samples_per_batch > 0:
223-
num_dummy_samples += num_dummy_samples_per_batch
224-
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch
199+
num_dummy_samples_per_batch = (
200+
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
201+
)
202+
if num_dummy_samples_per_batch > 0:
203+
num_dummy_samples += num_dummy_samples_per_batch
204+
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch
225205

226-
if train_config.task_type == "seq_classification":
227-
logits = model_outputs.logits
228-
labels = batch["labels"][:, 0]
229-
preds = torch.nn.functional.softmax(logits, dim=-1)
230-
acc_helper.forward(preds, labels)
206+
if train_config.task_type == "seq_classification":
207+
logits = model_outputs.logits
208+
labels = batch["labels"][:, 0]
209+
preds = torch.nn.functional.softmax(logits, dim=-1)
210+
acc_helper.forward(preds, labels)
211+
if train_config.opByOpVerifier:
212+
print("Mismatches detected:", verifier.get_perop_mismatch_count())
231213

232214
total_loss += loss.detach().float()
233215

@@ -274,7 +256,7 @@ def train(
274256
else:
275257
loss.backward() # backward pass
276258

277-
if (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
259+
if is_optimizer_step:
278260
if train_config.grad_scaler:
279261
scaler.step(optimizer)
280262
scaler.update()
@@ -468,6 +450,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
468450
device_type = torch.device(device).type
469451

470452
num_dummy_samples = 0
453+
autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
471454
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
472455
# stop when the maximum number of eval steps is reached
473456
if train_config.max_eval_step > 0 and step > train_config.max_eval_step:
@@ -478,11 +461,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
478461
# Ensure no gradients are computed for this scope to save memory
479462
with torch.no_grad():
480463
# Forward pass and compute loss
481-
with (
482-
torch.autocast(device_type=device_type, dtype=torch.float16)
483-
if train_config.use_autocast
484-
else nullcontext()
485-
):
464+
with autocast_ctx:
486465
outputs = model(**batch)
487466
loss = outputs.loss
488467

0 commit comments

Comments
 (0)