Skip to content

Commit 621398f

Browse files
committed
Fixed collate fn for bs>1. It will work fine for bs>1 for llama as well on single device.
Signed-off-by: Meet Patel <[email protected]>
1 parent d3c92a3 commit 621398f

File tree

4 files changed

+9
-19
lines changed

4 files changed

+9
-19
lines changed

QEfficient/cloud/finetune.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,9 @@ def main(**kwargs):
145145
dataset_processer = tokenizer
146146

147147
# Load and preprocess the dataset for training and validation
148-
ctx_len = train_config.context_length
149-
if ctx_len is None and hasattr(model.config, "max_position_embeddings"):
150-
ctx_len = model.config.max_position_embeddings
148+
dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=train_config.context_length)
151149

152-
dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=ctx_len)
153-
154-
dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=ctx_len)
150+
dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=train_config.context_length)
155151

156152
# TODO: vbaddi, check if its necessary to do this?
157153
# dataset_train = ConcatDataset(

QEfficient/finetune/dataset/imdb_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ def tokenize_add_label(sample):
2828
data = tokenizer(
2929
sample["text"],
3030
add_special_tokens=True,
31-
max_length=context_length,
32-
pad_to_max_length=True,
31+
max_length=tokenizer.model_max_length,
3332
)
3433

35-
data["labels"] = sample["label"]
34+
data["labels"] = [sample["label"]]
3635
return data
3736

3837
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))

QEfficient/finetune/utils/config_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,14 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
8888
num_replicas=dist.get_world_size(),
8989
shuffle=False,
9090
)
91-
if train_config.task_type == "seq_classification":
92-
kwargs["collate_fn"] = default_data_collator
93-
else:
94-
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
9591
else:
9692
kwargs["sampler"] = data_utils.DistributedSampler(
9793
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
9894
)
9995
kwargs["batch_size"] = batch_size
10096
kwargs["drop_last"] = True
101-
kwargs["collate_fn"] = default_data_collator
10297
else:
10398
kwargs["batch_size"] = batch_size
10499
kwargs["drop_last"] = True
105-
kwargs["collate_fn"] = default_data_collator
100+
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
106101
return kwargs

QEfficient/finetune/utils/train_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def train(
194194
loss = model_outputs.loss # Forward call
195195
if train_config.task_type == "seq_classification":
196196
logits = model_outputs.logits
197-
labels = batch["labels"]
197+
labels = batch["labels"][:, 0]
198198
preds = torch.nn.functional.softmax(logits, dim=-1)
199199
acc_helper.forward(preds, labels)
200200
print("Mismatches detected:", verifier.get_perop_mismatch_count())
@@ -203,7 +203,7 @@ def train(
203203
loss = model_outputs.loss # Forward call
204204
if train_config.task_type == "seq_classification":
205205
logits = model_outputs.logits
206-
labels = batch["labels"]
206+
labels = batch["labels"][:, 0]
207207
preds = torch.nn.functional.softmax(logits, dim=-1)
208208
acc_helper.forward(preds, labels)
209209

@@ -306,7 +306,7 @@ def train(
306306
dist.barrier()
307307
dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
308308
train_epoch_loss /= dist.get_world_size()
309-
309+
310310
if train_config.task_type == "seq_classification":
311311
accuracy = acc_helper.compute()
312312
if train_config.enable_ddp:
@@ -513,7 +513,7 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer,
513513
outputs = model(**batch)
514514
loss = outputs.loss
515515
logits = outputs.logits
516-
labels = batch["labels"]
516+
labels = batch["labels"][:, 0]
517517
if train_config.save_metrics:
518518
val_step_loss.append(loss.detach().float().item())
519519
preds = torch.nn.functional.softmax(logits, dim=-1)

0 commit comments

Comments
 (0)