Skip to content

Commit 702c406

Browse files
committed
transformers-train-llama2.py
1 parent e20aaa8 commit 702c406

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

llm/transformers-train-llama2.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#!/usr/bin/env python3
2+
3+
print("Initializing...")
4+
5+
import os
6+
from datasets import DatasetDict, disable_progress_bar
7+
from datasets.arrow_dataset import Dataset
8+
from transformers.utils import logging
9+
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments
10+
from transformers import pipeline, LlamaConfig, LlamaForCausalLM
11+
from difflib import SequenceMatcher
12+
13+
# ========== Setup ========== #
14+
os.environ["WANDB_DISABLED"] = "true"
15+
os.environ["DS_ACCELERATOR"] = "cpu"
16+
logging.set_verbosity(logging.ERROR)
17+
disable_progress_bar()
18+
19+
# ========== Generate Toy Data ========== #
20+
length = 8
21+
input_text = " ".join(str(i) for i in range(length))
22+
print("input:", input_text)
23+
24+
ds = DatasetDict({
25+
"train": Dataset.from_list([{"text": input_text + ' 0 0'}]),
26+
"valid": Dataset.from_list([{"text": input_text + ' 0 0'}])
27+
})
28+
29+
#samples = [{"text": " ".join(str(i + j) for i in range(8))} for j in range(100)]
30+
#ds = DatasetDict({
31+
# "train": Dataset.from_list(samples),
32+
# "valid": Dataset.from_list(samples[:20])
33+
#})
34+
35+
#print(samples)
36+
37+
# ========== Model Configuration ========== #
38+
config = LlamaConfig(
39+
vocab_size=32000,
40+
hidden_size=512,
41+
intermediate_size=512,
42+
num_hidden_layers=4,
43+
num_attention_heads=4,
44+
max_position_embeddings=128,
45+
rms_norm_eps=1e-6,
46+
initializer_range=0.02,
47+
pad_token_id=0,
48+
bos_token_id=1,
49+
eos_token_id=2
50+
)
51+
52+
model = LlamaForCausalLM(config)
53+
54+
# ========== Tokenizer ========== #
55+
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
56+
57+
# Set pad_token if missing
58+
if tokenizer.pad_token is None:
59+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
60+
print("Added pad_token.")
61+
62+
# Resize model to match tokenizer
63+
model.resize_token_embeddings(len(tokenizer))
64+
65+
# Save tokenizer (optional)
66+
tokenizer.save_pretrained("tokenizer.llama")
67+
68+
# ========== Tokenization ========== #
69+
def tokenize_function0(examples):
70+
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=length)
71+
72+
def tokenize_function(examples):
73+
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=length, return_tensors=None, add_special_tokens=True)
74+
75+
76+
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text"])
77+
78+
# ========== Data Collator ========== #
79+
dc = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
80+
81+
# ========== Training Arguments ========== #
82+
training_args = TrainingArguments(
83+
output_dir='./results',
84+
per_device_train_batch_size=1,
85+
gradient_accumulation_steps=16,
86+
num_train_epochs=20,
87+
learning_rate=2e-4,
88+
logging_steps=1,
89+
save_steps=100,
90+
save_total_limit=2,
91+
warmup_steps=5,
92+
weight_decay=0.01,
93+
dataloader_num_workers=0,
94+
fp16=False,
95+
optim="adamw_torch",
96+
logging_dir='./logs',
97+
report_to="none"
98+
)
99+
100+
# ========== Training ========== #
101+
trainer = Trainer(
102+
model=model,
103+
args=training_args,
104+
train_dataset=tokenized_datasets["train"],
105+
eval_dataset=tokenized_datasets["valid"],
106+
data_collator=dc,
107+
)
108+
109+
print("Training...")
110+
train_output = trainer.train()
111+
print("Training output:", train_output)
112+
113+
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
114+
115+
output = generator(input_text, max_length=length + 10, num_return_sequences=1)[0]['generated_text']
116+
print("Similarity:", SequenceMatcher(None, input_text, output).ratio())
117+
print(f"Generated output: '{output}'")
118+
119+
model.save_pretrained('trained_model.llama')
120+
tokenizer.save_pretrained('trained_model.llama')

0 commit comments

Comments
 (0)