Skip to content

Commit

Permalink
add PPO and stack_llama support
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Dec 28, 2023
1 parent cba23b3 commit e3840e0
Show file tree
Hide file tree
Showing 11 changed files with 1,934 additions and 0 deletions.
18 changes: 18 additions & 0 deletions examples/trl/stack_llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model.
There were three main steps to the training process:
1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se:
- `torchrun --nnodes 1 --nproc_per_node 8 supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --learning_rate 1e-5 --max_steps 5000 --bf16 --output_dir ./llama-se`
2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm:
- `torchrun --nnodes 1 --nproc_per_node 8 reward_modeling.py --model_name=<LLAMA_SE_MODEL>`
3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model:
- `torchrun --nnodes 1 --nproc_per_node 8 rl_training.py --log_with=wandb --model_name=<LLAMA_SE_MODEL> --reward_model_name=<LLAMA_SE_RM_MODEL> --adafactor=False --tokenizer_name=<LLAMA_TOKENIZER> --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam`


LoRA layers were using at all stages to reduce memory requirements.
At each stage the peft adapter layers were merged with the base model, using:
```shell
python merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ
```
Note that this script requires `peft>=0.3.0`.

For access to the base llama-7b model, please see Meta's [release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) and [request form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform).
50 changes: 50 additions & 0 deletions examples/trl/stack_llama/merge_peft_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py.
# only difference is removal of model.push_to_hub
from dataclasses import dataclass, field
from typing import Optional

import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser


@dataclass
class ScriptArguments:
"""
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the
merged model.
"""

adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
assert script_args.base_model_name is not None, "please provide the name of the Base model"
assert script_args.output_name is not None, "please provide the output name of the merged model"

peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
if peft_config.task_type == "SEQ_CLS":
# The sequence classification task is used for the reward model in PPO
model = AutoModelForSequenceClassification.from_pretrained(
script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16
)
else:
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)

# Load the PEFT model
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
model.eval()

model = model.merge_and_unload()

model.save_pretrained(f"{script_args.output_name}")
tokenizer.save_pretrained(f"{script_args.output_name}")
# model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False)
313 changes: 313 additions & 0 deletions examples/trl/stack_llama/reward_modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/reward_modeling.py, enable it for Gaudi2

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

import evaluate
import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
PreTrainedTokenizerBase,
TrainerCallback,
)
from transformers.utils import PaddingStrategy

from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments


# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""

local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})
resume_from_checkpoint: Optional[bool] = field(
default=False,
metadata={"help": "If you want to resume training where it left off."},
)
deepspeed: Optional[str] = field(
default=None,
metadata={
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
},
)
per_device_train_batch_size: Optional[int] = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=1)
learning_rate: Optional[float] = field(default=2e-5)
weight_decay: Optional[float] = field(default=0.001)
model_name: Optional[str] = field(
default="gpt2",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "The tokenizer for your model, if left empty will use the default for your model",
},
)
bf16: Optional[bool] = field(
default=True,
metadata={
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
},
)
num_train_epochs: Optional[int] = field(
default=1,
metadata={"help": "The number of training epochs for the reward model."},
)
train_subset: Optional[int] = field(
default=100000,
metadata={"help": "The size of the subset of the training data to use"},
)
eval_subset: Optional[int] = field(
default=50000,
metadata={"help": "The size of the subset of the eval data to use"},
)
gradient_checkpointing: Optional[bool] = field(
default=False,
metadata={"help": "Enables gradient checkpointing."},
)
optim: Optional[str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: Optional[str] = field(
default="linear",
metadata={"help": "The lr scheduler"},
)
max_length: Optional[int] = field(default=512)
eval_first_step: Optional[bool] = field(
default=False,
metadata={"help": "Whether to run eval after the first step"},
)


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

# Load the human stack-exchange-paired dataset for tuning the reward model.
train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train")
if script_args.train_subset > 0:
train_dataset = train_dataset.select(range(script_args.train_subset))
eval_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train")
if script_args.eval_subset > 0:
eval_dataset = eval_dataset.select(range(script_args.eval_subset))
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
model_name_split = script_args.model_name.split("/")[-1]
output_name = (
f"{model_name_split}_peft_stack-exchange-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}"
)

training_args = GaudiTrainingArguments(
output_dir=output_name,
learning_rate=script_args.learning_rate,
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
num_train_epochs=script_args.num_train_epochs,
weight_decay=script_args.weight_decay,
evaluation_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
deepspeed=script_args.deepspeed,
local_rank=script_args.local_rank,
remove_unused_columns=False,
label_names=[],
bf16=script_args.bf16,
logging_strategy="steps",
logging_steps=10,
optim=script_args.optim,
lr_scheduler_type=script_args.lr_scheduler_type,
report_to="none",
use_habana=True,
use_lazy_mode=True,
)
# Load the value-head model and tokenizer.
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token


peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
torch.autograd.set_detect_anomaly(True)
model = AutoModelForSequenceClassification.from_pretrained(
script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Need to do this for gpt2, because it doesn't have an official pad token.
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
model.config.use_cache = not script_args.gradient_checkpointing
num_proc = 24 # Can adjust to be higher if you have more processors.
original_columns = train_dataset.column_names


# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other.
# Then tokenize the dataset.
def preprocess_function(examples):
new_examples = {
"input_ids_j": [],
"attention_mask_j": [],
"input_ids_k": [],
"attention_mask_k": [],
}
for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]):
tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True)
tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True)

