Skip to content

SakanaAI/treequest

Repository files navigation

TreeQuest

Python GitHub license Checks status Tests status

arXiv Blog

AB-MCTS

A flexible answer tree search library featuring AB-MCTS, useful for (but not limited to) LLM inference-time scaling.

Quick Start

import random
from pathlib import Path

import treequest as tq

# Each node is associated with a user-definable `state`.
State = str

# 1. Define a function to be used for node generation.
def generate(parent_state: State | None) -> tuple[State, float]:
    """Generates new states and scores based on the parent state."""
    if parent_state is None: # None represents the expansion from root.
        new_state = "Initial state"
    else:
        new_state = f"State after {parent_state}"

    score = random.random() # A score for the new state; It should be normalized to the [0, 1] range.
    return new_state, score

# 2. Instantiate the algorithm and a search tree object.
algo = tq.ABMCTSA()
search_tree = algo.init_tree()

# 3. Run the search with a generation budget (10 in this case).
for _ in range(10):
    search_tree = algo.step(search_tree, {'Action A': generate})

# 4. Extract the best score and state.
best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best state: {best_state}, Score: {best_node_score}")

# 5. Visualize the search tree.
output_file_basename = Path("ab_mcts_a_search_tree")
tq.render(search_tree, output_file_basename, format="html")  # Generates `ab_mcts_a_search_tree.html`

Alternatively, you can use an ask–tell interface with batched AB-MCTS sampling steps:

import random
import treequest as tq

State = str

def generate(parent_state: State | None) -> tuple[State, float]:
    ...

generate_fns = {"Action A": generate}
actions = list(generate_fns.keys())

# We use batch_size=5 here
batch_size = 5

# It runs AB-MCTS sampling step with 5 processes in parallel
algo = tq.ABMCTSM(max_process_workers=batch_size)
search_tree = algo.init_tree()

total_budget = 50
num_steps = total_budget // batch_size
for _ in range(num_steps):
    # ask_batch returns a list of `Trial` object, which has action, parent_state and trial_id attrs
    search_tree, trials = algo.ask_batch(search_tree, batch_size, actions)

    for trial in trials:
        result = generate_fns[trial.action](trial.parent_state)
        # Call tell method with trial_id to update search_tree
        search_tree = algo.tell(search_tree, trial.trial_id, result)

best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]

In particular for AB-MCTS-M, each step call can be slow. If you encounter slow execution, prefer ask_batch over step. Please note that using a large batch_size can skew the search-tree shape (i.e., the tree may become too wide), so it is best to avoid overly large batch_size, see PROFILING.md for example trees. We recommend batch_size<=5 as a starting point.

Features

  • Easy-to-use API with customizable node generation and node scoring logic.
  • AB-MCTS-A and AB-MCTS-M, as well as Multi-LLM AB-MCTS support (See our paper for algorithm details).
  • Checkpointing and resuming searches.

Installation

First, install uv. Then you can install TreeQuest with the following command:

uv add "treequest"

Alternatively, you can use pip to install TreeQuest:

pip install "treequest"

There are optional dependencies for ABMCTS-M and visualization features. You can install them with:

uv add "treequest[abmcts-m]"  # For ABMCTS-M
uv add "treequest[vis]"  # For visualization features
uv add "treequest[all]"  # For all optional features

Usage

Using an LLM as a Node Generator

You can use any object as a node state. You only need to define a generating function that returns a (state, score) tuple and takes the parent state as an argument:

import dataclasses

import treequest as tq

@dataclasses.dataclass
class State:
    llm_answer: str
    score: float

def generate(parent_state: State | None) -> tuple[State, float]:
    """Generate a new node by calling an LLM."""
    if parent_state is None:
        state = initial_generation()
    else:
        state = refine_answer(parent_state.llm_answer, parent_state.score)

    return state, state.score

def initial_generation() -> State:
    """
    Call LLM API to generate an initial answer.
    """
    ...

def refine_answer(llm_answer: str, score: float) -> State:
    """
    Call LLM API to refine an answer.
    """
    ...


algo = tq.ABMCTSM()
search_tree = algo.init_tree()
for i in range(20):
    search_tree = algo.step(search_tree, {'Action Label': generate})
    # Logging best node during the search.
    if (i + 1) % 5 == 0:
        best_interim_state, _ = tq.top_k(search_tree, algo, k=1)[0]
        print(f"Iteration {i+1}: Best state so far = {best_interim_state}")

best_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best Answer: {best_state.llm_answer}, Best Score: {best_state.score}")

Using Multiple LLMs (and Beyond)

TreeQuest supports multiple action types. For example, you can provide multiple generation functions backed by different LLMs to represent different action types:

from functools import partial

import treequest as tq

def generate(llm_name: str, parent_state=None):
    """
    Call LLM API using litellm, vllm, etc., to generate a new node
    """
    ...
    return new_state, new_score

llm_names = ["o4-mini", "gemini-2.5-pro"]
# Create dict of different actions backed by different LLMs.
generate_fns = {llm_name: partial(generate, llm_name=llm_name) for llm_name in llm_names}

algo = tq.StandardMCTS()
search_tree = algo.init_tree()
for _ in range(20):
    search_tree = algo.step(search_tree, generate_fns)

