-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
81 lines (61 loc) · 1.89 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import builtins
import os
from datasets import load_from_disk
from utils import get_local_rank, get_local_rank
from args import parse_args, print_args
from prepare import prepare_tokenizer
from args import (
TaskArguments,
HuggingFaceHubArguments,
DistributedArguments,
TrainingArguments,
SFTArguments,
DeepSpeedArguments,
LoraArguments,
BitsAndBytesArguments,
)
def print(*args, **kwargs):
if get_local_rank() == 0 and not kwargs.get("all_ranks", False):
builtins.print(*args, **kwargs)
def main():
args = parse_args(
TaskArguments,
HuggingFaceHubArguments,
DistributedArguments,
TrainingArguments,
SFTArguments,
DeepSpeedArguments,
LoraArguments,
BitsAndBytesArguments,
)
print_args(args)
if not os.path.exists("./prepared"):
raise ValueError("Dataset not prepared. Did you run `python prepare.py` first?")
# Load dataset
dataset = load_from_disk("./prepared")
# Load tokenizer
tokenizer = args.tokenizer()
tokenizer = prepare_tokenizer(tokenizer)
# Load base model
model = args.model()
# Make new learnable parameters for specialized tokens (added by `prepare_tokenizer`)
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
# Set Llama 2 specific parameters
model.config.use_cache = False
model.config.pretraining_tp = 1
# Set supervised fine-tuning parameters
trainer = args.sft_trainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["test"] if "test" in dataset else None,
dataset_text_field="text",
)
# Train model
trainer.train()
# Save model & tokenizer
save_dir = args.task.save_dir
trainer.model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
if __name__ == "__main__":
main()