-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
44 lines (31 loc) · 1.07 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from transformers import Trainer, TrainingArguments
from config import CFG
from transformers import AutoModelForCausalLM
def get_model(model_name):
model = AutoModelForCausalLM.from_pretrained(model_name)
return model
training_args = TrainingArguments(
output_dir="./fine_tuned_gpt2/",
evaluation_strategy = "epoch",
learning_rate=1e-4,
weight_decay=0.01,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=10,
# push_to_hub=True,
save_total_limit=1,
save_strategy = "epoch",
load_best_model_at_end = True,
)
def training_function(tokenized_train_dataset, tokenized_val_dataset):
trainer = Trainer(
model= get_model(CFG['model']['model_name']),
args=training_args,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_val_dataset,
)
# start the training and display the message after completion
print("Training is started")
# trainer.train()
print("Training is completed")
return trainer