Skip to content

Latest commit

 

History

History
49 lines (31 loc) · 3.56 KB

README.md

File metadata and controls

49 lines (31 loc) · 3.56 KB

VersaPRM: Multi-Domain Process Reward Model via Synthetic Reasoning Data

VersaPRM

Abstract

Process Reward Models (PRMs) have proven effective at enhancing mathematical reasoning for Large Language Models (LLMs) by leveraging increased inference-time computation. However, they are predominantly trained on mathematical data and their generalizability to non-mathematical domains has not been rigorously studied. In response, this work first shows that current PRMs have poor performance in other domains. To address this limitation, we introduce VersaPRM, a multi-domain PRM trained on synthetic reasoning data generated using our novel data generation and annotation method. VersaPRM achieves consistent performance gains across diverse domains. For instance, in the MMLU-Pro category of Law, VersaPRM via weighted majority voting, achieves a 7.9% performance gain over the majority voting baseline---surpassing Qwen2.5-Math-PRM's gain of 1.3%. We further contribute to the community by open-sourcing all data, code and models for VersaPRM.

Link to paper: https://arxiv.org/abs/2502.06737

🚀 Models

We provide several VersaPRM models and Math PRM models trained with different configurations.

A suite of PRM models trained with various configurations:

  • VersaPRM – trained from Llama-PRM800K using MMLU-Pro-CoT-Train-Labeled
  • VersaPRM-Base-8B – trained from llama-3.1-8b instruct using MMLU-Pro-CoT-Train-Labeled
  • VersaPRM-Aug – trained from Llama-PRM800K using MMLU-Pro-CoT-Train-Labeled with counterfactual augmentations.
  • Qwen-PRM800K – Fine-tuned Qwen model using PRM800K dataset
  • Llama-PRM800K – Fine-tuned LLaMA model using PRM800K dataset

📊 Datasets

  • Size: 249k rows
  • Use Case: Evaluation cots generated llama-3.1-8b-instruct, sampled 128 responses for each question. (Some question have less as we remove CoT for which cannot extract final answer)
  • Size: 84.1k rows
  • Use Case: Stepwise Labeled CoT training dataset used to finetune VersaPRM. The CoT are generated by llama-3.1-8b-Instruct, with 16 responses sampled for each question. (Some question have less as we remove invalid CoTs)

Folder Structure

  • aws_batch_inference/: Scripts for performing batch inference using AWS services.
  • counterfactual_augmentation/: Scripts for creating and processing counterfactual augmentation batches.
  • evaluation/: Scripts and resources for evaluating model performance, including metric calculations and batch output merging.
  • figures/: Contains visual representations and figures used in the project.
  • model_train/: Scripts and configurations for training models, including data preparation and training scripts.
  • search_algs/: Scripts for running beam search and MCTS (TODO: add readme).
  • synth_cot_generation/: Scripts for generating synthetic chain-of-thought (CoT) data for training and evaluation.