Skip to content

Commit b88b758

Browse files
[QEff Finetune]: Added support for gradient checkpointing in the finetuning script. (#338)
Added --gradient_checkpointing new CLI flag to enable this feature. Currently this is enabled for all the HF models which has "supports_gradient_checkpointing" attribute set to True. --------- Signed-off-by: Meet Patel <[email protected]>
1 parent 4f5dbef commit b88b758

File tree

5 files changed

+42
-7
lines changed

5 files changed

+42
-7
lines changed

QEfficient/cloud/finetune.py

+12
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ def main(**kwargs):
103103
# print the datatype of the model parameters
104104
# print(get_parameter_dtypes(model))
105105

106+
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
107+
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
108+
# apply gradient checkpointing related hooks to the input embeddings. Without this we will get
109+
# "No inf checks were recorded for this optimizer." error.
110+
# Enable gradient checkpointing
111+
if train_config.gradient_checkpointing:
112+
# Note: below attribute and method is only available in HuggingFace Transformer models.
113+
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
114+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
115+
else:
116+
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
117+
106118
if train_config.use_peft:
107119
# Load the pre-trained peft model checkpoint and setup its configuration
108120
if train_config.from_peft_checkpoint:

QEfficient/finetune/configs/training.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class train_config:
1515
batch_size_training: int = 1
1616
context_length: int = None
1717
gradient_accumulation_steps: int = 4
18+
gradient_checkpointing: bool = False
1819
num_epochs: int = 1
1920
max_train_step: int = 0
2021
max_eval_step: int = 0

QEfficient/finetune/dataset/dataset_config.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from QEfficient.finetune.dataset.samsum_dataset import (
2222
get_preprocessed_samsum as get_samsum_dataset,
2323
)
24+
from QEfficient.finetune.dataset.samsum_dataset import (
25+
get_samsum_collate_fn,
26+
)
2427

2528
DATASET_PREPROC = {
2629
"alpaca_dataset": partial(get_alpaca_dataset),
@@ -29,4 +32,7 @@
2932
"gsm8k_dataset": get_gsm8k_dataset,
3033
"custom_dataset": get_custom_dataset,
3134
}
32-
DATALOADER_COLLATE_FUNC = {"custom_dataset": get_data_collator}
35+
DATALOADER_COLLATE_FUNC = {
36+
"custom_dataset": get_data_collator,
37+
"samsum_dataset": get_samsum_collate_fn,
38+
}

QEfficient/finetune/dataset/samsum_dataset.py

+21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# -----------------------------------------------------------------------------
77

88
import datasets
9+
import torch
10+
from torch.nn.utils.rnn import pad_sequence
911

1012

1113
def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
@@ -46,3 +48,22 @@ def tokenize_add_label(sample):
4648
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
4749

4850
return dataset
51+
52+
53+
def collate_fn(batch):
54+
eos_token = batch[0]["input_ids"][-1]
55+
56+
input_ids = pad_sequence(
57+
[torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
58+
)
59+
attn_mask = pad_sequence(
60+
[torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
61+
)
62+
labels = pad_sequence(
63+
[torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
64+
)
65+
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}
66+
67+
68+
def get_samsum_collate_fn(dataset_processer, dataset_config):
69+
return collate_fn

QEfficient/finetune/utils/train_utils.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,7 @@ def train(
178178
# adjust atol & rtol this as required
179179
atol=1e-1,
180180
use_ref_output_on_mismatch=True,
181-
# report all mismatches
182-
max_failures=None,
183-
# generate unittest for each op once
184-
repeat_same_op=True,
181+
filter_config=qaic_debug.DispatchFilterConfig.default(device),
185182
dump_root_dir=train_config.dump_root_dir + str(step),
186183
) as verifier:
187184
loss = model(**batch).loss # Forward call
@@ -297,8 +294,6 @@ def train(
297294
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
298295
model, train_config, eval_dataloader, local_rank, tokenizer, device
299296
)
300-
dist.barrier()
301-
dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM)
302297
if local_rank == 0:
303298
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
304299

0 commit comments

Comments
 (0)