-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare.py
113 lines (79 loc) · 3.39 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import json
from functools import partial
from transformers import AutoTokenizer
from datasets import load_dataset
from args import parse_args, print_args, TaskArguments, HuggingFaceHubArguments, SFTArguments
def create_prompt_formats(sample):
"""
Format various fields of the sample ('instruction', 'context', 'response')
Then concatenate them using two newline characters
:param sample: Sample dictionnary
"""
START_TEXT = "<START_TEXT>"
END_TEXT = "<END_TEXT>"
START_REPR = "<START_REPR>"
END_REPR = "<END_REPR>"
QUESTION_KEY = "### Instruction:\nYou are a math assistant. Solve the following problem.\n### Problem:"
question = f"{START_TEXT}\n{QUESTION_KEY}\n{sample['question']}\n{END_TEXT}"
answer = f"{START_REPR}\nLet's think step by step,\n{sample['answer'].replace('####', '### Answer:')}\n{END_REPR}"
parts = [part for part in [question, answer] if part]
formatted_prompt = "\n".join(parts)
sample["text"] = formatted_prompt
return sample
def preprocess_batch(batch, tokenizer, max_length):
"""
Tokenizing a batch
"""
return tokenizer(
batch["text"],
max_length=max_length,
truncation=True,
)
# SOURCE https://github.com/databrickslabs/dolly/blob/master/training/trainer.py
def preprocess_dataset(tokenizer: AutoTokenizer, dataset, max_seq_length=2048, seed=42):
"""Format & tokenize it so it is ready for training
:param tokenizer (AutoTokenizer): Model Tokenizer
:param max_length (int): Maximum number of tokens to emit from tokenizer
"""
# Add prompt to each sample
print("Preprocessing dataset...")
dataset = dataset.map(create_prompt_formats) # , batched=True)
# Apply preprocessing to each batch of the dataset & and remove 'instruction', 'context', 'response', 'category' fields
_preprocessing_function = partial(preprocess_batch, max_length=max_seq_length, tokenizer=tokenizer)
dataset = dataset.map(
_preprocessing_function,
batched=True,
remove_columns=["question", "answer"],
)
# Filter out samples that have input_ids exceeding max_length
dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_seq_length)
report = {
"max_seq_length": max([len(sample["input_ids"]) for split in dataset for sample in dataset[split]]),
}
print(dataset)
# Shuffle dataset
dataset = dataset.shuffle(seed=seed)
return dataset, report
def prepare_tokenizer(tokenizer: AutoTokenizer):
special_tokens = ["<START_TEXT>", "<END_TEXT>", "<START_REPR>", "<END_REPR>"]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp15 training
tokenizer.add_tokens(special_tokens, special_tokens=True)
return tokenizer
def main():
# Parse args
args = parse_args(TaskArguments, HuggingFaceHubArguments, SFTArguments)
print_args(args)
tokenizer = args.tokenizer()
tokenizer = prepare_tokenizer(tokenizer)
## Preprocess dataset
dataset = load_dataset("gsm8k", "main")
if args.task.clear_data_cache:
dataset.cleanup_cache_files()
dataset, report = preprocess_dataset(tokenizer, dataset, args.sft.max_seq_length)
print("Dataset report:", report)
dataset.save_to_disk("./prepared")
with open("./prepared/report.json", "w") as f:
json.dump(report, f)
if __name__ == "__main__":
main()