-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add DPO and SFT of TRL support in Gaudi and example (#601)
* add DPO and SFT of TRL support in Gaudi and example Signed-off-by: Wang, Yi A <[email protected]> * upgrade SFTTrainer/DPO trainer and stack_llama_2 example to v0.7.6 Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]>
- Loading branch information
Showing
9 changed files
with
1,219 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model | ||
|
||
## Prerequisites | ||
|
||
Install all the dependencies in the `requirements.txt`: | ||
|
||
``` | ||
$ pip install -U -r requirements.txt | ||
``` | ||
|
||
|
||
## Training | ||
|
||
There were two main steps to the DPO training process: | ||
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se: | ||
|
||
``` | ||
python ../../gaudi_spawn.py --world_size 8 --use_mpi sft_llama2.py \ | ||
--output_dir="./sft" \ | ||
--max_steps=500 \ | ||
--logging_steps=10 \ | ||
--save_steps=10 \ | ||
--per_device_train_batch_size=4 \ | ||
--per_device_eval_batch_size=1 \ | ||
--gradient_accumulation_steps=2 \ | ||
--learning_rate=1e-4 \ | ||
--lr_scheduler_type="cosine" \ | ||
--warmup_steps=100 \ | ||
--weight_decay=0.05 \ | ||
--optim="paged_adamw_32bit" \ | ||
--bf16 \ | ||
--remove_unused_columns=False \ | ||
--run_name="sft_llama2" \ | ||
--report_to=none \ | ||
--use_habana \ | ||
--use_lazy_mode | ||
``` | ||
2. Run the DPO trainer using the model saved by the previous step: | ||
``` | ||
python ../../gaudi_spawn.py --world_size 8 --use_mpi dpo_llama2.py \ | ||
--model_name_or_path="sft/final_merged_checkpoint" \ | ||
--output_dir="dpo" \ | ||
--report_to=none | ||
``` | ||
## Merging the adaptors | ||
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL: | ||
``` | ||
python merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo" --output_name="stack-llama-2" | ||
``` | ||
which will also push the model to your HuggingFace hub account. | ||
## Running the model | ||
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via: | ||
```py | ||
from peft import AutoPeftModelForCausalLM | ||
model = AutoPeftModelForCausalLM.from_pretrained( | ||
"dpo/final_checkpoint", | ||
low_cpu_mem_usage=True, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
model.generate(...) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py, enable it for Gaudi2 | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from datasets import Dataset, load_dataset | ||
from peft import LoraConfig | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | ||
|
||
from optimum.habana import GaudiConfig, GaudiTrainingArguments | ||
from optimum.habana.trl import GaudiDPOTrainer | ||
|
||
|
||
# Define and parse arguments. | ||
@dataclass | ||
class ScriptArguments: | ||
""" | ||
The arguments for the DPO training script. | ||
""" | ||
|
||
# data parameters | ||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) | ||
|
||
# training parameters | ||
model_name_or_path: Optional[str] = field( | ||
default="../sft/results/final_checkpoint", | ||
metadata={"help": "the location of the SFT model name or path"}, | ||
) | ||
tokenizer_name_or_path: Optional[str] = field( | ||
default="meta-llama/Llama-2-7b-hf", | ||
metadata={"help": "the location of the SFT model name or path"}, | ||
) | ||
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"}) | ||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) | ||
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) | ||
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"}) | ||
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) | ||
|
||
per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"}) | ||
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) | ||
gradient_accumulation_steps: Optional[int] = field( | ||
default=4, metadata={"help": "the number of gradient accumulation steps"} | ||
) | ||
gradient_checkpointing: Optional[bool] = field( | ||
default=False, metadata={"help": "whether to use gradient checkpointing"} | ||
) | ||
|
||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) | ||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) | ||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) | ||
|
||
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) | ||
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) | ||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) | ||
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"}) | ||
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"}) | ||
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) | ||
|
||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) | ||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) | ||
|
||
# instrumentation | ||
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) | ||
report_to: Optional[str] = field( | ||
default="wandb", | ||
metadata={ | ||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' | ||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' | ||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' | ||
}, | ||
) | ||
# debug argument for distributed training | ||
ignore_bias_buffers: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" | ||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" | ||
}, | ||
) | ||
|
||
|
||
def get_stack_exchange_paired( | ||
data_dir: str = "data/rl", | ||
sanity_check: bool = False, | ||
cache_dir: str = None, | ||
num_proc=24, | ||
) -> Dataset: | ||
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. | ||
The dataset is converted to a dictionary with the following structure: | ||
{ | ||
'prompt': List[str], | ||
'chosen': List[str], | ||
'rejected': List[str], | ||
} | ||
Prompts are structured as follows: | ||
"Question: " + <prompt> + "\n\nAnswer: " | ||
""" | ||
dataset = load_dataset( | ||
"lvwerra/stack-exchange-paired", | ||
split="train", | ||
cache_dir=cache_dir, | ||
data_dir=data_dir, | ||
) | ||
original_columns = dataset.column_names | ||
|
||
if sanity_check: | ||
dataset = dataset.select(range(min(len(dataset), 1000))) | ||
|
||
def return_prompt_and_responses(samples) -> Dict[str, str]: | ||
return { | ||
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], | ||
"chosen": samples["response_j"], | ||
"rejected": samples["response_k"], | ||
} | ||
|
||
return dataset.map( | ||
return_prompt_and_responses, | ||
batched=True, | ||
num_proc=num_proc, | ||
remove_columns=original_columns, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser(ScriptArguments) | ||
script_args = parser.parse_args_into_dataclasses()[0] | ||
# 1. initialize training arguments: | ||
training_args = GaudiTrainingArguments( | ||
per_device_train_batch_size=script_args.per_device_train_batch_size, | ||
per_device_eval_batch_size=script_args.per_device_eval_batch_size, | ||
max_steps=script_args.max_steps, | ||
logging_steps=script_args.logging_steps, | ||
save_steps=script_args.save_steps, | ||
gradient_accumulation_steps=script_args.gradient_accumulation_steps, | ||
gradient_checkpointing=script_args.gradient_checkpointing, | ||
learning_rate=script_args.learning_rate, | ||
evaluation_strategy="steps", | ||
eval_steps=script_args.eval_steps, | ||
output_dir=script_args.output_dir, | ||
report_to=script_args.report_to, | ||
lr_scheduler_type=script_args.lr_scheduler_type, | ||
warmup_steps=script_args.warmup_steps, | ||
optim=script_args.optimizer_type, | ||
bf16=True, | ||
remove_unused_columns=False, | ||
run_name="dpo_llama2", | ||
use_habana=True, | ||
use_lazy_mode=True, | ||
use_hpu_graphs_for_training=True, | ||
use_hpu_graphs_for_inference=True, | ||
) | ||
# 2. load a pretrained model | ||
model = AutoModelForCausalLM.from_pretrained( | ||
script_args.model_name_or_path, | ||
low_cpu_mem_usage=True, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
model.config.use_cache = False | ||
|
||
if script_args.ignore_bias_buffers: | ||
# torch distributed hack | ||
model._ddp_params_and_buffers_to_ignore = [ | ||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool | ||
] | ||
|
||
model_ref = AutoModelForCausalLM.from_pretrained( | ||
script_args.model_name_or_path, | ||
low_cpu_mem_usage=True, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
model_ref.config.use_cache = False | ||
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
# 3. Load the Stack-exchange paired dataset | ||
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check) | ||
train_dataset = train_dataset.filter( | ||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length | ||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length | ||
) | ||
|
||
# 4. Load evaluation dataset | ||
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True) | ||
eval_dataset = eval_dataset.filter( | ||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length | ||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length | ||
) | ||
|
||
peft_config = LoraConfig( | ||
r=script_args.lora_r, | ||
lora_alpha=script_args.lora_alpha, | ||
lora_dropout=script_args.lora_dropout, | ||
target_modules=[ | ||
"q_proj", | ||
"v_proj", | ||
"k_proj", | ||
"out_proj", | ||
"fc_in", | ||
"fc_out", | ||
"wte", | ||
], | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
|
||
gaudi_config = GaudiConfig() | ||
gaudi_config.use_fused_adam = True | ||
gaudi_config.use_fused_clip_norm = True | ||
|
||
# 5. initialize the DPO trainer | ||
dpo_trainer = GaudiDPOTrainer( | ||
model, | ||
model_ref, | ||
gaudi_config=gaudi_config, | ||
args=training_args, | ||
beta=script_args.beta, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
tokenizer=tokenizer, | ||
peft_config=peft_config, | ||
max_prompt_length=script_args.max_prompt_length, | ||
max_length=script_args.max_length, | ||
) | ||
|
||
# 6. train | ||
dpo_trainer.train() | ||
|
||
# 7. save | ||
dpo_trainer.save_model(script_args.output_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py. | ||
# only difference is removal of model.push_to_hub | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
import torch | ||
from peft import PeftConfig, PeftModel | ||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser | ||
|
||
|
||
@dataclass | ||
class ScriptArguments: | ||
""" | ||
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the | ||
merged model. | ||
""" | ||
|
||
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) | ||
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) | ||
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) | ||
|
||
|
||
parser = HfArgumentParser(ScriptArguments) | ||
script_args = parser.parse_args_into_dataclasses()[0] | ||
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" | ||
assert script_args.base_model_name is not None, "please provide the name of the Base model" | ||
assert script_args.output_name is not None, "please provide the output name of the merged model" | ||
|
||
peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) | ||
if peft_config.task_type == "SEQ_CLS": | ||
# The sequence classification task is used for the reward model in PPO | ||
model = AutoModelForSequenceClassification.from_pretrained( | ||
script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 | ||
) | ||
else: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) | ||
|
||
# Load the PEFT model | ||
model = PeftModel.from_pretrained(model, script_args.adapter_model_name) | ||
model.eval() | ||
|
||
model = model.merge_and_unload() | ||
|
||
model.save_pretrained(f"{script_args.output_name}") | ||
tokenizer.save_pretrained(f"{script_args.output_name}") | ||
# model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
trl == 0.7.6 | ||
peft == 0.6.2 | ||
datasets | ||
wandb | ||
tyro |
Oops, something went wrong.