TreeQuest



A flexible answer tree search library featuring AB-MCTS, useful for (but not limited to) LLM inference-time scaling.
Quick Start
import random
import treequest as tq
State = str
def generate(parent_state: State | None) -> tuple[State, float]:
"""Generates new states and scores based on the parent state."""
if parent_state is None:
new_state = "Initial state"
else:
new_state = f"State after {parent_state}"
score = random.random()
return new_state, score
algo = tq.ABMCTSA()
search_tree = algo.init_tree()
for _ in range(10):
search_tree = algo.step(search_tree, {'Action A': generate})
best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best state: {best_state}, Score: {best_node_score}")
Features
- Easy-to-use API with customizable node generation and node scoring logic.
- AB-MCTS-A and AB-MCTS-M, as well as Multi-LLM AB-MCTS support (See our paper for algorithm details).
- Checkpointing and resuming searches.
Installation
uv
First, install uv
. Then you can install TreeQuest with the following command:
uv add "treequest[abmcts-m]"
pip
Alternatively, you can use pip to install TreeQuest:
pip install "treequest[abmcts-m]"
Usage
Using an LLM as a Node Generator
You can use any object as a node state. You only need to define a generating function that returns a (state, score)
tuple and takes the parent state as an argument:
import dataclasses
import treequest as tq
@dataclasses.dataclass
class State:
llm_answer: str
score: float
def generate(parent_state: State | None) -> tuple[State, float]:
"""Generate a new node by calling an LLM."""
if parent_state is None:
state = initial_generation()
else:
state = refine_answer(parent_state.llm_answer, parent_state.score)
return state, state.score
def initial_generation() -> State:
"""
Call LLM API to generate an initial answer.
"""
...
def refine_answer(llm_answer: str, score: float) -> State:
"""
Call LLM API to refine an answer.
"""
...
algo = tq.ABMCTSM()
search_tree = algo.init_tree()
for i in range(20):
search_tree = algo.step(search_tree, {'Action Label': generate})
if (i + 1) % 5 == 0:
best_interim_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Iteration {i+1}: Best state so far = {best_interim_state}")
best_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best Answer: {best_state.llm_answer}, Best Score: {best_state.score}")
Using Multiple LLMs (and Beyond)
TreeQuest supports multiple action types. For example, you can provide multiple generation functions backed by different LLMs to represent different action types:
from functools import partial
import treequest as tq
def generate(llm_name: str, parent_state=None):
"""
Call LLM API using litellm, vllm, etc., to generate a new node
"""
...
return new_state, new_score
llm_names = ["o4-mini", "gemini-2.5-pro"]
generate_fns = {llm_name: partial(generate, llm_name=llm_name) for llm_name in llm_names}
algo = tq.StandardMCTS()
search_tree = algo.init_tree()
for _ in range(20):
search_tree = algo.step(search_tree, generate_fns)
The variation is not limited to LLM types; you can use different prompts, actions, scoring logic, etc. in generate_fns
.
Algorithms
ABMCTS-A: ABMCTS with Node Aggregation
ABMCTS-A uses node aggregation for adaptive branching:
import treequest as tq
ab_mcts_a = tq.ABMCTSA()
search_tree = ab_mcts_a.init_tree()
for _ in range(50):
search_tree = ab_mcts_a.step(search_tree, generate_fns)
ABMCTS-M: ABMCTS with Mixed Models
ABMCTS-M leverages PyMC's mixed modeling capabilities:
import treequest as tq
ab_mcts_m = tq.ABMCTSM()
search_tree = ab_mcts_m.init_tree()
for _ in range(30):
search_tree = ab_mcts_m.step(search_tree, generate_fns)
NOTE: To run AB-MCTS-M, you need to install extra dependencies with the treequest[abmcts-m]
option.
Requirements
Contributing
Contributions are welcome! Please see CONTRIBUTING.md for development tips.
Citation
@article{inoue2025wider,
title={Wider or Deeper? Scaling LLM Inference-Time Compute with Adaptive Branching Tree Search},
author={Inoue, Yuichi and Misaki, Kou and Imajuku, Yuki and Kuroki, So and Nakamura, Taishi and Akiba, Takuya},
journal={arXiv preprint arXiv:2503.04412},
year={2025}
}
License
Apache 2.0