Skip to content

Commit b4cf464

Browse files
Merge branch 'main' into qnnCompilationQEFFBaseModel
2 parents 619a718 + b88b758 commit b4cf464

File tree

6 files changed

+54
-9
lines changed

6 files changed

+54
-9
lines changed

QEfficient/cloud/finetune.py

+12
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ def main(**kwargs):
103103
# print the datatype of the model parameters
104104
# print(get_parameter_dtypes(model))
105105

106+
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
107+
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
108+
# apply gradient checkpointing related hooks to the input embeddings. Without this we will get
109+
# "No inf checks were recorded for this optimizer." error.
110+
# Enable gradient checkpointing
111+
if train_config.gradient_checkpointing:
112+
# Note: below attribute and method is only available in HuggingFace Transformer models.
113+
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
114+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
115+
else:
116+
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
117+
106118
if train_config.use_peft:
107119
# Load the pre-trained peft model checkpoint and setup its configuration
108120
if train_config.from_peft_checkpoint:

QEfficient/finetune/configs/training.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class train_config:
1515
batch_size_training: int = 1
1616
context_length: int = None
1717
gradient_accumulation_steps: int = 4
18+
gradient_checkpointing: bool = False
1819
num_epochs: int = 1
1920
max_train_step: int = 0
2021
max_eval_step: int = 0

QEfficient/finetune/dataset/dataset_config.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from QEfficient.finetune.dataset.samsum_dataset import (
2222
get_preprocessed_samsum as get_samsum_dataset,
2323
)
24+
from QEfficient.finetune.dataset.samsum_dataset import (
25+
get_samsum_collate_fn,
26+
)
2427

2528
DATASET_PREPROC = {
2629
"alpaca_dataset": partial(get_alpaca_dataset),
@@ -29,4 +32,7 @@
2932
"gsm8k_dataset": get_gsm8k_dataset,
3033
"custom_dataset": get_custom_dataset,
3134
}
32-
DATALOADER_COLLATE_FUNC = {"custom_dataset": get_data_collator}
35+
DATALOADER_COLLATE_FUNC = {
36+
"custom_dataset": get_data_collator,
37+
"samsum_dataset": get_samsum_collate_fn,
38+
}

QEfficient/finetune/dataset/samsum_dataset.py

+21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# -----------------------------------------------------------------------------
77

88
import datasets
9+
import torch
10+
from torch.nn.utils.rnn import pad_sequence
911

1012

1113
def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
@@ -46,3 +48,22 @@ def tokenize_add_label(sample):
4648
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
4749

4850
return dataset
51+
52+
53+
def collate_fn(batch):
54+
eos_token = batch[0]["input_ids"][-1]
55+
56+
input_ids = pad_sequence(
57+
[torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
58+
)
59+
attn_mask = pad_sequence(
60+
[torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
61+
)
62+
labels = pad_sequence(
63+
[torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
64+
)
65+
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}
66+
67+
68+
def get_samsum_collate_fn(dataset_processer, dataset_config):
69+
return collate_fn

QEfficient/finetune/utils/train_utils.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def train(
8383
best_val_loss = float("inf")
8484
total_train_steps = 0
8585
max_steps_reached = False # Flag to indicate max training steps reached
86+
device_type = device.split(":")[0]
8687

8788
tensorboard_updates = None
8889
if train_config.enable_ddp:
@@ -95,7 +96,7 @@ def train(
9596
if device.startswith("qaic"):
9697
scaler = QAicGradScaler()
9798
else:
98-
scaler = GradScaler()
99+
scaler = GradScaler(device_type)
99100

100101
loss_0_counter = torch.tensor([0]).to(device)
101102

@@ -177,10 +178,7 @@ def train(
177178
# adjust atol & rtol this as required
178179
atol=1e-1,
179180
use_ref_output_on_mismatch=True,
180-
# report all mismatches
181-
max_failures=None,
182-
# generate unittest for each op once
183-
repeat_same_op=True,
181+
filter_config=qaic_debug.DispatchFilterConfig.default(device),
184182
dump_root_dir=train_config.dump_root_dir + str(step),
185183
) as verifier:
186184
loss = model(**batch).loss # Forward call
@@ -296,8 +294,6 @@ def train(
296294
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
297295
model, train_config, eval_dataloader, local_rank, tokenizer, device
298296
)
299-
dist.barrier()
300-
dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM)
301297
if local_rank == 0:
302298
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
303299

scripts/replicate_kv_head/replicate_kv_heads.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers
1515
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
1616
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
17+
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear
1718

1819

1920
def duplicate_weights_for_linear_layer(
@@ -49,6 +50,15 @@ def duplicate_weights_for_linear_layer(
4950
1,
5051
).view(hidden_size // layer.group_size, new_kv_heads * head_dim)
5152
layer.out_features = layer.out_features * repeat
53+
54+
elif isinstance(layer, FP8DeQuantLinear):
55+
layer.weight.data = torch.repeat_interleave(
56+
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
57+
).view(new_kv_heads * head_dim, hidden_size)
58+
layer.weight_scale.data = torch.repeat_interleave(
59+
layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0
60+
).view(new_kv_heads * head_dim, -1)
61+
5262
else:
5363
layer.weight.data = torch.repeat_interleave(
5464
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
@@ -65,7 +75,6 @@ def main(args):
6575
model_kwargs = {"attn_implementation": "eager"}
6676
if args.num_hidden_layers:
6777
model_kwargs["num_hidden_layers"] = args.num_hidden_layers
68-
6978
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
7079

7180
# Undo the effect of replace_transformers_quantizers

0 commit comments

Comments
 (0)