Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dpo training (Do not merge) #63

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2852c0a
add dpo training for arrow routing
shuishen112 Jul 12, 2024
fb47f71
add dpo train with preference data
shuishen112 Jul 12, 2024
665bfc4
add preference data module
shuishen112 Jul 13, 2024
330813c
add log for training and validation
shuishen112 Jul 15, 2024
ec73c5a
wip
shuishen112 Jul 15, 2024
93b5526
fix isort
shuishen112 Jul 15, 2024
1f97122
fix optimization with efficient checkpoint module
shuishen112 Jul 16, 2024
b4fa78b
wip
shuishen112 Jul 16, 2024
6b20a61
fix dpo training
shuishen112 Jul 17, 2024
8c61ba3
add simpo training
shuishen112 Jul 17, 2024
0f0be96
fix
Jul 17, 2024
3de4297
isort
shuishen112 Jul 17, 2024
d3b77a1
wip
Jul 17, 2024
6601b67
get cluster dataset by embedding
Jul 24, 2024
d69f03b
add args for dpo training
shuishen112 Jul 30, 2024
4a0d3d6
add log
shuishen112 Jul 31, 2024
b73de95
wip
shuishen112 Aug 7, 2024
7b9f417
add length normalization
shuishen112 Aug 7, 2024
0566d96
fix
shuishen112 Aug 7, 2024
1dba64d
make the reference model eval mode
Aug 21, 2024
38049bb
fix train_batch_size
shuishen112 Aug 21, 2024
75a80d9
wip
shuishen112 Aug 21, 2024
3904f11
add ultrafeedback data
shuishen112 Aug 28, 2024
d169661
add urtra data train
shuishen112 Aug 28, 2024
fe40732
add ultralfeedback sft
shuishen112 Sep 5, 2024
49ee734
training module with ultral sft
shuishen112 Sep 10, 2024
8f15188
fix save
shuishen112 Nov 16, 2024
bf45a61
fix arrow
shuishen112 Nov 16, 2024
2f08190
clean
shuishen112 Nov 17, 2024
8858e10
merge
shuishen112 Nov 17, 2024
b33316b
merge the code to the current version
shuishen112 Nov 18, 2024
96bb1c4
clean the code
shuishen112 Nov 18, 2024
54da222
clean the code
shuishen112 Nov 18, 2024
147878e
add orca datamodule
shuishen112 Nov 18, 2024
23b3b36
add orca sft training
shuishen112 Nov 18, 2024
19fb72c
update orca
shuishen112 Nov 19, 2024
9d9ab5b
chose finetune task
shuishen112 Nov 19, 2024
718b138
format
shuishen112 Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mttl/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ class TrainingArgs(DataArgs):

profile: bool = False # if 'True' will profile the model training

# dpo
rl_training: str = "dpo"

@property
def dataset_config(self):
if self.dataset_type is not None:
Expand Down
352 changes: 352 additions & 0 deletions mttl/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
import argparse
import ast
import json
import os
from string import Template
from typing import Dict

from mttl.utils import logger, setup_logging


class Config:
def __init__(self, filenames=None, kwargs=None, raise_error=True, silent=False):
# Stores personalization of the config file in a dict (json serializable)

self._updated_kwargs = {}
self.filenames = filenames
self._set_defaults()

overwrite_logs = []
if filenames:
for filename in filenames.split("+"):
if not os.path.exists(filename):
filename = os.path.join(
os.getenv("CONFIG_PATH", default="configs"), filename
)

if not os.path.exists(filename) and ".json" not in filename:
filename = filename + ".json"

overwrite_logs += self.update_kwargs(
json.load(open(filename)),
eval=False,
raise_error=raise_error,
silent=silent,
)

if kwargs:
overwrite_logs += self.update_kwargs(
kwargs, raise_error=raise_error, silent=silent
)

# setup logging to the output dir
try:
setup_logging(self.output_dir)
except Exception as e:
if raise_error:
raise ValueError("Error setting up logging") from e
elif not silent:
logger.warning(f"Error setting up logging to {self.output_dir}")

# log the overwrites
for log in overwrite_logs:
logger.warning(log)

self.post_init(silent=silent)

def post_init(self, silent=False):
pass

@classmethod
def fromdict(cls, data):
_ = data.pop("_updated_kwargs", None)
return cls(kwargs=data, raise_error=False, silent=True)

def asdict(self) -> Dict:
from mttl.models.utils import convert_hps_to_dict

return convert_hps_to_dict(self.__dict__)

def was_overridden(self, key):
return key in self._updated_kwargs

def was_default(self, key):
return key not in self._updated_kwargs

def update_kwargs(self, kwargs, eval=True, raise_error=True, silent=False):
overwrites_log = []
for k, v in kwargs.items():
if eval:
try:
v = ast.literal_eval(v)
except (ValueError, SyntaxError):
v = v
else:
v = v

