Skip to content

Commit

Permalink
PPOtask for mlora_train (#279)
Browse files Browse the repository at this point in the history
[feature] PPOtask for mlora_train
  • Loading branch information
ck-gyj authored Jan 26, 2025
1 parent 0e2c4d6 commit 4abf120
Show file tree
Hide file tree
Showing 10 changed files with 922 additions and 1 deletion.
81 changes: 81 additions & 0 deletions demo/ppo/ppo_case1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
dispatcher:
name: "default"
concurrency_num: 2
datasets:
- name: "ppo_data"
data: "demo/data.json"
prompt: "demo/ppo_prompt.yaml"
prompt_type: "ppo"
preprocess: "default"
adapters:
- name: "lora_ppo_reward"
type: "lora"
path: "adapters/lora_ppo_reward"
optimizer: "adamw"
lr: 1e-4
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
- name: "lora_ppo_critic"
type: "lora"
path: "adapters/lora_ppo_critic"
optimizer: "adamw"
lr: 5e-5
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
- name: "lora_ppo_actor"
type: "lora"
path: "adapters/lora_ppo_actor"
optimizer: "adamw"
lr: 5e-5
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
tasks:
- type: "ppo"
name: "task_0"
adapter:
reward_adapter: "lora_ppo_reward"
actor_adapter: "lora_ppo_actor"
critic_adapter: "lora_ppo_critic"
reference: "base"
dataset: "ppo_data"
batch_size: 16
mini_batch_size: 16
num_epochs: 20
K_epochs: 5
optim_num: 2
cutoff_len: 256
save_step: 100
gamma: 0.99
lamdb: 0.99
kl_coefficient: 0.99
generate_num: 32
critic_loss_type: "mse"
actor_loss_type: "adv_loss"
reward_loss_type: "reward_loss"
8 changes: 8 additions & 0 deletions demo/ppo_prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
template: |
{% if Optional=="instruction" %}
### Instruction: {{ data_point['instruction'] + '\n'}}
{% elif Optional=="chosen" %}
### Instruction: {{ data_point['instruction'] + '\n'}} ### Output: {{ data_point['chosen'] + '\n'}}
{% else %}
### Instruction: {{ data_point['instruction'] + '\n'}} ### Output: {{ data_point['reject'] + '\n'}}
{% endif %}
2 changes: 2 additions & 0 deletions mlora/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CITTaskConfig,
CPOTaskConfig,
DPOTaskConfig,
PPOTaskConfig,
TaskConfig,
TrainTaskConfig,
)
Expand All @@ -37,4 +38,5 @@
"ADAPTERCONFIG_CLASS",
"OptimizerConfig",
"LRSchedulerConfig",
"PPOTaskConfig",
]
60 changes: 59 additions & 1 deletion mlora/config/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ def __init__(
super().__init__(config)
self.init(self.__params_map, config)

self.adapter_ = adapters[config["adapter"]]
if isinstance(config["adapter"], dict):
self.reward_adapter_ = adapters[config["adapter"]["reward_adapter"]]
self.actor_adapter_ = adapters[config["adapter"]["actor_adapter"]]
self.critic_adapter_ = adapters[config["adapter"]["critic_adapter"]]
else:
self.adapter_ = adapters[config["adapter"]]

self.dataset_: DatasetConfig | None = datasets[config["dataset"]]


Expand Down Expand Up @@ -148,9 +154,61 @@ def __init__(
self.temperature_ = float(self.temperature_)


class PPOTaskConfig(TrainTaskConfig):
gamma_: float
lamdb_: float
K_epochs_: int
T_horizon_: int
critic_loss_type_: str
actor_loss_type_: str
reward_loss_type_: str
clip_rate_: float
generate_num_: int
reward_adapter_: AdapterConfig
critic_adapter_: AdapterConfig
actor_adapter_: AdapterConfig
kl_coefficient_: float
optim_num_: int

__params_map: Dict[str, str] = {
"gamma_": "gamma",
"lamdb_": "lamdb",
"K_epochs_": "K_epochs",
"optim_num_": "optim_num",
"critic_loss_type_": "critic_loss_type",
"actor_loss_type_": "actor_loss_type",
"reward_loss_type_": "reward_loss_type",
"generate_num_": "generate_num",
"kl_coefficient_": "kl_coefficient",
}

def __init__(
self,
config: Dict[str, str],
adapters: Mapping[str, AdapterConfig],
datasets: Mapping[str, DatasetConfig],
):
super().__init__(config, adapters, datasets)
self.init(self.__params_map, config)

self.gamma_ = float(self.gamma_)
self.lamdb_ = float(self.lamdb_)
self.K_epochs_ = int(self.K_epochs_)
self.optim_num_ = int(self.optim_num_)
self.generate_num_ = int(self.generate_num_)
self.kl_coefficient_ = float(self.kl_coefficient_)

if config["reference"] not in adapters:
self.reference_ = None
logging.info("PPOTask - use the base model as reference model.")
else:
self.reference_ = adapters[config["reference"]]


TASKCONFIG_CLASS: Dict[str, Type[TaskConfig]] = {
"train": TrainTaskConfig,
"dpo": DPOTaskConfig,
"cpo": CPOTaskConfig,
"cit": CITTaskConfig,
"ppo": PPOTaskConfig,
}
3 changes: 3 additions & 0 deletions mlora/executor/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .cit_task import CITTask
from .cpo_task import CPOTask
from .dpo_task import DPOTask
from .ppo_task import PPOTask
from .task import Task
from .train_task import TrainTask

Expand All @@ -12,6 +13,7 @@
"dpo": DPOTask,
"cpo": CPOTask,
"cit": CITTask,
"ppo": PPOTask,
}


Expand All @@ -32,5 +34,6 @@ def register_task_class(type_name: str, task: Type[Task]):
"DPOTask",
"CPOTask",
"CITTask",
"PPOTask",
"register_task_class",
]
Loading

0 comments on commit 4abf120

Please sign in to comment.