Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPOtask for mlora_train #279

Merged
merged 17 commits into from
Jan 26, 2025
81 changes: 81 additions & 0 deletions demo/ppo/ppo_case1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
dispatcher:
name: "default"
concurrency_num: 2
datasets:
- name: "ppo_data"
data: "demo/data.json"
prompt: "demo/ppo_prompt.yaml"
prompt_type: "ppo"
preprocess: "default"
adapters:
- name: "lora_ppo_reward"
type: "lora"
path: "adapters/lora_ppo_reward"
optimizer: "adamw"
lr: 1e-4
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
- name: "lora_ppo_critic"
type: "lora"
path: "adapters/lora_ppo_critic"
optimizer: "adamw"
lr: 5e-5
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
- name: "lora_ppo_actor"
type: "lora"
path: "adapters/lora_ppo_actor"
optimizer: "adamw"
lr: 5e-5
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
tasks:
- type: "ppo"
name: "task_0"
adapter:
reward_adapter: "lora_ppo_reward"
actor_adapter: "lora_ppo_actor"
critic_adapter: "lora_ppo_critic"
reference: "base"
dataset: "ppo_data"
batch_size: 16
mini_batch_size: 16
num_epochs: 20
K_epochs: 5
optim_num: 2
cutoff_len: 256
save_step: 100
gamma: 0.99
lamdb: 0.99
kl_coefficient: 0.99
generate_num: 32
critic_loss_type: "mse"
actor_loss_type: "adv_loss"
reward_loss_type: "reward_loss"
8 changes: 8 additions & 0 deletions demo/ppo_prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
template: |
{% if Optional=="instruction" %}
### Instruction: {{ data_point['instruction'] + '\n'}}
{% elif Optional=="chosen" %}
### Instruction: {{ data_point['instruction'] + '\n'}} ### Output: {{ data_point['chosen'] + '\n'}}
{% else %}
### Instruction: {{ data_point['instruction'] + '\n'}} ### Output: {{ data_point['reject'] + '\n'}}
{% endif %}
2 changes: 2 additions & 0 deletions mlora/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CITTaskConfig,
CPOTaskConfig,
DPOTaskConfig,
PPOTaskConfig,
TaskConfig,
TrainTaskConfig,
)
Expand All @@ -37,4 +38,5 @@
"ADAPTERCONFIG_CLASS",
"OptimizerConfig",
"LRSchedulerConfig",
"PPOTaskConfig",
]
60 changes: 59 additions & 1 deletion mlora/config/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ def __init__(
super().__init__(config)
self.init(self.__params_map, config)

self.adapter_ = adapters[config["adapter"]]
if isinstance(config["adapter"], dict):
self.reward_adapter_ = adapters[config["adapter"]["reward_adapter"]]
self.actor_adapter_ = adapters[config["adapter"]["actor_adapter"]]
self.critic_adapter_ = adapters[config["adapter"]["critic_adapter"]]
else:
self.adapter_ = adapters[config["adapter"]]

self.dataset_: DatasetConfig | None = datasets[config["dataset"]]


Expand Down Expand Up @@ -148,9 +154,61 @@ def __init__(
self.temperature_ = float(self.temperature_)


class PPOTaskConfig(TrainTaskConfig):
gamma_: float
lamdb_: float
K_epochs_: int
T_horizon_: int
critic_loss_type_: str
actor_loss_type_: str
reward_loss_type_: str
clip_rate_: float
generate_num_: int
reward_adapter_: AdapterConfig
critic_adapter_: AdapterConfig
actor_adapter_: AdapterConfig
kl_coefficient_: float
optim_num_: int

__params_map: Dict[str, str] = {
"gamma_": "gamma",
"lamdb_": "lamdb",
"K_epochs_": "K_epochs",
"optim_num_": "optim_num",
"critic_loss_type_": "critic_loss_type",
"actor_loss_type_": "actor_loss_type",
"reward_loss_type_": "reward_loss_type",
"generate_num_": "generate_num",
"kl_coefficient_": "kl_coefficient",
}

def __init__(
self,
config: Dict[str, str],
adapters: Mapping[str, AdapterConfig],
datasets: Mapping[str, DatasetConfig],
):
super().__init__(config, adapters, datasets)
self.init(self.__params_map, config)

self.gamma_ = float(self.gamma_)
self.lamdb_ = float(self.lamdb_)
self.K_epochs_ = int(self.K_epochs_)
self.optim_num_ = int(self.optim_num_)
self.generate_num_ = int(self.generate_num_)
self.kl_coefficient_ = float(self.kl_coefficient_)

if config["reference"] not in adapters:
self.reference_ = None
logging.info("PPOTask - use the base model as reference model.")
else:
self.reference_ = adapters[config["reference"]]


TASKCONFIG_CLASS: Dict[str, Type[TaskConfig]] = {
"train": TrainTaskConfig,
"dpo": DPOTaskConfig,
"cpo": CPOTaskConfig,
"cit": CITTaskConfig,
"ppo": PPOTaskConfig,
}
3 changes: 3 additions & 0 deletions mlora/executor/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .cit_task import CITTask
from .cpo_task import CPOTask
from .dpo_task import DPOTask
from .ppo_task import PPOTask
from .task import Task
from .train_task import TrainTask

Expand All @@ -12,6 +13,7 @@
"dpo": DPOTask,
"cpo": CPOTask,
"cit": CITTask,
"ppo": PPOTask,
}


Expand All @@ -32,5 +34,6 @@ def register_task_class(type_name: str, task: Type[Task]):
"DPOTask",
"CPOTask",
"CITTask",
"PPOTask",
"register_task_class",
]
Loading
Loading