Skip to content

Commit

Permalink
Merge pull request #119 from amosproj/98-implement-a-script-for-llm-f…
Browse files Browse the repository at this point in the history
…ine-tuning

98 implement a script for llm fine tuning
  • Loading branch information
julioc-p authored Jul 10, 2024
2 parents ce9477f + ce79baf commit 9742120
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 95 deletions.
60 changes: 60 additions & 0 deletions src/hpc_scripts/CustomSFTTrainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Union
from trl import SFTTrainer
import optuna
from transformers.trainer_utils import HPSearchBackend, BestRun, PREFIX_CHECKPOINT_DIR, default_compute_objective
import os
import gc
import torch


class CustomSFTTrainer(SFTTrainer):

@staticmethod
def run_hp_search_optuna(trainer, n_trials, direction, **kwargs):

def _objective(trial, checkpoint_dir=None):
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
#################
# UPDATES START
#################
if not checkpoint:
# free GPU memory
del trainer.model
gc.collect()
torch.cuda.empty_cache()
trainer.objective = None
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return trainer.objective

timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials,
timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)

def hyperparameter_search(
self,
hp_space,
n_trials,
direction,
compute_objective=default_compute_objective,
) -> Union[BestRun, List[BestRun]]:

self.hp_search_backend = HPSearchBackend.OPTUNA
self.hp_space = hp_space
self.hp_name = None
self.compute_objective = compute_objective
best_run = CustomSFTTrainer.run_hp_search_optuna(
self, n_trials, direction)
self.hp_search_backend = None
return best_run
122 changes: 41 additions & 81 deletions src/hpc_scripts/hyperparameter_optimization.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,78 @@
# imports
import transformers
import gc
from transformers import (AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig
)
from trl import SFTTrainer
from peft import LoraConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import HfApi, login
from transformers.hyperparameter_search import HPSearchBackend
from transformers.trainer import *
import optuna
import gc

import torch
import CustomSFTTrainer
import random
import os
HF_TOKEN = os.getenv('HF_TOKEN', 'add_hf_token')
api = HfApi()
login(HF_TOKEN, add_to_git_credential=True)


gc.collect()
torch.cuda.empty_cache()


def run_hp_search_optuna(trainer, n_trials, direction, **kwargs):

def _objective(trial, checkpoint_dir=None):
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
#################
# UPDATES START
#################
if not checkpoint:
# free GPU memory
del trainer.model
gc.collect()
torch.cuda.empty_cache()
trainer.objective = None
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return trainer.objective

timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials,
timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)


def hyperparameter_search(
self,
hp_space,
n_trials,
direction,
compute_objective=default_compute_objective,
) -> Union[BestRun, List[BestRun]]:

trainer.hp_search_backend = HPSearchBackend.OPTUNA
self.hp_space = hp_space
trainer.hp_name = None
trainer.compute_objective = compute_objective
best_run = run_hp_search_optuna(trainer, n_trials, direction)
self.hp_search_backend = None
return best_run


transformers.trainer.Trainer.hyperparameter_search = hyperparameter_search


# defining hyperparameter search space for optuna


def optuna_hp_space(trial):
return {
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64]),
"num_train_epochs": trial.suggest_int("num_train_epochs", 3, 15),
"weight_decay": trial.suggest_loguniform("weight_decay", 1e-6, 1e-2),
"gradient_clipping": trial.suggest_float("gradient_clipping", 0.1, 0.5),
}

# Define a function to calculate BLEU score


# configuration arguments
model_id = "google/gemma-2-27b-it"
model_id = "google/gemma-2-9b-it"

# model init function for the trainer
# bits and bytes config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)


def model_init(trial):

return AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=bnb_config, device_map="auto")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
return model


# tokenizer load
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='right')

