RiskAgent provides high-quality, evidence-based risk predictions for over 387 risk scenarios across diverse complex diseases, including rare diseases and cancer.
- Training Data and Training script
- External Evaluation on MEDCALC-BENCH
- RiskAgent-70B
- ...
We provide a simple demo for the risk agent pipeline, which can be found at evaluate/riskagent_demo.ipynb
and evaluate/riskagent_demo_auto.ipynb
.
This supports report summary of risk prediction by a given patient information using RiskAgent model with just a simple setup.
Install necessary packages:
conda create -n riskagent python=3.9
pip install -r requirements.txt
from riskagent_pipeline import RiskAgentPipeline
pipeline = RiskAgentPipeline(
model_type="openai",
api_key="YOUR_OPENAI_API_KEY",
deployment="gpt-4o"
)
test_case = """
A 54-year-old female patient with a history of hypertension and diabetes presents to the clinic complaining of palpitations and occasional light-headedness. Her medical record shows a previous stroke but no history of congestive heart failure or vascular diseases like myocardial infarction or peripheral artery disease.
"""
results = pipeline.process_case(test_case)
print("\n=== Final Assessment ===")
print(results['final_output'])
If the downstream application involves sensitive data, we can use the RiskAgent-1/3/8/70B model for local inference.
The trained model can be found at:
Model | Model size | Base Model |
---|---|---|
RiskAgent-1B | 1B | Llama-3.2-1B-Instruct |
RiskAgent-3B | 3B | Llama-3.2-3B-Instruct |
RiskAgent-8B | 8B | Meta-Llama-3-8B-Instruct |
[RiskAgent-70B] Comming soon! | 70B | Meta-Llama-3-70B-Instruct |
Note: Prior to utilizing our model, please ensure you have obtained the Llama licensing and access rights to the Llama model.
from riskagent_pipeline import RiskAgentPipeline
pipeline = RiskAgentPipeline(
model_type="llama3",
model_path="LOCAL_PATH/RiskAgent-8B",
device_map="cuda:0",
verbose=True
)
test_case = """
A 54-year-old female patient with a history of hypertension and diabetes presents to the clinic complaining of palpitations and occasional light-headedness. Her medical record shows a previous stroke but no history of congestive heart failure or vascular diseases like myocardial infarction or peripheral artery disease.
"""
results = pipeline.process_case(test_case)
print("\n=== Final Assessment ===")
print(results['final_output'])
We provides instructions for reproducing the models and results reported in our paper.
MedRisk benchmark is made up with two version (also available on huggingface):
- MedRisk-Quantity:
data/MedRisk-Quantity.xlsx
- MedRisk-Qualitative:
data/MedRisk-Qualitative.xlsx
Each Instance in the dataset contains the following information:
input_id
: unique id for each instance.cal_id
: The tool id for this question.question
: the question stemoption_a
,option_b
,option_c
,option_d
: the options for the questioncorrect_answer
: the correct answer for the questionsplit
: the split of the dataset, eithertrain
,test
, orval
relevant_tools
: the full available tool list ordered with the relevance to the question.inputs
: the input parameters for the tool calculation (human readable format)inputs_raw
: the input parameters for the tool calculation (raw format)
We also provide the training data with the format of instruction-following data, this can be found at data/fine_tune/ft_data.zip
.
The evaluate_baseline.py
provides evaluation functions on OpenAI models and LLaMA-based models.
Run evluation with LLaMA based models:
python evaluate_baseline.py \
--model_type llama3 \
--model_path meta-llama/Meta-Llama-3-8B \
--split test \
--output_file llama3_pred_quantity.xlsx \
--device_map "cuda:0" \
--data_path data/MedRisk-Quantity.xlsx
model_path
can be model card from huggingface or your local model path.
device_map
: ["auto", "cuda:0", "cuda:1", etc. ] note: please try to run on single GPU to avoid parallel erros, i.e. device_map="cuda:0"
model_type:
["llama2", "llama3", "gpt"]
data_path
: either MedRisk-Quantity.xlsx
or MedRisk-Qualitative.xlsx
Run evluation with OpenAI models:
python evaluate_baseline.py
--model_type gpt \
--api_key YOUR_API_KEY \
--model_card gpt-4o \
--split test \
--output_file gpt4o_pred_quantity.xlsx \
--data_path data/MedRisk-Quantity.xlsx
We provide inference on both OpenAI models and open source models (eg. LLaMA) for our risk agent reasoning framework.
Run evluation with LLaMA based models on MedRisk benchmark:
python evaluate_riskagent.py \
--model_type llama3 \
--model_path meta-llama/Meta-Llama-3-8B \
--data_path data/MedRisk-Quantity.xlsx\
--output_dir ./riskagent_llama3_quantity \
--split test \
--tool_lib_path data/tool_library.xlsx \
--device_map "cuda:0"
Run evluation with OpenAI models:
python evaluate_riskagent.py \
--model_type openai \
--deployment gpt-4o \
--api_key YOUR_API_KEY \
--data_path data/MedRisk-Quantity.xlsx \
--output_dir ./riskagent_gpt4o_quantity \
--split test
Run evluation with OpenAI models via Azure:
python evaluate_riskagent.py \
--model_type azure \
--deployment gpt-4o \
--api_key YOUR_API_KEY \
--data_path data/MedRisk-Quantity.xlsx \
--api_base YOUR_AZURE_ENDPOINT \
--output_dir ./riskagent_gpt4o_quantity \
--split test
Please consider citing 📑 our papers if our repository is helpful to your work, thanks sincerely!
@article{liu2025riskagent,
title={RiskAgent: Autonomous Medical AI Copilot for Generalist Risk Prediction},
author={Liu, Fenglin and Wu, Jinge and Zhou, Hongjian and Gu, Xiao and Molaei, Soheila and Thakur, Anshul and Clifton, Lei and Wu, Honghan and Clifton, David A},
journal={arXiv preprint arXiv:2503.03802},
year={2025}
}
The Llama Family Models: Open and Efficient Foundation Language Models
LLaMA-Factory: Unified Efficient Fine-Tuning of 100+ Language Models