Skip to content

Commit

Permalink
refactor ppo example
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Jan 18, 2024
1 parent e8698cf commit b7ce2a5
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 304 deletions.
78 changes: 78 additions & 0 deletions examples/trl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,81 @@ python run_generation.py \
--prompt "Here is my prompt"

```
## PPO pipeline
### Training
The following example is for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model.
There are three main steps to the PPO training process:
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
```
python ../gaudi_spawn.py --world_size 8 --use_mpi sft.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--output_dir="./sft" \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--lora_target_modules "q_proj" "v_proj" \
--bf16 \
--remove_unused_columns=False \
--run_name="sft_llama2" \
--report_to=none \
--use_habana \
--use_lazy_mode
```
2. Reward modeling using dialog pairs from the SE dataset using the llama-v2-7b-se to create llama-v2-7b-se-rm:
```
python ../gaudi_spawn.py --world_size 8 --use_mpi reward_modeling.py \
--model_name=./sft/final_merged_checkpoint \
--output_dir=./rm
```
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
```
python merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="rm" --output_name="rm_merged_checkpoint"
```
3. RL fine-tuning of llama-v2-7b-se with the llama-v2-7b-se-rm reward model:
```
python ../gaudi_spawn.py --world_size 8 --use_mpi ppo.py \
--log_with=wandb \
--model_name=./sft/final_merged_checkpoint \
--reward_model_name=./rm_merged_checkpoint \
--adafactor=False \
--output_max_length=128 \
--batch_size=8 \
--gradient_accumulation_steps=8 \
--batched_gen=True \
--ppo_epochs=4 \
--seed=0 \
--learning_rate=1.4e-5 \
--early_stopping=True \
--output_dir=llama-se-rl-finetune
```
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
```
python merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="llama-se-rl-finetune" --output_name="rl_merged_checkpoint"
```
### Running the model
We can load the PPO-trained LoRA adaptors which were saved by the PPO training step and run it through the [text-generation example](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation).
```
python run_generation.py \
--model_name_or_path ../trl/rl_merged_checkpoint/ \
--use_hpu_graphs --use_kv_cache --batch_size 1 --bf16 --max_new_tokens 100 \
--prompt "Here is my prompt"

```
23 changes: 16 additions & 7 deletions examples/trl/stack_llama/rl_training.py → examples/trl/ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/rl_training.py, enable it for Gaudi2
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional

import torch
from datasets import load_dataset
Expand All @@ -26,8 +26,8 @@ class ScriptArguments:

# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable.
model_name: Optional[str] = field(default="", metadata={"help": "the model name"})
tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"})
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
tokenizer_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the tokenizer name"})
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
Expand Down Expand Up @@ -57,6 +57,13 @@ class ScriptArguments:

adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
use_habana: Optional[bool] = field(default=True, metadata={"help": "use habana for RL training"})
lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=16, metadata={"help": "the lora r parameter"})
lora_target_modules: List[str] = field(
default_factory=lambda: None,
metadata={"help": "Target modules for the LoRA method."},
)


adapt_PreTrainedModelWrapper_to_gaudi()
Expand Down Expand Up @@ -170,9 +177,10 @@ def collator(data):
# Now let's build the model, the reference model, and the tokenizer.
current_device = GaudiAccelerator().local_process_index
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=script_args.lora_target_modules,
bias="none",
task_type="CAUSAL_LM",
)
Expand Down Expand Up @@ -266,7 +274,6 @@ def collator(data):
output_length_sampler = LengthSampler(output_min_length, output_max_length)
else:
output_length_sampler = LengthSampler(output_max_length, output_max_length + 1)

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
if epoch >= config.total_ppo_epochs:
break
Expand All @@ -292,3 +299,5 @@ def collator(data):

if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")

ppo_trainer.save_pretrained(script_args.output_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ class ScriptArguments:
learning_rate: Optional[float] = field(default=2e-5)
weight_decay: Optional[float] = field(default=0.001)
model_name: Optional[str] = field(
default="gpt2",
default="meta-llama/Llama-2-7b-hf",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
},
)
tokenizer_name: Optional[str] = field(
default=None,
default="meta-llama/Llama-2-7b-hf",
metadata={
"help": "The tokenizer for your model, if left empty will use the default for your model",
},
Expand Down Expand Up @@ -91,6 +91,17 @@ class ScriptArguments:
default=False,
metadata={"help": "Whether to run eval after the first step"},
)
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
save_steps: Optional[int] = field(default=500, metadata={"help": "the saving frequency"})
eval_steps: Optional[int] = field(default=500, metadata={"help": "the evaluation frequency"})
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.1, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
lora_target_modules: List[str] = field(
default_factory=lambda: None,
metadata={"help": "Target modules for the LoRA method."},
)


parser = HfArgumentParser(ScriptArguments)
Expand All @@ -105,21 +116,18 @@ class ScriptArguments:
eval_dataset = eval_dataset.select(range(script_args.eval_subset))
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
model_name_split = script_args.model_name.split("/")[-1]
output_name = (
f"{model_name_split}_peft_stack-exchange-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}"
)

training_args = GaudiTrainingArguments(
output_dir=output_name,
output_dir=script_args.output_dir,
learning_rate=script_args.learning_rate,
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
num_train_epochs=script_args.num_train_epochs,
weight_decay=script_args.weight_decay,
evaluation_strategy="steps",
eval_steps=500,
eval_steps=script_args.eval_steps,
save_strategy="steps",
save_steps=500,
save_steps=script_args.save_steps,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
deepspeed=script_args.deepspeed,
Expand All @@ -128,7 +136,7 @@ class ScriptArguments:
label_names=[],
bf16=script_args.bf16,
logging_strategy="steps",
logging_steps=10,
logging_steps=script_args.logging_steps,
optim=script_args.optim,
lr_scheduler_type=script_args.lr_scheduler_type,
report_to="none",
Expand All @@ -140,13 +148,14 @@ class ScriptArguments:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token


peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=script_args.lora_target_modules,
bias="none",
)
torch.autograd.set_detect_anomaly(True)
model = AutoModelForSequenceClassification.from_pretrained(
Expand Down Expand Up @@ -310,4 +319,4 @@ def on_step_end(self, args, state, control, **kwargs):
trainer.train(script_args.resume_from_checkpoint)

print("Saving last checkpoint of the model")
trainer.save_model(output_name + "_peft_last_checkpoint")
trainer.save_model(script_args.output_dir)
18 changes: 0 additions & 18 deletions examples/trl/stack_llama/README.md

This file was deleted.

50 changes: 0 additions & 50 deletions examples/trl/stack_llama/merge_peft_adapter.py

This file was deleted.

Loading

0 comments on commit b7ce2a5

Please sign in to comment.