Skip to content

Commit b1fb57d

Browse files
Added finetuning support for BERT based models on IMDB dataset. (#292)
- Added support for SequenceClassification task for BERT based models. - Added support for IMDB dataset. - Introduced a new flag --task_type with default value as "generation". For BERT training, it can be set to "seq_classification". --------- Signed-off-by: Meet Patel <[email protected]>
1 parent 7bc9d1e commit b1fb57d

File tree

9 files changed

+183
-96
lines changed

9 files changed

+183
-96
lines changed

QEfficient/cloud/finetune.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
print(f"Warning: {e}. Moving ahead without these qaic modules.")
3939

4040

41-
from transformers import AutoModelForCausalLM, AutoTokenizer
41+
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
4242

4343
# Suppress all warnings
4444
warnings.filterwarnings("ignore")
@@ -56,6 +56,7 @@ def main(**kwargs):
5656
# update the configuration for the training process
5757
train_config = TRAIN_CONFIG()
5858
update_config(train_config, **kwargs)
59+
dataset_config = generate_dataset_config(train_config, kwargs)
5960
device = train_config.device
6061

6162
# dist init
@@ -78,12 +79,30 @@ def main(**kwargs):
7879
# Load the pre-trained model and setup its configuration
7980
# config = AutoConfig.from_pretrained(train_config.model_name)
8081
pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
81-
model = AutoModelForCausalLM.from_pretrained(
82-
pretrained_model_path,
83-
use_cache=False,
84-
attn_implementation="sdpa",
85-
torch_dtype=torch.float16,
86-
)
82+
if train_config.task_type == "seq_classification":
83+
model = AutoModelForSequenceClassification.from_pretrained(
84+
pretrained_model_path,
85+
num_labels=dataset_config.num_labels,
86+
attn_implementation="sdpa",
87+
torch_dtype=torch.float16,
88+
)
89+
90+
if not hasattr(model, "base_model_prefix"):
91+
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
92+
93+
for param in getattr(model, model.base_model_prefix).parameters():
94+
param.requires_grad = False
95+
96+
for param in model.parameters():
97+
if param.requires_grad:
98+
param.data = param.data.to(torch.float32)
99+
else:
100+
model = AutoModelForCausalLM.from_pretrained(
101+
pretrained_model_path,
102+
use_cache=False,
103+
attn_implementation="sdpa",
104+
torch_dtype=torch.float16,
105+
)
87106

88107
# Load the tokenizer and add special tokens
89108
tokenizer = AutoTokenizer.from_pretrained(
@@ -127,7 +146,6 @@ def main(**kwargs):
127146
model.print_trainable_parameters()
128147

129148
# Get the dataset utils
130-
dataset_config = generate_dataset_config(train_config, kwargs)
131149
dataset_processer = tokenizer
132150

133151
# Load and preprocess the dataset for training and validation

QEfficient/finetune/configs/dataset_config.py

+8
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ class gsm8k_dataset:
3737
test_split: str = "test"
3838

3939

40+
@dataclass
41+
class imdb_dataset:
42+
dataset: str = "imdb_dataset"
43+
train_split: str = "train"
44+
test_split: str = "test"
45+
num_labels: int = 2
46+
47+
4048
@dataclass
4149
class custom_dataset:
4250
dataset: str = "custom_dataset"

QEfficient/finetune/configs/training.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class train_config:
2929
use_autocast: bool = True
3030
val_batch_size: int = 1
3131
dataset = "samsum_dataset"
32+
task_type = "generation" # "generation" / "seq_classification"
3233
peft_method: str = "lora"
3334
use_peft: bool = True # use parameter efficient fine tuning
3435
from_peft_checkpoint: str = "" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint

QEfficient/finetune/dataset/dataset_config.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
get_dataset as get_grammar_dataset,
1919
)
2020
from QEfficient.finetune.dataset.gsm8k_dataset import get_gsm8k_dataset
21-
from QEfficient.finetune.dataset.samsum_dataset import (
22-
get_preprocessed_samsum as get_samsum_dataset,
21+
from QEfficient.finetune.dataset.imdb_dataset import (
22+
get_preprocessed_imdb as get_imdb_dataset,
2323
)
2424
from QEfficient.finetune.dataset.samsum_dataset import (
25-
get_samsum_collate_fn,
25+
get_preprocessed_samsum as get_samsum_dataset,
2626
)
2727

2828
DATASET_PREPROC = {
@@ -31,8 +31,8 @@
3131
"samsum_dataset": get_samsum_dataset,
3232
"gsm8k_dataset": get_gsm8k_dataset,
3333
"custom_dataset": get_custom_dataset,
34+
"imdb_dataset": get_imdb_dataset,
3435
}
3536
DATALOADER_COLLATE_FUNC = {
3637
"custom_dataset": get_data_collator,
37-
"samsum_dataset": get_samsum_collate_fn,
3838
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
9+
from itertools import chain
10+
11+
import datasets
12+
13+
14+
def get_preprocessed_imdb(dataset_config, tokenizer, split, context_length=None):
15+
dataset = datasets.load_dataset("stanfordnlp/imdb", split=split, trust_remote_code=True)
16+
17+
if split == "test":
18+
# Test set contains 15000 samples. Not all are required.
19+
# 0-12499 are 0 labeled samples, 12500-24999 are 1 labeled samples.
20+
dataset = dataset.select(chain(range(0, 500), range(12500, 13000)))
21+
22+
# Need to shuffle dataset as all the 0 labeled data is organized first and then all the 1 labeled data.
23+
dataset = dataset.shuffle(seed=42)
24+
25+
if tokenizer.pad_token is None:
26+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
27+
28+
def tokenize_add_label(sample):
29+
data = tokenizer(
30+
sample["text"],
31+
add_special_tokens=True,
32+
max_length=tokenizer.model_max_length,
33+
)
34+
35+
data["labels"] = [sample["label"]]
36+
return data
37+
38+
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
39+
return dataset

QEfficient/finetune/dataset/samsum_dataset.py

-21
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
# -----------------------------------------------------------------------------
77

88
import datasets
9-
import torch
10-
from torch.nn.utils.rnn import pad_sequence
119

1210

1311
def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
@@ -48,22 +46,3 @@ def tokenize_add_label(sample):
4846
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
4947

5048
return dataset
51-
52-
53-
def collate_fn(batch):
54-
eos_token = batch[0]["input_ids"][-1]
55-
56-
input_ids = pad_sequence(
57-
[torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
58-
)
59-
attn_mask = pad_sequence(
60-
[torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
61-
)
62-
labels = pad_sequence(
63-
[torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
64-
)
65-
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}
66-
67-
68-
def get_samsum_collate_fn(dataset_processer, dataset_config):
69-
return collate_fn

QEfficient/finetune/utils/config_utils.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
LoraConfig,
1616
PrefixTuningConfig,
1717
)
18-
from transformers import default_data_collator
1918
from transformers.data import DataCollatorForSeq2Seq
2019

2120
import QEfficient.finetune.configs.dataset_config as datasets
@@ -88,16 +87,14 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
8887
num_replicas=dist.get_world_size(),
8988
shuffle=False,
9089
)
91-
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
9290
else:
9391
kwargs["sampler"] = data_utils.DistributedSampler(
9492
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
9593
)
9694
kwargs["batch_size"] = batch_size
9795
kwargs["drop_last"] = True
98-
kwargs["collate_fn"] = default_data_collator
9996
else:
10097
kwargs["batch_size"] = batch_size
10198
kwargs["drop_last"] = True
102-
kwargs["collate_fn"] = default_data_collator
99+
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
103100
return kwargs

0 commit comments

Comments
 (0)