Skip to content

Commit

Permalink
Add trust remote code to tokenizer (#50)
Browse files Browse the repository at this point in the history
* add

* update docker

* remove remote code on DPO script

* up
  • Loading branch information
natolambert authored Mar 5, 2024
1 parent f062e54 commit b72307b
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ When updating the `Dockerfile`, make sure to see the instructions at the top to

In development, we have the following docker images (most recent first as it's likely what you need).
TODO: we should log the git commit affiliated with each of these, or delete them when outdated.
- `nathanl/herm_v5`: chat template loading from tokenizer fixes + DPO additions.
- `nathanl/herm_v6`: chat template loading from tokenizer fixes + DPO additions.

Deprecated:
- `nathanl/herm_dpo`: for adding functionality with DPO sweeps, fix minor bugs (last updated 24 Feb.)
Expand Down
3 changes: 2 additions & 1 deletion herm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
__version__ = "0.1.0.dev"
from .chattemplates import * # noqa
from .dpo import DPOInference
from .models import REWARD_MODEL_CONFIG
from .models import DPO_MODEL_CONFIG, REWARD_MODEL_CONFIG
from .utils import (
load_bon_dataset,
load_eval_dataset,
Expand All @@ -26,6 +26,7 @@

__all__ = [
DPOInference,
DPO_MODEL_CONFIG,
load_bon_dataset,
load_eval_dataset,
prepare_dialogue,
Expand Down
15 changes: 15 additions & 0 deletions herm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.

from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
LlamaTokenizer,
MixtralForCausalLM,
T5ForConditionalGeneration,
pipeline,
)
Expand Down Expand Up @@ -103,3 +107,14 @@
"model_type": "Seq. Classifier",
},
}

DPO_MODEL_CONFIG = {
"default": {
"model_builder": AutoModelForCausalLM.from_pretrained,
"tokenizer_builder": AutoTokenizer.from_pretrained,
},
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": {
"model_builder": MixtralForCausalLM.from_pretrained,
"tokenizer_builder": LlamaTokenizer.from_pretrained,
},
}
6 changes: 3 additions & 3 deletions scripts/configs/eval_dpo_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ mistralai/Mixtral-8x7B-Instruct-v0.1:
ref_model: mistralai/Mixtral-8x7B-v0.1
tokenizer: mistralai/Mixtral-8x7B-Instruct-v0.1
chat_template:
batch_size: 2
batch_size: 1
num_gpus: 4
trust_remote_code: False
NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO:
ref_model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT
tokenizer: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO
chat_template:
batch_size: 2
batch_size: 1
num_gpus: 4
trust_remote_code: False
trust_remote_code: True
NousResearch/Nous-Hermes-2-Mistral-7B-DPO:
ref_model: teknium/OpenHermes-2.5-Mistral-7B
tokenizer: NousResearch/Nous-Hermes-2-Mistral-7B-DPO
Expand Down
4 changes: 3 additions & 1 deletion scripts/run_bon.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def main():
else:
config = REWARD_MODEL_CONFIG["default"]
logger.info(f"Using reward model config: {config}")
if args.trust_remote_code:
logger.info("Loading model with Trust Remote Code")

