Skip to content

Commit b114665

Browse files
Meet Patelquic-meetkuma
Meet Patel
authored andcommitted
Fixed collate fn for bs>1. It will work fine for bs>1 for llama as well on single device.
Signed-off-by: Meet Patel <[email protected]>
1 parent d3854ff commit b114665

File tree

4 files changed

+9
-19
lines changed

4 files changed

+9
-19
lines changed

QEfficient/cloud/finetune.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,9 @@ def main(**kwargs):
133133
dataset_processer = tokenizer
134134

135135
# Load and preprocess the dataset for training and validation
136-
ctx_len = train_config.context_length
137-
if ctx_len is None and hasattr(model.config, "max_position_embeddings"):
138-
ctx_len = model.config.max_position_embeddings
136+
dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=train_config.context_length)
139137

140-
dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=ctx_len)
141-
142-
dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=ctx_len)
138+
dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=train_config.context_length)
143139

144140
# TODO: vbaddi, check if its necessary to do this?
145141
# dataset_train = ConcatDataset(

QEfficient/finetune/dataset/imdb_dataset.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ def tokenize_add_label(sample):
2828
data = tokenizer(
2929
sample["text"],
3030
add_special_tokens=True,
31-
max_length=context_length,
32-
pad_to_max_length=True,
31+
max_length=tokenizer.model_max_length,
3332
)
3433

35-
data["labels"] = sample["label"]
34+
data["labels"] = [sample["label"]]
3635
return data
3736

3837
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))

QEfficient/finetune/utils/config_utils.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,14 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
8888
num_replicas=dist.get_world_size(),
8989
shuffle=False,
9090
)
91-
if train_config.task_type == "seq_classification":
92-
kwargs["collate_fn"] = default_data_collator
93-
else:
94-
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
9591
else:
9692
kwargs["sampler"] = data_utils.DistributedSampler(
9793
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
9894
)
9995
kwargs["batch_size"] = batch_size
10096
kwargs["drop_last"] = True
101-
kwargs["collate_fn"] = default_data_collator
10297
else:
10398
kwargs["batch_size"] = batch_size
10499
kwargs["drop_last"] = True
105-
kwargs["collate_fn"] = default_data_collator
100+
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
106101
return kwargs

QEfficient/finetune/utils/train_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def train(
193193
loss = model_outputs.loss # Forward call
194194
if train_config.task_type == "seq_classification":
195195
logits = model_outputs.logits
196-
labels = batch["labels"]
196+
labels = batch["labels"][:, 0]
197197
preds = torch.nn.functional.softmax(logits, dim=-1)
198198
acc_helper.forward(preds, labels)
199199
print("Mismatches detected:", verifier.get_perop_mismatch_count())
@@ -202,7 +202,7 @@ def train(
202202
loss = model_outputs.loss # Forward call
203203
if train_config.task_type == "seq_classification":
204204
logits = model_outputs.logits
205-
labels = batch["labels"]
205+
labels = batch["labels"][:, 0]
206206
preds = torch.nn.functional.softmax(logits, dim=-1)
207207
acc_helper.forward(preds, labels)
208208

@@ -305,7 +305,7 @@ def train(
305305
dist.barrier()
306306
dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
307307
train_epoch_loss /= dist.get_world_size()
308-
308+
309309
if train_config.task_type == "seq_classification":
310310
accuracy = acc_helper.compute()
311311
if train_config.enable_ddp:
@@ -515,7 +515,7 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer,
515515
outputs = model(**batch)
516516
loss = outputs.loss
517517
logits = outputs.logits
518-
labels = batch["labels"]
518+
labels = batch["labels"][:, 0]
519519
if train_config.save_metrics:
520520
val_step_loss.append(loss.detach().float().item())
521521
preds = torch.nn.functional.softmax(logits, dim=-1)

0 commit comments

Comments
 (0)