# Loading training and evaluation data
training_dataset = load_dataset(
"Kubermatic/cncf-question-and-answer-dataset-for-llm-training", split="train[:7500]")
eval_dataset = load_dataset(
"Kubermatic/cncf-question-and-answer-dataset-for-llm-training", split="train[7500:8000]")
dataset = load_dataset(
"Kubermatic/Merged_QAs", split="train")

random.seed(42)
random_indices = random.sample(range(len(dataset)), k=500)

training_indices = random_indices[:400]
eval_indices = random_indices[400:500]
training_dataset = dataset.filter(
lambda _, idx: idx in training_indices, with_indices=True)
eval_dataset = dataset.filter(
lambda _, idx: idx in eval_indices, with_indices=True)

max_seq_length = 1024


output_dir = "trained_model"
training_arguments = TrainingArguments(
output_dir=output_dir,
num_train_epochs=1,
num_train_epochs=3,
gradient_checkpointing=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
Expand Down Expand Up @@ -163,11 +117,14 @@ def formatting_func(example):
output_texts.append(text)
return output_texts

# instantiation of the trainer

# Passing model
model = model_init(None)


trainer = SFTTrainer(
model=model_id,
# instantiation of the trainer
trainer = CustomSFTTrainer(
model=model,
train_dataset=training_dataset,
eval_dataset=eval_dataset,
args=training_arguments,
Expand All @@ -178,10 +135,13 @@ def formatting_func(example):
model_init=model_init,
)

# avoid placing model on device as it is already placed on device in model_init
trainer.place_model_on_device = False

best_trial = trainer.hyperparameter_search(
direction="minimize",
hp_space=optuna_hp_space,
n_trials=20,
n_trials=5,
)

print(best_trial)
34 changes: 20 additions & 14 deletions src/hpc_scripts/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@


# training pipeline taken from https://huggingface.co/blog/gemma-peft
model_id = "google/gemma-2-27b-it"
model_id = "google/gemma-2-9b-it"

bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)

dataset = load_dataset(
"Kubermatic/Merged_QAs", split="train")
dataset.shuffle(42)
dataset = dataset.train_test_split(train_size=0.20, test_size=0.04)

tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='right')
# TODO: Check if this can be changed to AutoModelForQuestionAnswering with GEMMA
model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=bnb_config, device_map="auto")

# Training Data
dataset = load_dataset(
"Kubermatic/cncf-question-and-answer-dataset-for-llm-training", split="train")
model_id, quantization_config=bnb_config, device_map="auto", attn_implementation='eager')


# Training (hyper)parameters (initial config taken from: https://medium.com/@lucamassaron/sherlock-holmes-q-a-enhanced-with-gemma-2b-it-fine-tuning-2907b06d2645)
Expand All @@ -44,15 +45,15 @@

training_arguments = TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
num_train_epochs=5,
gradient_checkpointing=True,
per_device_train_batch_size=16,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
optim="paged_adamw_32bit",
save_steps=0,
logging_steps=10,
learning_rate=5e-4,
weight_decay=0.001,
learning_rate=1.344609154868106e-05,
weight_decay=0.00019307024914471071,
fp16=True,
bf16=False,
max_grad_norm=0.3,
Expand All @@ -63,6 +64,10 @@
report_to="tensorboard",
disable_tqdm=False,
load_best_model_at_end=True,
eval_accumulation_steps=1,
evaluation_strategy='steps',
eval_steps=500,
per_device_eval_batch_size=4
# debug="underflow_overflow"
)

Expand Down Expand Up @@ -96,13 +101,14 @@ def formatting_func(example):

trainer = SFTTrainer(
model=model,
train_dataset=dataset,
train_dataset=dataset["train"],
args=training_arguments,
peft_config=lora_config,
formatting_func=formatting_func,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
callbacks=[EarlyStoppingCallback(early_stopping_patience=15)],
eval_dataset=dataset["test"],
)
trainer.train()
print("Model is trained")
Expand Down

0 comments on commit 9742120

Please sign in to comment.