# Default entries
# "model_builder": AutoModelForSequenceClassification.from_pretrained,
Expand All @@ -113,7 +115,7 @@ def main():
############################
logger.info("*** Load dataset ***")
tokenizer_path = args.tokenizer if args.tokenizer else args.model
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=args.trust_remote_code)
dataset = load_bon_dataset(
best_of=args.best_of,
conv=conv,
Expand Down
20 changes: 15 additions & 5 deletions scripts/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
from fastchat.conversation import get_conv_template
from huggingface_hub import HfApi
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl.trainer.utils import DPODataCollatorWithPadding

from herm import DPOInference, load_eval_dataset, save_to_hub
from herm import DPO_MODEL_CONFIG, DPOInference, load_eval_dataset, save_to_hub

# get token from HF_TOKEN env variable, but if it doesn't exist pass none
HF_TOKEN = os.getenv("HF_TOKEN", None)
Expand Down Expand Up @@ -84,6 +83,17 @@ def main():
transformers.utils.logging.enable_explicit_format()

logger.info(f"Running reward model on {args.model} with chat template {args.chat_template}")
if args.trust_remote_code:
logger.info("Loading model with Trust Remote Code")

if args.model in DPO_MODEL_CONFIG:
config = DPO_MODEL_CONFIG[args.model]
else:
config = DPO_MODEL_CONFIG["default"]
logger.info(f"Using dpo model config: {config}")

model_builder = config["model_builder"]
tokenizer_builder = config["tokenizer_builder"]

assert args.model != args.ref_model, "policy and reference model should be different"
# load chat template
Expand All @@ -103,7 +113,7 @@ def main():
############################
logger.info("*** Load dataset ***")
tokenizer_path = args.tokenizer if args.tokenizer else args.model
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer = tokenizer_builder(tokenizer_path, trust_remote_code=args.trust_remote_code)
tokenizer.pad_token = tokenizer.eos_token
# if no BOS token, set as pad token, e.g. QWEN models
if tokenizer.bos_token is None:
Expand Down Expand Up @@ -134,7 +144,7 @@ def main():
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
model = AutoModelForCausalLM.from_pretrained(
model = model_builder(
args.model,
trust_remote_code=args.trust_remote_code,
**model_kwargs,
Expand All @@ -148,7 +158,7 @@ def main():
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
ref_model = AutoModelForCausalLM.from_pretrained(
ref_model = model_builder(
args.ref_model,
trust_remote_code=args.trust_remote_code,
**model_kwargs_ref,
Expand Down
4 changes: 3 additions & 1 deletion scripts/run_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def main():
transformers.utils.logging.enable_explicit_format()

logger.info(f"Running reward model on {args.model} with chat template {args.chat_template}")
if args.trust_remote_code:
logger.info("Loading model with Trust Remote Code")

# load chat template
chat_template = args.chat_template
Expand Down Expand Up @@ -113,7 +115,7 @@ def main():
############################
logger.info("*** Load dataset ***")
tokenizer_path = args.tokenizer if args.tokenizer else args.model
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=args.trust_remote_code)
if not custom_dialogue: # not needed for PairRM / SteamSHP
tokenizer.truncation_side = "left" # copied from Starling, but few samples are above context length
dataset, subsets = load_eval_dataset(
Expand Down
5 changes: 2 additions & 3 deletions scripts/submit_eval_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
argparser.add_argument("--eval_dpo", action="store_true", default=False, help="Evaluate DPO model suite")
argparser.add_argument("--eval_on_bon", action="store_true", default=False, help="Evaluate on BON preference sets")
argparser.add_argument("--image", type=str, default="nathanl/herm_v5", help="Beaker image to use")
argparser.add_argument("--image", type=str, default="nathanl/herm_v6", help="Beaker image to use")
argparser.add_argument("--cluster", type=str, default="ai2/allennlp-cirrascale", help="Beaker cluster to use")
argparser.add_argument("--upload_to_hub", action="store_false", default=True, help="Upload to results to HF hub")
argparser.add_argument("--model", type=str, default=None, help="Specific model to evaluate if not sweep")
Expand Down Expand Up @@ -116,8 +116,7 @@
if model_config["chat_template"] is not None:
d["tasks"][0]["arguments"][0] += f" --chat_template {model_config['chat_template']}"
if model_config["trust_remote_code"]:
if not eval_dpo: # TODO create trust remote code option in DPO script
d["tasks"][0]["arguments"][0] += " --trust_remote_code"
d["tasks"][0]["arguments"][0] += " --trust_remote_code"
if not upload_to_hub:
d["tasks"][0]["arguments"][0] += " --do_not_save"
if eval_on_pref_sets:
Expand Down

0 comments on commit b72307b

Please sign in to comment.