Skip to content

Commit fc8d112

Browse files
committed
train-transformers.py
1 parent 180f8b6 commit fc8d112

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

llm/train-transformers.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/usr/bin/env python3
2+
3+
print("initializing")
4+
5+
import os
6+
from datasets import DatasetDict, disable_progress_bar
7+
from datasets.arrow_dataset import Dataset
8+
from transformers.utils import logging
9+
from transformers import GPT2Config, GPT2LMHeadModel
10+
from transformers import AutoTokenizer
11+
# https://huggingface.co/docs/transformers/en/main_classes/data_collator
12+
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
13+
from transformers import pipeline
14+
from difflib import SequenceMatcher
15+
import torch
16+
17+
os.environ["WANDB_DISABLED"] = "true"
18+
os.environ["DS_ACCELERATOR"] = "cpu"
19+
logging.set_verbosity(logging.ERROR)
20+
21+
length = 4
22+
input = " ".join(str(i) for i in range(length))
23+
print("input:", input)
24+
input_list = input.split()
25+
26+
disable_progress_bar()
27+
ds = DatasetDict({ "train": Dataset.from_list([{"text":input}]),
28+
#Dataset.from_generator(gen),
29+
"valid": Dataset.from_list([{"text":input}])
30+
})
31+
32+
tokenizer= AutoTokenizer.from_pretrained("gpt2")
33+
34+
tokenizer.pad_token = tokenizer.eos_token
35+
36+
# Tokenize the ds
37+
def tokenize_function(examples):
38+
return tokenizer(examples["text"], max_length=length)
39+
40+
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text"])
41+
42+
print("model")
43+
# https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Config
44+
config = GPT2Config( vocab_size=tokenizer.vocab_size, n_positions=128, n_ctx=128,
45+
n_embd=256, n_layer=2 * length, n_head=length)
46+
47+
48+
model = GPT2LMHeadModel(config)
49+
50+
dc = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
51+
52+
# https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
53+
training_args = TrainingArguments(
54+
run_name="test",
55+
num_train_epochs=4,
56+
output_dir="./results",
57+
overwrite_output_dir=True,
58+
eval_strategy="steps",
59+
)
60+
61+
trainer = Trainer(
62+
model=model,
63+
args=training_args,
64+
train_dataset=tokenized_datasets["train"],
65+
eval_dataset=tokenized_datasets["valid"],
66+
data_collator=dc,
67+
)
68+
print("training")
69+
t = trainer.train()
70+
print(t)
71+
72+
print("inference")
73+
gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
74+
output = gen(input_list[0], max_length=length, num_return_sequences=1)[0]['generated_text']
75+
print(SequenceMatcher(None, input, output).ratio(), output)
76+
tokenizer.save_pretrained('trained_model')
77+
78+
# https://github.com/ggml-org/llama.cpp/issues/11345
79+
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight.clone().detach())
80+
model.tie_word_embeddings = False
81+
82+
model.save_pretrained('trained_model')

0 commit comments

Comments
 (0)