Skip to content

Commit e682111

Browse files
committed
Removed samsum collate_fn as it is dead code.
1 parent fd5ec6a commit e682111

File tree

1 file changed

+0
-21
lines changed

1 file changed

+0
-21
lines changed

QEfficient/finetune/dataset/samsum_dataset.py

-21
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
# -----------------------------------------------------------------------------
77

88
import datasets
9-
import torch
10-
from torch.nn.utils.rnn import pad_sequence
119

1210

1311
def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
@@ -48,22 +46,3 @@ def tokenize_add_label(sample):
4846
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
4947

5048
return dataset
51-
52-
53-
def collate_fn(batch):
54-
eos_token = batch[0]["input_ids"][-1]
55-
56-
input_ids = pad_sequence(
57-
[torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
58-
)
59-
attn_mask = pad_sequence(
60-
[torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
61-
)
62-
labels = pad_sequence(
63-
[torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
64-
)
65-
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}
66-
67-
68-
def get_samsum_collate_fn(dataset_processer, dataset_config):
69-
return collate_fn

0 commit comments

Comments
 (0)