if not hasattr(self, k) and raise_error:
raise ValueError(f"{k} is not in the config")

if eval and not silent:
overwrites_log.append(f"Overwriting {k} to {v}")

if type(v) == str and "$" in v:
# this raises an error if the env. var does not exist
v = Template(v).substitute(os.environ)

setattr(self, k, v)
self._updated_kwargs[k] = v
return overwrites_log

def __getitem__(self, item):
return getattr(self, item, None)

def to_json(self):
"""
Converts parameter values in config to json
:return: json
"""
import copy

to_save = copy.deepcopy(self.__dict__)
to_save.pop("_updated_kwargs")

return json.dumps(to_save, indent=4, sort_keys=False)

def save_config(self, output_dir):
"""
Saves the config
"""
os.makedirs(output_dir, exist_ok=True)

with open(os.path.join(output_dir, "config.json"), "w+") as fout:
fout.write(self.to_json())
fout.write("\n")

@classmethod
def parse(
cls,
extra_kwargs=None,
raise_error=True,
parent=None,
return_parser=False,
c=None,
):
import itertools

# dont do it if called from jupyter notebook
if c is None:
parser = (
argparse.ArgumentParser(parents=[parent])
if parent
else argparse.ArgumentParser()
)
parser.add_argument("-c", "--config_files", required=False)
parser.add_argument("-k", "--kwargs", nargs="*", action="append")
args = parser.parse_args()
else:
args = argparse.Namespace()
args.kwargs = None
args.config_files = c

kwargs = {}
if args.kwargs:
kwargs_opts = list(itertools.chain(*args.kwargs))
for value in kwargs_opts:
key, _, value = value.partition("=")

# allows multiple values for a given option when specified in the command line!
if key in kwargs:
if type(kwargs[key]) != list:
kwargs[key] = [kwargs[key]]
kwargs[key].append(value)
else:
kwargs[key] = value

args.kwargs = kwargs
if extra_kwargs:
args.kwargs.update(extra_kwargs)

config = cls(
filenames=args.config_files, kwargs=args.kwargs, raise_error=raise_error
)

if return_parser:
return config, args
return config

def _set_defaults(self):
self.cache_dir = os.getenv("CACHE_DIR", "./cache")

# Data config
self.dataset = None
self.custom_tasks_splits = None

self.data_dir = os.getenv("TRAIN_DIR", "/tmp/")
self.output_dir = os.getenv("OUTPUT_DIR", "./output")

self.finetune_task_name = None
self.example_to_ids_path = None # path to clustering of data
self.embeddings_path = None

# NI related configs
self.use_task_descriptions = False # Use task descriptions
self.max_num_instances_per_task = (
100 # Max instances per training task (applies to NI)
)
self.num_pos_examples = (
0 # Use some few-shot examples if possible (applies to NI)
)

self.task_prefix = None # xfit has task prefixes detailing # of shots, seed, etc; this is automatically filled in at fine-tuning time
self.exp_name = None
self.wandb_project = None
self.padding_side = "right"
self.truncation_side = "right"
self.max_input_length = 512
self.max_output_length = 64
self.num_beams = 4
self.append_another_bos = False
self.do_lowercase = False
self.freeze_embeds = False

# T0 related configs
self.use_t0_templates_as_tasks = (
False # if True, then t0 consists of 313 tasks, otherwise 38
)
self.use_t0_few_shot_training_set = False # if True, then use 100 examples per task during training + 100 examples per validation task

# Filtering configs, useful for flan flat, etc.
self.include_template_type = "zs_noopt"
self.include_task_source = "P3,Flan2021,CoT"
self.remove_phi_eval_tasks = False

# Training config
self.compute_strategy = None
self.padding_side = "right"
self.scheduler = "linear_decay_with_warmup"
self.checkpoint = None # load from checkpoint
self.checkpoint_step = None # load from checkpoint in format of global_stepX.pt
self.backbone_checkpoint = None # load the backbone from here
self.train_batch_size = 8
self.predict_batch_size = 32
self.learning_rate = 1e-3
self.warmup_proportion = 0.06
self.trainable_param_names = ".*"
self.non_trainable_param_names = None
self.weight_decay = 0.01
self.adam_epsilon = 1e-8
self.max_grad_norm = 0.1
self.gradient_accumulation_steps = 1
self.optimizer = "adamw"
self.adafactor_scale_parameter = True
self.adafactor_warmup_init = False
self.adafactor_relative_step = False
self.num_train_epochs = -1
self.warmup_steps = -1
self.total_steps = -1
self.num_tasks_per_batch = None
self.save_every = None
self.eval_every = None
self.eval_every_n_epoch = 1
self.debug = False
self.seed = 42
self.eval_before_training = True

