From a7cf68b5923d42cc16f3a1fa2d9b2b863d140866 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Sun, 5 May 2024 11:32:53 -0700 Subject: [PATCH] Make RewardBench pip installable + runable! (#121) --- README.md | 32 +++- rewardbench/__init__.py | 4 +- rewardbench/__main__.py | 329 +++++++++++++++++++++++++++++++++ rewardbench/rewardbench.py | 367 +++++++++++++++++++++++++++++++++++++ rewardbench/utils.py | 122 +++++++++++- setup.py | 19 +- tests/test_package.py | 45 +++++ 7 files changed, 907 insertions(+), 11 deletions(-) create mode 100644 rewardbench/__main__.py create mode 100644 rewardbench/rewardbench.py create mode 100644 tests/test_package.py diff --git a/README.md b/README.md index f3cf3d9d..d67a345f 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,36 @@ The two primary scripts to generate results (more in `scripts/`): 2. `scripts/run_dpo.py`: Run evaluations for direct preference optimization (DPO) models (and other models using implicit rewards, such as KTO). 3. `scripts/train_rm.py`: A basic RM training script built on [TRL](https://github.com/huggingface/trl). -## Installation -Please install `torch` on your system, and then install the following requirements. +## Quick Usage +RewardBench let's you quickly evaluate any reward model on any preference set. +To install for quick usage, install with pip as: +``` +pip install reward bench +``` +Then, run a following: +``` +rewardbench --model={yourmodel} --dataset={yourdataset} --batch_size=8 +``` +For a DPO model, pass --ref_model={} and the script will automatically route there. +Automatically uses Tokenizers chat templates, but can also use fastchat conv templates. + +To run the core Reward Bench evaluation set, run: +``` +rewardbench --model={yourmodel} +``` + +Examples: +1. Normal operation +``` +rewardbench --model=OpenAssistant/reward-model-deberta-v3-large-v2 --dataset=allenai/ultrafeedback_binarized_cleaned --split=test_gen --chat_template=raw +``` +2. DPO model from local dataset (note `--load_json`) +``` +rewardbench --model=Qwen/Qwen1.5-0.5B-Chat --ref_model=Qwen/Qwen1.5-0.5B --dataset=/net/nfs.cirrascale/allennlp/jacobm/herm/data/berkeley-nectar-binarized-preferences-random-rejected.jsonl --load_json +``` + +## Full Installation +To install from source, please install `torch` on your system, and then install the following requirements. ``` pip install -e . ``` diff --git a/rewardbench/__init__.py b/rewardbench/__init__.py index 448c3534..040ece23 100644 --- a/rewardbench/__init__.py +++ b/rewardbench/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0.dev" +__version__ = "0.1.1" from .chattemplates import * # noqa from .dpo import DPOInference from .models import DPO_MODEL_CONFIG, REWARD_MODEL_CONFIG @@ -20,6 +20,7 @@ check_tokenizer_chat_template, load_bon_dataset, load_eval_dataset, + load_preference_dataset, prepare_dialogue, prepare_dialogue_from_tokenizer, save_to_hub, @@ -31,6 +32,7 @@ DPO_MODEL_CONFIG, load_bon_dataset, load_eval_dataset, + load_preference_dataset, prepare_dialogue, prepare_dialogue_from_tokenizer, REWARD_MODEL_CONFIG, diff --git a/rewardbench/__main__.py b/rewardbench/__main__.py new file mode 100644 index 00000000..0fadd9c8 --- /dev/null +++ b/rewardbench/__main__.py @@ -0,0 +1,329 @@ +# Copyright 2023 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Run RewardBench (evaluate any reward model on any dataet) +import argparse +import json +import logging +import os +import sys + +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from tqdm import tqdm +from transformers import AutoTokenizer + +from rewardbench import ( + DPO_MODEL_CONFIG, + REWARD_MODEL_CONFIG, + check_tokenizer_chat_template, + load_preference_dataset, +) + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate a reward model.") + + # core args + parser.add_argument("--dataset", type=str, required=True, help="The dataset to evaluate on.") + parser.add_argument("--split", type=str, default=None, help="The split to evaluate on.") + parser.add_argument("--model", type=str, required=True, help="The model to evaluate.") + parser.add_argument("--ref_model", type=str, default=None, help="The reference model to compare against.") + parser.add_argument("--tokenizer", type=str, default=None, help="The tokenizer to use (defaults to model).") + parser.add_argument( + "--chat_template", + type=str, + default=None, + help="The chat template to use (defaults to from tokenizer, from chattemplate).", + ) + + # inference args + parser.add_argument("--batch_size", type=int, default=8, help="The batch size to use.") + parser.add_argument("--max_length", type=int, default=512, help="The max length to use.") + + # system args + parser.add_argument("--load_json", action="store_true", default=False, help="Load dataset as json.") + parser.add_argument("--trust_remote_code", action="store_true", default=False, help="Trust remote code.") + parser.add_argument("--debug", action="store_true", default=False, help="Debug mode.") + parser.add_argument("--output_dir", type=str, default="results/", help="The output directory to save results.") + parser.add_argument("--save_all", action="store_true", default=False, help="Save all results.") + args = parser.parse_args() + + ############### + # Setup logging + ############### + accelerator = Accelerator() + current_device = accelerator.process_index + + logger = get_logger(__name__) + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = logging.INFO + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.info(f"Running reward model on {args.model} with chat template {args.chat_template}") + if args.trust_remote_code: + logger.info("Loading model with Trust Remote Code") + + # basic checks from config + if args.ref_model: + is_dpo = True + MODEL_CONFIGS = DPO_MODEL_CONFIG + assert args.model != args.ref_model, "policy and reference model should be different" + from trl.trainer.utils import DPODataCollatorWithPadding + + from rewardbench import DPOInference + else: + is_dpo = False + MODEL_CONFIGS = REWARD_MODEL_CONFIG + + if args.chat_template: + from fastchat.conversation import get_conv_template + + conv = get_conv_template(args.chat_template) + else: + conv = None + + if args.model in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.model] + else: + config = MODEL_CONFIGS["default"] + logger.info(f"Using reward model config: {config}") + + # Default entries + # "model_builder": AutoModelForSequenceClassification.from_pretrained, + # "pipeline_builder": pipeline, + # "quantized": True, + # "custom_dialogue": False, + # "model_type": "Seq. Classifier" + + if not is_dpo: + quantized = config["quantized"] # only Starling isn't quantized for now + custom_dialogue = config["custom_dialogue"] + pipeline_builder = config["pipeline_builder"] + _ = config["model_type"] + if custom_dialogue: + raise NotImplementedError("Custom dialogue not implemented yet for simpler data formatting.") + + model_builder = config["model_builder"] + + ######################### + # load dataset + ######################### + logger.info("*** Load dataset ***") + tokenizer_path = args.tokenizer if args.tokenizer else args.model + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=args.trust_remote_code) + dataset = load_preference_dataset( + args.dataset, split=args.split, json=args.load_json, tokenizer=tokenizer, conv=conv + ) + + if args.debug: + dataset = dataset.select(range(10)) + + logger.info("*** Load reward model ***") + + ############################ + # Load DPO model pipeline + ############################ + if is_dpo: + tokenizer.pad_token = tokenizer.eos_token + # if no BOS token, set as pad token, e.g. QWEN models + if tokenizer.bos_token is None: + tokenizer.bos_token_id = tokenizer.eos_token_id + tokenizer.pad_token_id = tokenizer.eos_token_id + + model_kwargs = { + "load_in_8bit": True, + "device_map": "auto", + "torch_dtype": torch.float16 if torch.cuda.is_available() else None, + } + model = model_builder( + args.model, + trust_remote_code=args.trust_remote_code, + **model_kwargs, + ) + ref_model = model_builder( + args.ref_model, + trust_remote_code=args.trust_remote_code, + **model_kwargs, + ) + + # use internal inference functions in DPO trainer + dpo = DPOInference( + model, + ref_model, + tokenizer=tokenizer, + accelerator=accelerator, + # norm is norm, avg is average, sum is sum + ) + + # tokenize dataset + column_names = list(dataset.features) + + tokenized_dataset = dataset.map(dpo.tokenize_row, remove_columns=column_names) + dataloader = torch.utils.data.DataLoader( + tokenized_dataset, + batch_size=args.batch_size, + collate_fn=DPODataCollatorWithPadding( + pad_token_id=tokenizer.pad_token_id, + label_pad_token_id=dpo.label_pad_token_id, + is_encoder_decoder=dpo.is_encoder_decoder, + ), + # collate_fn = lambda x: x, # fix weird batching error + shuffle=False, + drop_last=False, + ) + + ############################ + # Load classifier model pipeline + ############################ + else: + reward_pipeline_kwargs = { + "batch_size": args.batch_size, # eval_args.inference_batch_size, + "truncation": True, + "padding": True, + "max_length": args.max_length, + "function_to_apply": "none", # Compute raw logits + "return_token_type_ids": False, + } + if quantized: + model_kwargs = { + "load_in_8bit": True, + "device_map": {"": current_device}, + "torch_dtype": torch.float16 if torch.cuda.is_available() else None, + } + else: + model_kwargs = {"device_map": {"": current_device}} + + model = model_builder(args.model, **model_kwargs, trust_remote_code=args.trust_remote_code) + reward_pipe = pipeline_builder( + "text-classification", # often not used + model=model, + tokenizer=tokenizer, + ) + + # set pad token to eos token if not set + if reward_pipe.tokenizer.pad_token_id is None: + reward_pipe.model.config.pad_token_id = reward_pipe.tokenizer.eos_token_id + reward_pipe.tokenizer.pad_token_id = reward_pipe.tokenizer.eos_token_id + # For models whose config did not contains `pad_token_id` + if reward_pipe.model.config.pad_token_id is None: + reward_pipe.model.config.pad_token_id = reward_pipe.tokenizer.pad_token_id + + # if using fastchat template (no template in tokenizer), make the RM tokenizer output an EOS token + if not check_tokenizer_chat_template(tokenizer): + reward_pipe.tokenizer.add_eos_token = True + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + ) + + dataloader, model = accelerator.prepare(dataloader, reward_pipe.model) + reward_pipe.model = model + + ############################ + # Run inference + ############################ + + results = [] + scores_chosen = [] + scores_rejected = [] + for step, batch in enumerate(tqdm(dataloader, desc="RM batch steps")): + logger.info(f"RM inference step {step}/{len(dataloader)}") + + if is_dpo: + rewards_chosen, rewards_rejected = dpo.inference_step(batch) + else: + rewards_chosen = reward_pipe(batch["text_chosen"], **reward_pipeline_kwargs) + rewards_rejected = reward_pipe(batch["text_rejected"], **reward_pipeline_kwargs) + + # for each item in batch, record 1 if chosen > rejected + # extra score from dict within batched results (e.g. logits) + # [{'label': 'LABEL_1', 'score': 0.6826171875},... ] + if isinstance(rewards_chosen[0], dict): + score_chosen_batch = [result["score"] for result in rewards_chosen] + score_rejected_batch = [result["score"] for result in rewards_rejected] + # for classes that directly output scores (custom code) + else: + score_chosen_batch = rewards_chosen.cpu().numpy().tolist() + score_rejected_batch = rewards_rejected.cpu().numpy().tolist() + + # log results + [ + results.append(1) if chosen > rejected else results.append(0) + for chosen, rejected in zip(score_chosen_batch, score_rejected_batch) + ] + scores_chosen.extend(score_chosen_batch) + scores_rejected.extend(score_rejected_batch) + + ############################ + # compile scores + ############################ + # calculate accuracy + accuracy = sum(results) / len(results) + logger.info(f"Results: {accuracy}, on {len(results)} prompts") + + ############################ + # compile scores + ############################ + # save score in json to args.output_dir + args.model + ".json" + output_path = args.output_dir + args.model + ".json" + dirname = os.path.dirname(output_path) + os.makedirs(dirname, exist_ok=True) + + # remove old data + if os.path.exists(output_path): + os.remove(output_path) + + with open(output_path, "w") as f: + json.dump( + { + "accuracy": accuracy, + "num_prompts": len(results), + "model": args.model, + "ref_model": args.ref_model, + "tokenizer": tokenizer_path, + "chat_template": args.chat_template, + }, + f, + ) + + # if save_all is passed, save a large jsonl with all scores_chosen, scores_rejected + if args.save_all: + output_path = args.output_dir + args.model + "_all.jsonl" + dirname = os.path.dirname(output_path) + os.makedirs(dirname, exist_ok=True) + + # remove old data + if os.path.exists(output_path): + os.remove(output_path) + + with open(output_path, "w") as f: + for chosen, rejected in zip(scores_chosen, scores_rejected): + f.write(json.dumps({"chosen": scores_chosen, "rejected": scores_rejected}) + "\n") + + +if __name__ == "__main__": + main() diff --git a/rewardbench/rewardbench.py b/rewardbench/rewardbench.py new file mode 100644 index 00000000..31cf8fc8 --- /dev/null +++ b/rewardbench/rewardbench.py @@ -0,0 +1,367 @@ +# Copyright 2023 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Run RewardBench (evaluate any reward model on any dataet) + +import argparse +import json +import logging +import os +import sys + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from tqdm import tqdm +from transformers import AutoTokenizer + +from rewardbench import ( + DPO_MODEL_CONFIG, + REWARD_MODEL_CONFIG, + check_tokenizer_chat_template, + load_preference_dataset, +) + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate a reward model.") + + # core args + parser.add_argument("--dataset", type=str, default="allenai/reward-bench", help="The dataset to evaluate on.") + parser.add_argument("--split", type=str, default=None, help="The split to evaluate on.") + parser.add_argument("--model", type=str, required=True, help="The model to evaluate.") + parser.add_argument("--ref_model", type=str, default=None, help="The reference model to compare against.") + parser.add_argument("--tokenizer", type=str, default=None, help="The tokenizer to use (defaults to model).") + parser.add_argument( + "--chat_template", + type=str, + default=None, + help="The chat template to use (defaults to from tokenizer, from chattemplate).", + ) + + # inference args + parser.add_argument("--batch_size", type=int, default=8, help="The batch size to use.") + parser.add_argument("--max_length", type=int, default=512, help="The max length to use.") + + # system args + parser.add_argument("--load_json", action="store_true", default=False, help="Load dataset as json.") + parser.add_argument("--trust_remote_code", action="store_true", default=False, help="Trust remote code.") + parser.add_argument("--debug", action="store_true", default=False, help="Debug mode.") + parser.add_argument("--output_dir", type=str, default="results/", help="The output directory to save results.") + parser.add_argument("--save_all", action="store_true", default=False, help="Save all results.") + args = parser.parse_args() + + ############### + # Setup logging + ############### + accelerator = Accelerator() + current_device = accelerator.process_index + + logger = get_logger(__name__) + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = logging.INFO + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.info(f"Running reward model on {args.model} with chat template {args.chat_template}") + if args.trust_remote_code: + logger.info("Loading model with Trust Remote Code") + + # basic checks from config + if args.ref_model: + is_dpo = True + MODEL_CONFIGS = DPO_MODEL_CONFIG + assert args.model != args.ref_model, "policy and reference model should be different" + from trl.trainer.utils import DPODataCollatorWithPadding + + from rewardbench import DPOInference + else: + is_dpo = False + MODEL_CONFIGS = REWARD_MODEL_CONFIG + + if args.chat_template: + from fastchat.conversation import get_conv_template + + conv = get_conv_template(args.chat_template) + else: + conv = None + + if args.model in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.model] + else: + config = MODEL_CONFIGS["default"] + logger.info(f"Using reward model config: {config}") + + # Default entries + # "model_builder": AutoModelForSequenceClassification.from_pretrained, + # "pipeline_builder": pipeline, + # "quantized": True, + # "custom_dialogue": False, + # "model_type": "Seq. Classifier" + + if not is_dpo: + quantized = config["quantized"] # only Starling isn't quantized for now + custom_dialogue = config["custom_dialogue"] + pipeline_builder = config["pipeline_builder"] + _ = config["model_type"] + if custom_dialogue: + raise NotImplementedError("Custom dialogue not implemented yet for simpler data formatting.") + + model_builder = config["model_builder"] + + ######################### + # load dataset + ######################### + logger.info("*** Load dataset ***") + tokenizer_path = args.tokenizer if args.tokenizer else args.model + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=args.trust_remote_code) + if args.dataset == "allenai/reward-bench": + logger.info("Running core eval dataset.") + from rewardbench import load_eval_dataset + from rewardbench.constants import EXAMPLE_COUNTS, SUBSET_MAPPING + from rewardbench.utils import calculate_scores_per_section + + # primary set compiles slightly more information + dataset, subsets = load_eval_dataset( + core_set=True, + conv=conv, + custom_dialogue_formatting=False, + tokenizer=tokenizer, + logger=logger, + keep_columns=["text_chosen", "text_rejected", "prompt"], + ) + else: + dataset = load_preference_dataset( + args.dataset, split=args.split, json=args.load_json, tokenizer=tokenizer, conv=conv + ) + + if args.debug: + dataset = dataset.select(range(10)) + + logger.info("*** Load reward model ***") + + ############################ + # Load DPO model pipeline + ############################ + if is_dpo: + tokenizer.pad_token = tokenizer.eos_token + # if no BOS token, set as pad token, e.g. QWEN models + if tokenizer.bos_token is None: + tokenizer.bos_token_id = tokenizer.eos_token_id + tokenizer.pad_token_id = tokenizer.eos_token_id + + model_kwargs = { + "load_in_8bit": True, + "device_map": "auto", + "torch_dtype": torch.float16 if torch.cuda.is_available() else None, + } + model = model_builder( + args.model, + trust_remote_code=args.trust_remote_code, + **model_kwargs, + ) + ref_model = model_builder( + args.ref_model, + trust_remote_code=args.trust_remote_code, + **model_kwargs, + ) + + # use internal inference functions in DPO trainer + dpo = DPOInference( + model, + ref_model, + tokenizer=tokenizer, + accelerator=accelerator, + # norm is norm, avg is average, sum is sum + ) + + # tokenize dataset + column_names = list(dataset.features) + + tokenized_dataset = dataset.map(dpo.tokenize_row, remove_columns=column_names) + dataloader = torch.utils.data.DataLoader( + tokenized_dataset, + batch_size=args.batch_size, + collate_fn=DPODataCollatorWithPadding( + pad_token_id=tokenizer.pad_token_id, + label_pad_token_id=dpo.label_pad_token_id, + is_encoder_decoder=dpo.is_encoder_decoder, + ), + # collate_fn = lambda x: x, # fix weird batching error + shuffle=False, + drop_last=False, + ) + + ############################ + # Load classifier model pipeline + ############################ + else: + reward_pipeline_kwargs = { + "batch_size": args.batch_size, # eval_args.inference_batch_size, + "truncation": True, + "padding": True, + "max_length": args.max_length, + "function_to_apply": "none", # Compute raw logits + "return_token_type_ids": False, + } + if quantized: + model_kwargs = { + "load_in_8bit": True, + "device_map": {"": current_device}, + "torch_dtype": torch.float16 if torch.cuda.is_available() else None, + } + else: + model_kwargs = {"device_map": {"": current_device}} + + model = model_builder(args.model, **model_kwargs, trust_remote_code=args.trust_remote_code) + reward_pipe = pipeline_builder( + "text-classification", # often not used + model=model, + tokenizer=tokenizer, + ) + + # set pad token to eos token if not set + if reward_pipe.tokenizer.pad_token_id is None: + reward_pipe.model.config.pad_token_id = reward_pipe.tokenizer.eos_token_id + reward_pipe.tokenizer.pad_token_id = reward_pipe.tokenizer.eos_token_id + # For models whose config did not contains `pad_token_id` + if reward_pipe.model.config.pad_token_id is None: + reward_pipe.model.config.pad_token_id = reward_pipe.tokenizer.pad_token_id + + # if using fastchat template (no template in tokenizer), make the RM tokenizer output an EOS token + if not check_tokenizer_chat_template(tokenizer): + reward_pipe.tokenizer.add_eos_token = True + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + ) + + dataloader, model = accelerator.prepare(dataloader, reward_pipe.model) + reward_pipe.model = model + + ############################ + # Run inference + ############################ + + results = [] + scores_chosen = [] + scores_rejected = [] + for step, batch in enumerate(tqdm(dataloader, desc="RM batch steps")): + logger.info(f"RM inference step {step}/{len(dataloader)}") + + if is_dpo: + rewards_chosen, rewards_rejected = dpo.inference_step(batch) + else: + rewards_chosen = reward_pipe(batch["text_chosen"], **reward_pipeline_kwargs) + rewards_rejected = reward_pipe(batch["text_rejected"], **reward_pipeline_kwargs) + + # for each item in batch, record 1 if chosen > rejected + # extra score from dict within batched results (e.g. logits) + # [{'label': 'LABEL_1', 'score': 0.6826171875},... ] + if isinstance(rewards_chosen[0], dict): + score_chosen_batch = [result["score"] for result in rewards_chosen] + score_rejected_batch = [result["score"] for result in rewards_rejected] + # for classes that directly output scores (custom code) + else: + score_chosen_batch = rewards_chosen.cpu().numpy().tolist() + score_rejected_batch = rewards_rejected.cpu().numpy().tolist() + + # log results + [ + results.append(1) if chosen > rejected else results.append(0) + for chosen, rejected in zip(score_chosen_batch, score_rejected_batch) + ] + scores_chosen.extend(score_chosen_batch) + scores_rejected.extend(score_rejected_batch) + + ############################ + # compile scores + ############################ + # calculate accuracy + accuracy = sum(results) / len(results) + logger.info(f"Results: {accuracy}, on {len(results)} prompts") + + if args.dataset == "allenai/reward-bench": + out_dataset = dataset.add_column("results", results) + if args.debug: + subsets = subsets[:10] + out_dataset = out_dataset.add_column("subsets", subsets) + out_dataset = out_dataset.to_pandas() # I know this is meh + + results_grouped = {} + present_subsets = np.unique(out_dataset["subsets"]) + for subset in present_subsets: + subset_dataset = out_dataset[out_dataset["subsets"] == subset] + num_correct = sum(subset_dataset["results"]) + num_total = len(subset_dataset["results"]) + logger.info(f"{subset}: {num_correct}/{num_total} ({num_correct/num_total})") + results_grouped[subset] = num_correct / num_total + + results_section = calculate_scores_per_section(EXAMPLE_COUNTS, SUBSET_MAPPING, results_grouped) + logger.info(f"Results: {results_section}") + + ############################ + # compile scores + ############################ + # save score in json to args.output_dir + args.model + ".json" + output_path = args.output_dir + args.model + ".json" + dirname = os.path.dirname(output_path) + os.makedirs(dirname, exist_ok=True) + + # remove old data + if os.path.exists(output_path): + os.remove(output_path) + + with open(output_path, "w") as f: + json.dump( + { + "accuracy": accuracy, + "num_prompts": len(results), + "model": args.model, + "ref_model": args.ref_model, + "tokenizer": tokenizer_path, + "chat_template": args.chat_template, + "extra_results": results_grouped if args.dataset == "allenai/reward-bench" else None, + }, + f, + ) + + # if save_all is passed, save a large jsonl with all scores_chosen, scores_rejected + if args.save_all: + output_path = args.output_dir + args.model + "_all.jsonl" + dirname = os.path.dirname(output_path) + os.makedirs(dirname, exist_ok=True) + + # remove old data + if os.path.exists(output_path): + os.remove(output_path) + + with open(output_path, "w") as f: + for chosen, rejected in zip(scores_chosen, scores_rejected): + f.write(json.dumps({"chosen": scores_chosen, "rejected": scores_rejected}) + "\n") + + +if __name__ == "__main__": + main() diff --git a/rewardbench/utils.py b/rewardbench/utils.py index 8e023e33..a5f818f7 100644 --- a/rewardbench/utils.py +++ b/rewardbench/utils.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Union import pandas as pd -from datasets import Dataset, Value, concatenate_datasets, load_dataset +from datasets import Dataset, DatasetDict, Value, concatenate_datasets, load_dataset from fastchat.conversation import Conversation from huggingface_hub import HfApi from transformers import PreTrainedTokenizer @@ -133,6 +133,126 @@ def map_conversations_testsets(example): return example +def load_preference_dataset( + dataset_name: str, + split: str = "train", + json: bool = False, + conv: Conversation = None, + tokenizer: PreTrainedTokenizer = None, + logger: logging.Logger = None, +) -> Dataset: + """ + Load a preference dataset from the datasets library. + + Expects the data the following schema. + - prompt (string): question + - chosen (list): all turns of the conversation (including the prompt), chosen answer + - rejected (list): all turns of the conversation (including the prompt), rejected answer + + Removes all excess columns, only returns scores over the provided data in order. + + Args: + dataset_name (str): The name of the dataset to load (HuggingFace or local directory) + split (str): The split of the dataset to load (train, validation, test, ...) + + Returns: + dataset (Dataset): The loaded dataset with prompt, text_chosen, and text_rejected columns. + text_ indicates a full conversation ending with that turn + """ + if json: + dataset = load_dataset("json", data_files=dataset_name) + else: + dataset = load_dataset(dataset_name, split=split) + + # if datasetdict, flatten all splits + if isinstance(dataset, DatasetDict): + available_splits = list(dataset.keys()) + datasets_to_combine = [dataset[split] for split in available_splits] + dataset = concatenate_datasets(datasets_to_combine) + + # if has column question without prompt, rename question column to prompt + if "question" in dataset.column_names: + assert "prompt" not in dataset.column_names, "Both prompt and question columns found" + dataset = dataset.rename_column("question", "prompt") + if "input" in dataset.column_names: + assert "prompt" not in dataset.column_names, "Both prompt and question columns found" + dataset = dataset.rename_column("input", "prompt") + + # switch to format used for data utils + # e.g. for evaluating this data https://huggingface.co/datasets/allenai/preference-test-sets + # python -m rewardbench/rewardbench.py --dataset-name allenai/preference-test-sets --split shp + features = dataset.features + + def switch_format(example): + # chosen/rejected append {"role": "assistnat", "content": chosen} + example["prompt"] = example["chosen"][:-1] + example["chosen"] = example["chosen"][-1]["content"] + example["rejected"] = example["rejected"][-1]["content"] + return example + + # NOTE: We do NOT want to support every schema. These are the main three to start with + # 1. Prompt is in a list of previous turns, chosen and rejected are final message from assistant + # 2. Prompt is a string, chosen and rejected are full conversations with different final turns + # 3. Prompt is not existent, chosen and rejected are full conversations with different final turns + # TODO implement system prompts correctly (though, often doesn't work for Reward Models) + + # if prompt isn't a column, + if "prompt" not in dataset.column_names: + dataset = dataset.map( + switch_format, + num_proc=8, + load_from_cache_file=False, + ) + # elif prompt is a list and not a str, same function works + elif not isinstance(features["prompt"], list): + dataset = dataset.map( + switch_format, + num_proc=8, + load_from_cache_file=False, + ) + + # update features + features = dataset.features + + # assert the correct types + assert features["chosen"].dtype == "string", f"chosen is wrong type (should be string): {features['chosen']}" + assert features["rejected"].dtype == "string", f"rejected is wrong type (should be string): {features['rejected']}" + + # tokenize the data + usable_tokenizer = check_tokenizer_chat_template(tokenizer) + + # assert either conv is passed or tokenizer has chat_template + assert conv is not None or usable_tokenizer + + if usable_tokenizer: + if logger is not None: + logger.info("*** Preparing dataset with HF Transformers ***") + # docs https://huggingface.co/docs/transformers/main/en/chat_templating + dataset = dataset.map( + prepare_dialogue_from_tokenizer, + fn_kwargs={"tokenizer": tokenizer}, + num_proc=8, + load_from_cache_file=False, + ) + + # else use FastChat to get chat template + else: + if logger is not None: + logger.info("*** Preparing dataset with FastChat ***") + dataset = dataset.map( + prepare_dialogue, + fn_kwargs={"dialogue_template": conv}, + num_proc=8, + load_from_cache_file=False, + ) + + # remove excess data + keep_columns = ["prompt", "text_chosen", "text_rejected"] + all_cols = dataset.column_names + dataset = dataset.remove_columns([c for c in all_cols if c not in keep_columns]) + return dataset + + def load_eval_dataset( core_set: bool = True, custom_dialogue_formatting: bool = False, diff --git a/setup.py b/setup.py index 84cadd82..93f3a82e 100644 --- a/setup.py +++ b/setup.py @@ -14,43 +14,48 @@ from setuptools import find_packages, setup +# instructions for releasing new version: update the version number, then follow +# from 6 https://github.com/huggingface/diffusers/blob/49b959b5408b97274e2ee423059d9239445aea26/setup.py#L36C43-L38C1 +# this has not yet been pushed to pypyi-test setup( name="rewardbench", - version="0.1.0.dev", + version="0.1.1", author="Nathan Lambert", author_email="nathanl@allenai.org", description="Tools for evaluating reward models", + entry_points={ + "console_scripts": ["rewardbench=rewardbench.rewardbench:main"], + }, long_description=open("README.md").read(), long_description_content_type="text/markdown", url="https://github.com/allenai/rewardbench", packages=find_packages(), classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "License :: OSI Approved :: Apache 2.0 License", + "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ], python_requires=">=3.10", install_requires=[ "accelerate", "bitsandbytes", - "black==24.3.0", + "black", "datasets", "deepspeed", "einops", "flake8>=6.0", - "fschat[model_worker,webui]", + "fschat", "huggingface_hub", "isort>=5.12.0", "pandas", "peft", "pytest", - # "ray", # for generative llm multi-gpu "scipy", + "sentencepiece", "tabulate", # dependency for markdown rendering in pandas "tokenizers", + "torch", "tiktoken==0.6.0", # added for llama 3 - # "transformers @ git+https://github.com/huggingface/transformers.git@851f253f4d3fa2414451eeaac82b7a9ad6084675", # noqa "transformers==4.40.0", # pinned at llama 3 "trl>=0.8.2", # fixed transformers import error # TODO consider vllm in setup, currently only in dockerfile diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 00000000..80f01423 --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,45 @@ +# Copyright 2023 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# tests to make sure the code in the package is working as expected +import unittest + +from fastchat.conversation import get_conv_template +from transformers import AutoTokenizer + +from rewardbench import load_preference_dataset + + +class LoadAnyDataTest(unittest.TestCase): + """ + Simple scripts to make sure the loading scripts do not error. + """ + + def setUp(self): + self.tokenizer = AutoTokenizer.from_pretrained("allenai/rlhf-test-tokenizer") + self.conv = get_conv_template("tulu") + + def test_load_standard_tokenizer(self): + load_preference_dataset( + "allenai/ultrafeedback_binarized_cleaned", split="test_prefs", tokenizer=self.tokenizer + ) + + def test_load_standard_conv(self): + load_preference_dataset("allenai/ultrafeedback_binarized_cleaned", split="test_prefs", conv=self.conv) + + def test_load_alt_tokenizer(self): + load_preference_dataset("allenai/preference-test-sets", split="shp", tokenizer=self.tokenizer) + + def test_load_alt_conv(self): + load_preference_dataset("allenai/preference-test-sets", split="shp", conv=self.conv)