new_examples["input_ids_j"].append(tokenized_j["input_ids"])
new_examples["attention_mask_j"].append(tokenized_j["attention_mask"])
new_examples["input_ids_k"].append(tokenized_k["input_ids"])
new_examples["attention_mask_k"].append(tokenized_k["attention_mask"])

return new_examples


# preprocess the dataset and filter out QAs that are longer than script_args.max_length
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
)

eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
)


# We need to define a special data collator that batches the data in our j vs k format.
@dataclass
class RewardDataCollatorWithPadding:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
features_j = []
features_k = []
for feature in features:
features_j.append(
{
"input_ids": feature["input_ids_j"],
"attention_mask": feature["attention_mask_j"],
}
)
features_k.append(
{
"input_ids": feature["input_ids_k"],
"attention_mask": feature["attention_mask_k"],
}
)
batch_j = self.tokenizer.pad(
features_j,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch_k = self.tokenizer.pad(
features_k,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch = {
"input_ids_j": batch_j["input_ids"],
"attention_mask_j": batch_j["attention_mask"],
"input_ids_k": batch_k["input_ids"],
"attention_mask_k": batch_k["attention_mask"],
"return_loss": True,
}
return batch


# Define the metric that we'll use for validation.
accuracy = evaluate.load("accuracy")


def compute_metrics(eval_pred):
predictions, _ = eval_pred
# Here, predictions is rewards_j and rewards_k.
# We want to see how much of the time rewards_j > rewards_k.
predictions = np.argmax(predictions, axis=0)
labels = np.zeros(predictions.shape)
return accuracy.compute(predictions=predictions, references=labels)


class RewardTrainer(GaudiTrainer):
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155
def compute_loss(self, model, inputs, return_outputs=False):
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
if return_outputs:
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
return loss


gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
gaudi_config.use_fused_clip_norm = True

# Train the model, woohoo.
trainer = RewardTrainer(
model=model,
gaudi_config=gaudi_config,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
data_collator=RewardDataCollatorWithPadding(
tokenizer=tokenizer, max_length=script_args.max_length, padding="max_length"
),
)


if script_args.eval_first_step:

class EvaluateFirstStepCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
if state.global_step == 1:
control.should_evaluate = True

trainer.add_callback(EvaluateFirstStepCallback())

trainer.train(script_args.resume_from_checkpoint)

print("Saving last checkpoint of the model")
trainer.save_model(output_name + "_peft_last_checkpoint")
Loading

0 comments on commit e3840e0

Please sign in to comment.