self.subsample_train = None
self.subsample_dev = None
self.subsample_test = None
self.subsample_per_task = False

self.ni_online_eval = False # zero-shot online eval for ni
self.t0_online_eval = False # zero-shot eval for t0
self.early_stop_on_zero_shot = False # zero-shot early stopping

# auxiliary losses
self.ortho_loss = 0.0 # orthogonality between skills
self.task_loss = 0.0 # task prediction loss (mi between tasks and skills)
self.l1_loss = 0.0 # sparsity of the logits
self.mi_loss = (
0.0 # mi between tasks and skills (difference of entropies method)
)
self.mc_loss = 0.0 # T-Few
self.length_norm = 0.0 # T-Few
self.unlikely_loss = 0.0 # T-Few
self.poly_unlikely_loss = 0.0 # poly unlikelihood loss
self.finetune_type = None # ["F", "A", "Z", "MuZ", "Poly", "PolyRand"]
self.finetune_skip_es = False # skip early stopping while fine-tuning
self.finetune_use_last_checkpoint = (
False # use always the best valid_perf checkpoint if available
)

self.model = None
self.model_family = None # model family, either "gpt" or "encdec"

self.precision = "32"
self.monitor_grad_alignment_on = None

self.model_modifier = None
self.adapter_type = None

self.lora_rank = 16
self.lora_dropout = 0.0
self.lora_init_scale = 0.01
self.lora_alpha = 1.0
self.lora_warmup = False
self.lora_init_b_random = False
self.lora_dropout = 0.0

# n-skills for router-based methods
self.n_skills = 8
self.n_tasks = None
self.task_names = None

# which modules to modify and which layers for adapters
self.modify_modules = None
self.modify_layers = None
self.tie_params = None # "q_proj\\.lora_a|k_proj\\.lora_a|v_proj\\.lora_a" to share lora_a for q,k,v

"""
router_granularity : how granular is the module selection
coarsegrained : 1 single selector across all linear layers
coderwise : 2 selectors (1 for encoder, 1 for decoder)
blockwise : 1 selector for each block of K attention layers (and layernorm)
layerwise : 1 selector for each attention layer (and layernorm)
finegrained : 1 selector for every linear layer
"""
self.router_granularity = "finegrained" # router granularity
self.router_selector = None # router selector
self.router_weight_decay = None # router weight decay
self.router_learning_rate = None

# Polytropon related hyper-parameters
self.n_splits = 1 # number of splits for poly-s
self.router_selector_cluster_temp = 1.0 # temperature for the cluster selector
self.poly_average_correction = False # correct the poly average
self.poly_use_shared_skill = False # use one skill shared by all tasks
self.skip_unseen_tokens = (
True # skip unseen tokens in PerTokenPoly during evaluation
)

self.module_logits_relaxed_bernoulli = True
self.module_logits_straight_through = False
self.module_logits_learning_rate = 0.1
self.adapters_learning_rate = None
self.adapters_weight_decay = None
self.module_logits_dropout = 0.0
self.module_logits_l2_norm = False

self.augment_mmlu: bool = False

# soft prompts
self.soft_prompt_length: int = 10
self.patch_last_k_layers: int = -1
self.soft_prompt_mlp_dim: int = None
self.soft_prompt_hidden_dim: int = None
self.soft_prompt_learn_kv: bool = False
self.prompt_placement: str = "prefix"
self.add_routing_token: bool = False

# rl training
self.rl_training = "dpo"
self.beta = 0.5
16 changes: 16 additions & 0 deletions mttl/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,12 @@ def get_datamodule(args, for_generation=False, dataset_override=None):
WinograndeMultiChoiceDataModule,
)

from mttl.datamodule.ultrafeedback_data_module import (
UltrafeedbackSFTmodule,
)

from mttl.datamodule.orca_data_module import OrcaDataModule

# if we have a DataArgs object, we can directly create the datamodule
if isinstance(args, DataArgs) and args.dataset_type is not None:
dataset_config = args.dataset_config
Expand Down Expand Up @@ -1063,6 +1069,16 @@ def get_datamodule(args, for_generation=False, dataset_override=None):
pack_sequences=args.pack_sequences,
)
dm = FlatMultiTaskModule(config, for_generation=for_generation)
elif "ultrachat" in dataset:
config = DatasetConfig(
**common_kwargs,
)
dm = UltrafeedbackSFTmodule(config, for_generation=for_generation)
elif "orca" in dataset:
config = DatasetConfig(
**common_kwargs,
)
dm = OrcaDataModule(config, for_generation=for_generation)
elif "mmlu" in dataset:
config = MMLUDataConfig(
**common_kwargs,
Expand Down
Loading
Loading