The variation is not limited to LLM types; you can use different prompts, actions, scoring logic, etc. in generate_fns.

Batch Semantics and Concurrency

  • Algorithms are stateless objects; the evolving tree/search state is returned from init_tree, step, ask, and tell.
  • ask_batch(state, batch_size, actions) returns exactly batch_size Trial objects to expand next.
    • Non-queue algorithms (e.g., ABMCTSM, ABMCTSA, MultiArmedBanditUCB) return exactly batch_size Trials.
    • Queue-based algorithms (e.g., StandardMCTS, BestFirstSearchAlgo, TreeOfThoughtsBFS) precompute a set of parent/action pairs and duplicate them if needed to fill batch_size.
  • tell(state, trial_id, (new_state, score)) reflects the result for the corresponding Trial.
    • Order-independent: you can call tell in any order; reflection is tied to trial_id.
    • Idempotent: calling tell twice on the same trial_id does not add extra nodes.
    • For queue-based algorithms, over-told Trials beyond possible number of childs from a parent node (e.g., (# actions)*samples_per_action for StandardMCTS) become INVALID and are not reflected.
  • Scores are expected to be normalized to the [0, 1] range.

Algorithms

ABMCTS-A: ABMCTS with Node Aggregation

ABMCTS-A uses node aggregation for adaptive branching:

import treequest as tq

# Instantiate the ABMCTS-A algorithm.
ab_mcts_a = tq.ABMCTSA()

search_tree = ab_mcts_a.init_tree()
for _ in range(50):
    search_tree = ab_mcts_a.step(search_tree, generate_fns)

ABMCTS-M: ABMCTS with Mixed Models

ABMCTS-M leverages PyMC's mixed modeling capabilities:

import treequest as tq

# Instantiate the ABMCTS-M algorithm.
ab_mcts_m = tq.ABMCTSM()

search_tree = ab_mcts_m.init_tree()
for _ in range(30):
    search_tree = ab_mcts_m.step(search_tree, generate_fns)

NOTE: To run AB-MCTS-M, you need to install extra dependencies with the treequest[abmcts-m] or treequest[all] option.

Visualization

TreeQuest provides visualization utilities to render the search tree. You can visualize the search tree using the tq.render function as shown in the Quick Start section. You need to install optional dependencies for visualization either by uv add treequest[vis] or uv add treequest[all].

import base64
from io import BytesIO
from pathlib import Path

import treequest as tq
from PIL import Image

class State:
    text: str
    image: Image.Image

def state_formatter_html(state: State) -> str:
    """Formats the state for HTML visualization."""
    buffer = BytesIO()
    state.image.save(buffer, format="WEBP")
    img_str = base64.b64encode(buffer.getvalue()).decode()
    # HTML representation is allowed (ensure safety when using untrusted content)
    return f'<p>{state.text}</p><br/><img src="data:image/webp;base64,{img_str}" width=100% />'

generate_fns = {
    "action_1": ...  # Define your node generation functions (actions) here
}

algo = tq.ABMCTSA()
search_tree = algo.init_tree()
output_dir = Path("progress"); output_dir.mkdir(exist_ok=True)
for step in range(1, 11):
    search_tree = algo.step(search_tree, generate_fns)
    # We can check the progress by rendering the tree at each step.
    if step % 5 > 0: continue
    tq.render(
        search_tree,
        output_basename=output_dir / f"search_tree_step{step}",
        format="pdf",  # For "pdf", "svg", "png", "jpg" and "jpeg" formats, using graphviz
        state_formatter=lambda state: state.text,  # For non-HTML formats, use text only
    )  # Generates `progress/search_tree_step5.pdf`, `progress/search_tree_step10.pdf`
tq.render(
    search_tree,
    output_basename="search_tree",
    format="html",  # For HTML format, use HTML with d3.js visualization
    state_formatter=state_formatter_html,  # Use HTML formatter
)  # Generates `search_tree.html`

IMPORTANT: When using HTML format, ensure that the HTML file is securely handled, especially if the state formatter includes raw HTML content. Avoid opening untrusted HTML files in your browser. For example, XSS (cross site scripting) attacks can occur if the state includes malicious HTML/JavaScript code.

Requirements

  • Python 3.11+

Contributing

Contributions are welcome! Please see CONTRIBUTING.md for development tips.

Citation

@inproceedings{inoue2025wider,
  title={Wider or Deeper?  Scaling {LLM} Inference-Time Compute with Adaptive Branching Tree Search},
  author={Yuichi Inoue and Kou Misaki and Yuki Imajuku and So Kuroki and Taishi Nakamura and Takuya Akiba},
  booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
  year={2025},
  url={https://openreview.net/forum?id=jAsr5GHt3P}
}

License

Apache 2.0

Third-party notices

TreeQuest bundles a few assets whose original authors retain copyright:

  • D3.js v7 by Mike Bostock and contributors, distributed under the ISC License. The full license text is included in src/treequest/vis/assets/d3.LICENSE.txt alongside the d3.v7.min.js binary.
  • Colormap samples extracted from matplotlib and seaborn. The data and license references are embedded in src/treequest/vis/assets/colormaps.json.

These notices must be preserved in any redistribution of TreeQuest, including any compiled artifacts.