Skip to content

Commit

Permalink
add DPO and SFT of TRL support in Gaudi and example (#601)
Browse files Browse the repository at this point in the history
* add DPO and SFT of TRL support in Gaudi and example

Signed-off-by: Wang, Yi A <[email protected]>

* upgrade SFTTrainer/DPO trainer and stack_llama_2 example to v0.7.6

Signed-off-by: Wang, Yi A <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored Dec 25, 2023
1 parent 5f772ac commit 79f6de3
Show file tree
Hide file tree
Showing 9 changed files with 1,219 additions and 0 deletions.
72 changes: 72 additions & 0 deletions examples/trl/stack_llama_2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model

## Prerequisites

Install all the dependencies in the `requirements.txt`:

```
$ pip install -U -r requirements.txt
```


## Training

There were two main steps to the DPO training process:
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:

```
python ../../gaudi_spawn.py --world_size 8 --use_mpi sft_llama2.py \
--output_dir="./sft" \
--max_steps=500 \
--logging_steps=10 \
--save_steps=10 \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--bf16 \
--remove_unused_columns=False \
--run_name="sft_llama2" \
--report_to=none \
--use_habana \
--use_lazy_mode
```
2. Run the DPO trainer using the model saved by the previous step:
```
python ../../gaudi_spawn.py --world_size 8 --use_mpi dpo_llama2.py \
--model_name_or_path="sft/final_merged_checkpoint" \
--output_dir="dpo" \
--report_to=none
```
## Merging the adaptors
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
```
python merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo" --output_name="stack-llama-2"
```
which will also push the model to your HuggingFace hub account.
## Running the model
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via:
```py
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(
"dpo/final_checkpoint",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)
model.generate(...)
```
231 changes: 231 additions & 0 deletions examples/trl/stack_llama_2/dpo_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py, enable it for Gaudi2
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from optimum.habana import GaudiConfig, GaudiTrainingArguments
from optimum.habana.trl import GaudiDPOTrainer


# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the DPO training script.
"""

# data parameters
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

# training parameters
model_name_or_path: Optional[str] = field(
default="../sft/results/final_checkpoint",
metadata={"help": "the location of the SFT model name or path"},
)
tokenizer_name_or_path: Optional[str] = field(
default="meta-llama/Llama-2-7b-hf",
metadata={"help": "the location of the SFT model name or path"},
)
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})

per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=4, metadata={"help": "the number of gradient accumulation steps"}
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "whether to use gradient checkpointing"}
)

lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})

max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})

output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})

# instrumentation
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
default="wandb",
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
},
)
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)


def get_stack_exchange_paired(
data_dir: str = "data/rl",
sanity_check: bool = False,
cache_dir: str = None,
num_proc=24,
) -> Dataset:
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts are structured as follows:
"Question: " + <prompt> + "\n\nAnswer: "
"""
dataset = load_dataset(
"lvwerra/stack-exchange-paired",
split="train",
cache_dir=cache_dir,
data_dir=data_dir,
)
original_columns = dataset.column_names

if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))

def return_prompt_and_responses(samples) -> Dict[str, str]:
return {
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
"chosen": samples["response_j"],
"rejected": samples["response_k"],
}

return dataset.map(
return_prompt_and_responses,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
# 1. initialize training arguments:
training_args = GaudiTrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
max_steps=script_args.max_steps,
logging_steps=script_args.logging_steps,
save_steps=script_args.save_steps,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
eval_steps=script_args.eval_steps,
output_dir=script_args.output_dir,
report_to=script_args.report_to,
lr_scheduler_type=script_args.lr_scheduler_type,
warmup_steps=script_args.warmup_steps,
optim=script_args.optimizer_type,
bf16=True,
remove_unused_columns=False,
run_name="dpo_llama2",
use_habana=True,
use_lazy_mode=True,
use_hpu_graphs_for_training=True,
use_hpu_graphs_for_inference=True,
)
# 2. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False

if script_args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]

model_ref = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)
model_ref.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

# 3. Load the Stack-exchange paired dataset
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
train_dataset = train_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)

# 4. Load evaluation dataset
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
eval_dataset = eval_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)

peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=[
"q_proj",
"v_proj",
"k_proj",
"out_proj",
"fc_in",
"fc_out",
"wte",
],
bias="none",
task_type="CAUSAL_LM",
)

gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
gaudi_config.use_fused_clip_norm = True

# 5. initialize the DPO trainer
dpo_trainer = GaudiDPOTrainer(
model,
model_ref,
gaudi_config=gaudi_config,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config,
max_prompt_length=script_args.max_prompt_length,
max_length=script_args.max_length,
)

# 6. train
dpo_trainer.train()

# 7. save
dpo_trainer.save_model(script_args.output_dir)
50 changes: 50 additions & 0 deletions examples/trl/stack_llama_2/merge_peft_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py.
# only difference is removal of model.push_to_hub
from dataclasses import dataclass, field
from typing import Optional

import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser


@dataclass
class ScriptArguments:
"""
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the
merged model.
"""

adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
assert script_args.base_model_name is not None, "please provide the name of the Base model"
assert script_args.output_name is not None, "please provide the output name of the merged model"

peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
if peft_config.task_type == "SEQ_CLS":
# The sequence classification task is used for the reward model in PPO
model = AutoModelForSequenceClassification.from_pretrained(
script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16
)
else:
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)

# Load the PEFT model
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
model.eval()

model = model.merge_and_unload()

model.save_pretrained(f"{script_args.output_name}")
tokenizer.save_pretrained(f"{script_args.output_name}")
# model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False)
5 changes: 5 additions & 0 deletions examples/trl/stack_llama_2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
trl == 0.7.6
peft == 0.6.2
datasets
wandb
tyro
Loading

0 comments on commit 79f6de3

Please sign in to comment.