Skip to content

Commit

Permalink
Support loading model from wandb (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Sep 25, 2024
1 parent f5b5d9f commit 984b799
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 45 deletions.
123 changes: 78 additions & 45 deletions rewardbench/rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@

# Run RewardBench (evaluate any reward model on any dataet)

import argparse
import json
import logging
import os
import sys
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
import transformers
import wandb
from accelerate import Accelerator
from accelerate.logging import get_logger
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, HfArgumentParser

from rewardbench import (
DPO_MODEL_CONFIG,
Expand All @@ -36,38 +38,61 @@
)


def main():
parser = argparse.ArgumentParser(description="Evaluate a reward model.")

@dataclass
class Args:
# core args
parser.add_argument("--dataset", type=str, default="allenai/reward-bench", help="The dataset to evaluate on.")
parser.add_argument("--split", type=str, default=None, help="The split to evaluate on.")
parser.add_argument("--model", type=str, required=True, help="The model to evaluate.")
parser.add_argument("--ref_model", type=str, default=None, help="The reference model to compare against.")
parser.add_argument("--tokenizer", type=str, default=None, help="The tokenizer to use (defaults to model).")
parser.add_argument(
"--chat_template",
type=str,
default=None,
help="The chat template to use (defaults to from tokenizer, from chattemplate).",
)
parser.add_argument(
"--not_quantized", action="store_true", help="disable quantization for models that are quantized by default"
)
dataset: str = "allenai/reward-bench"
"""The dataset to evaluate on."""
split: Optional[str] = None
"""The split to evaluate on."""
model: Optional[str] = None
"""The model to evaluate."""
revision: Optional[str] = None
"""The model revision to evaluate."""
ref_model: Optional[str] = None
"""The reference model to compare against."""
tokenizer: Optional[str] = None
"""The tokenizer to use (defaults to model)."""
chat_template: Optional[str] = None
"""The chat template to use (defaults to from tokenizer, from chattemplate)."""
not_quantized: bool = False
"""Disable quantization for models that are quantized by default."""

# wandb args
wandb_run: Optional[str] = None
"""The wandb run to extract model and revision from."""

# inference args
parser.add_argument("--batch_size", type=int, default=8, help="The batch size to use.")
parser.add_argument("--max_length", type=int, default=512, help="The max length to use.")
batch_size: int = 8
"""The batch size to use."""
max_length: int = 512
"""The max length to use."""

# system args
parser.add_argument("--load_json", action="store_true", default=False, help="Load dataset as json.")
parser.add_argument("--trust_remote_code", action="store_true", default=False, help="Trust remote code.")
parser.add_argument("--debug", action="store_true", default=False, help="Debug mode.")
parser.add_argument("--output_dir", type=str, default="results/", help="The output directory to save results.")
parser.add_argument("--save_all", action="store_true", default=False, help="Save all results.")
parser.add_argument(
"--force_truncation", action="store_true", default=False, help="Force truncation (for if model errors)."
)
args = parser.parse_args()
load_json: bool = False
"""Load dataset as json."""
trust_remote_code: bool = False
"""Trust remote code."""
debug: bool = False
"""Debug mode."""
output_dir: str = "results/"
"""The output directory to save results."""
save_all: bool = False
"""Save all results."""
force_truncation: bool = False
"""Force truncation (for if model errors)."""


def main():
parser = HfArgumentParser((Args))
actual_main(*parser.parse_args_into_dataclasses())


def actual_main(args: Args):
if args.wandb_run is not None:
wandb_run = wandb.Api().run(args.wandb_run)
args.model = wandb_run.config["hf_repo_id"]
args.revision = wandb_run.config["hf_repo_revision"]

###############
# Setup logging
Expand Down Expand Up @@ -148,7 +173,9 @@ def main():
#########################
logger.info("*** Load dataset ***")
tokenizer_path = args.tokenizer if args.tokenizer else args.model
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=args.trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=args.trust_remote_code, revision=args.revision
)
if args.dataset == "allenai/reward-bench":
logger.info("Running core eval dataset.")
from rewardbench import load_eval_dataset
Expand Down Expand Up @@ -256,7 +283,9 @@ def main():
# note, device map auto does not work for quantized models
model_kwargs = {"device_map": "auto"}

model = model_builder(args.model, **model_kwargs, trust_remote_code=args.trust_remote_code)
model = model_builder(
args.model, **model_kwargs, revision=args.revision, trust_remote_code=args.trust_remote_code
)
reward_pipe = pipeline_builder(
"text-classification", # often not used
model=model,
Expand Down Expand Up @@ -363,19 +392,23 @@ def main():
if os.path.exists(output_path):
os.remove(output_path)

final_results = {
"accuracy": accuracy,
"num_prompts": len(results),
"model": args.model,
"ref_model": args.ref_model,
"tokenizer": tokenizer_path,
"chat_template": args.chat_template,
"extra_results": results_grouped if args.dataset == "allenai/reward-bench" else None,
}
with open(output_path, "w") as f:
json.dump(
{
"accuracy": accuracy,
"num_prompts": len(results),
"model": args.model,
"ref_model": args.ref_model,
"tokenizer": tokenizer_path,
"chat_template": args.chat_template,
"extra_results": results_grouped if args.dataset == "allenai/reward-bench" else None,
},
f,
)
json.dump(final_results, f)

if args.wandb_run is not None:
for key in final_results:
wandb_run.summary[f"rewardbench/{key}"] = final_results[key]
wandb_run.update()
print(f"Logged metrics to {wandb_run.url}")

# if save_all is passed, save a large jsonl with all scores_chosen, scores_rejected
if args.save_all:
Expand All @@ -389,7 +422,7 @@ def main():

with open(output_path, "w") as f:
for chosen, rejected in zip(scores_chosen, scores_rejected):
f.write(json.dumps({"chosen": scores_chosen, "rejected": scores_rejected}) + "\n")
f.write(json.dumps({"chosen": chosen, "rejected": rejected}) + "\n")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"tiktoken==0.6.0", # added for llama 3
"transformers==4.43.4", # pinned at llama 3
"trl>=0.8.2", # fixed transformers import error, for DPO
"wandb", # for loading model path / reivisions from wandb
],
extras_require={
"generative": [
Expand Down

0 comments on commit 984b799

Please sign in to comment.