Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup model routing config and plan routing to o1 #6189

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ codeact_enable_jupyter = true
# List of microagents to disable
#disabled_microagents = []

# Whether to enable plan routing to reasoning models
#enable_plan_routing = false

[agent.RepoExplorerAgent]
# Example: use a cheaper model for RepoExplorerAgent to reduce cost, especially
# useful when an agent doesn't demand high quality but uses a lot of tokens
Expand Down Expand Up @@ -284,6 +287,23 @@ llm_config = 'gpt3'
# The security analyzer to use (For Headless / CLI only - In Web this is overridden by Session Init)
#security_analyzer = ""

################################ Model Routing ###############################
# Configuration for model routing features
##############################################################################
[model_routing]

# The reasoning model to use for plan generation
reasoning_llm_config_name = 'reasoning_model'
judge_llm_config_name = 'judge_model'

[llm.judge_model]
model = "gpt-4o"
api_key = ""

[llm.reasoning_model]
model = "o1"
api_key = ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might have here a little too much configurability 😅

It's perfectly fine if we reserve some names for our own features. So the names (of the llm configs) don't need to be configurable themselves, they mean what we say they mean.

We did that with draft_llm:

We can reserve the names reasoning_model and reasoning_judge_model for the reasoning model routing feature, and use them freely as necessary in the code. So we don't need these lines:

reasoning_llm_config_name = 'reasoning_model'
judge_llm_config_name = 'judge_model'

That will also simplify the code below, starting from reading these configs in llm_config.py: I think we don't need to do anything there? They'll be read like any other named configs. And it will save us quite a bit of code complexity elsewhere too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might have here a little too much configurability 😅

Yeah I'm also feeling this, I'm implementing this and also think a bit about how to support the Notdiamond router one, where we can train a custom router on a set of selected LLMs and hence the llm config names are not fixed like the two above. But in this case indeed choosing a reserved name would make more sense, we probably don't need it to be configurable. I'll try to change that.

Copy link
Collaborator

@enyst enyst Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I hear you, I'm thinking about Notdiamond too, and about the litellm option - there's a routing feature in litellm which we have tried twice in the past to use, and it proved too much complexity when it doesn't actually support one of the most important things we wanted from it (fallbacks/retries by reading the providers' retry-after headers). Maybe we will look again at it (3rd time's the charm?) or not.

Anyway, there seem to be two ways we can take on the idea of future routing:

  • ignore it. We have a full feature here, we implement it as necessary, nice enough but we don't necessarily need all the building blocks we're guessing we will need for the most generalizable thing. Cross that bridge when we come to it. (we don't even know exactly what they will need, do we?)
  • keep an abstract class for routing config, and not a lot more. Maybe the way we took with condensers is a relevant example here: the configs for that do share an ABC, but the subclasses don't have the same attributes, and that's fine. Each will be configured as it needs, maybe with nothing in its config (the NoOpCondenser), or maybe with a bunch of attributes (max_size, whatever, for some specialized condensers which really need a config of their own). Again, cross that last bridge when we come to it.

OK, I started by saying there are two ways, but idk, maybe they're almost the same today. 😅

This routing feature, in this PR, does need to be enabled, so as long as it's enabled, it can do its thing IMHO.


#################################### Eval ####################################
# Configuration for the evaluation, please refer to the specific evaluation
# plugin for the available options
Expand Down
7 changes: 7 additions & 0 deletions evaluation/benchmarks/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import json
import os
import tempfile
Expand Down Expand Up @@ -33,6 +34,7 @@
SandboxConfig,
get_llm_config_arg,
get_parser,
load_from_toml,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
Expand Down Expand Up @@ -155,14 +157,19 @@ def get_config(
metadata.llm_config, metadata.eval_output_dir, instance['instance_id']
)
)
config_copy = copy.deepcopy(config)
load_from_toml(config_copy)
agent_config = AgentConfig(
codeact_enable_jupyter=False,
codeact_enable_browsing=RUN_WITH_BROWSING,
codeact_enable_llm_editor=False,
condenser=metadata.condenser_config,
enable_prompt_extensions=False,
enable_plan_routing=config_copy.get_agent_config().enable_plan_routing,
)
config.set_agent_config(agent_config)
config.routing_llms = config_copy.routing_llms
config.model_routing = config_copy.model_routing
return config


Expand Down
63 changes: 53 additions & 10 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.core.config import AgentConfig, ModelRoutingConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import ImageContent, Message, TextContent
from openhands.core.schema import ActionType
Expand Down Expand Up @@ -39,12 +39,14 @@
from openhands.events.serialization.event import truncate_content
from openhands.llm.llm import LLM
from openhands.memory.condenser import Condenser
from openhands.router import BaseRouter, LLMBasedPlanRouter
from openhands.runtime.plugins import (
AgentSkillsRequirement,
JupyterRequirement,
PluginRequirement,
)
from openhands.utils.prompt import PromptManager
from openhands.utils.trajectory import format_trajectory


class CodeActAgent(Agent):
Expand Down Expand Up @@ -80,11 +82,14 @@ def __init__(
self,
llm: LLM,
config: AgentConfig,
model_routing_config: ModelRoutingConfig | None = None,
routing_llms: dict[str, LLM] | None = None,
) -> None:
"""Initializes a new instance of the CodeActAgent class.

Parameters:
- llm (LLM): The llm to be used by this agent
- routing_llms (dict[str, LLM]): The llms to be selected for routing
"""
super().__init__(llm, config)
self.pending_actions: deque[Action] = deque()
Expand Down Expand Up @@ -113,6 +118,18 @@ def __init__(
self.condenser = Condenser.from_config(self.config.condenser)
logger.debug(f'Using condenser: {self.condenser}')

self.router: BaseRouter | None = None

if config.enable_plan_routing:
assert model_routing_config is not None and routing_llms is not None
self.router = LLMBasedPlanRouter(
llm=self.llm,
routing_llms=routing_llms or dict(),
model_routing_config=model_routing_config,
)

self.active_llm: LLM | None = None # The LLM chosen by the router

def get_action_message(
self,
action: Action,
Expand Down Expand Up @@ -148,6 +165,9 @@ def get_action_message(
rather than being returned immediately. They will be processed later when all corresponding
tool call results are available.
"""
# Handle the case where self.active_llm is None
active_llm_ = self.active_llm or self.llm

# create a regular message from an event
if isinstance(
action,
Expand Down Expand Up @@ -213,7 +233,7 @@ def get_action_message(
elif isinstance(action, MessageAction):
role = 'user' if action.source == 'user' else 'assistant'
content = [TextContent(text=action.content or '')]
if self.llm.vision_is_active() and action.image_urls:
if active_llm_.vision_is_active() and action.image_urls:
content.append(ImageContent(image_urls=action.image_urls))
return [
Message(
Expand Down Expand Up @@ -264,8 +284,11 @@ def get_observation_message(
Raises:
ValueError: If the observation type is unknown
"""
# Handle the case where self.active_llm is None
active_llm_ = self.active_llm or self.llm

message: Message
max_message_chars = self.llm.config.max_message_chars
max_message_chars = active_llm_.config.max_message_chars
if isinstance(obs, CmdOutputObservation):
# if it doesn't have tool call metadata, it was triggered by a user action
if obs.tool_call_metadata is None:
Expand Down Expand Up @@ -383,13 +406,30 @@ def step(self, state: State) -> Action:
if latest_user_message and latest_user_message.content.strip() == '/exit':
return AgentFinishAction()

params: dict = {}

# check if model routing is needed
if self.router:
messages = self._get_messages(state)
formatted_trajectory = format_trajectory(messages)
self.active_llm = self.router.should_route_to(formatted_trajectory)

if self.active_llm != self.llm:
logger.warning(f'🧭 Routing to custom model: {self.active_llm}')
else:
self.active_llm = self.llm

params['tools'] = self.tools
if not self.active_llm.is_function_calling_active():
params['mock_function_calling'] = True

# prepare what we want to send to the LLM
# NOTE: We need to call this here when self.active_llm is correctly set
messages = self._get_messages(state)
params: dict = {
'messages': self.llm.format_messages_for_llm(messages),
}
params['tools'] = self.tools
response = self.llm.completion(**params)
params['messages'] = self.active_llm.format_messages_for_llm(messages)

response = self.active_llm.completion(**params)

actions = codeact_function_calling.response_to_actions(response)
for action in actions:
self.pending_actions.append(action)
Expand Down Expand Up @@ -430,13 +470,16 @@ def _get_messages(self, state: State) -> list[Message]:
if not self.prompt_manager:
raise Exception('Prompt Manager not instantiated.')

# Handle the case where self.active_llm is None
active_llm_ = self.active_llm or self.llm

messages: list[Message] = [
Message(
role='system',
content=[
TextContent(
text=self.prompt_manager.get_system_message(),
cache_prompt=self.llm.is_caching_prompt_active(),
cache_prompt=active_llm_.is_caching_prompt_active(),
)
],
)
Expand Down Expand Up @@ -507,7 +550,7 @@ def _get_messages(self, state: State) -> list[Message]:

messages.append(msg)

if self.llm.is_caching_prompt_active():
if active_llm_.is_caching_prompt_active():
# NOTE: this is only needed for anthropic
# following logic here:
# https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262
Expand Down
2 changes: 1 addition & 1 deletion openhands/agenthub/dummy_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DummyAgent(Agent):
without making any LLM calls.
"""

def __init__(self, llm: LLM, config: AgentConfig):
def __init__(self, llm: LLM, config: AgentConfig, **kwargs):
super().__init__(llm, config)
self.steps: list[ActionObs] = [
{
Expand Down
1 change: 1 addition & 0 deletions openhands/controller/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self,
llm: LLM,
config: 'AgentConfig',
**kwargs,
):
self.llm = llm
self.config = config
Expand Down
2 changes: 2 additions & 0 deletions openhands/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_field_info,
)
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.model_routing_config import ModelRoutingConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
from openhands.core.config.utils import (
Expand All @@ -27,6 +28,7 @@
'LLMConfig',
'SandboxConfig',
'SecurityConfig',
'ModelRoutingConfig',
'load_app_config',
'load_from_env',
'load_from_toml',
Expand Down
1 change: 1 addition & 0 deletions openhands/core/config/agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ class AgentConfig(BaseModel):
enable_prompt_extensions: bool = Field(default=True)
disabled_microagents: list[str] | None = Field(default=None)
condenser: CondenserConfig = Field(default_factory=NoOpCondenserConfig)
enable_plan_routing: bool = Field(default=False)
9 changes: 8 additions & 1 deletion openhands/core/config/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
model_defaults_to_dict,
)
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.model_routing_config import ModelRoutingConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig

Expand All @@ -20,6 +21,7 @@ class AppConfig(BaseModel):
Attributes:
llms: Dictionary mapping LLM names to their configurations.
The default configuration is stored under the 'llm' key.
routing_llms: Dictionary mapping LLM for routing' names to their configurations.
agents: Dictionary mapping agent names to their configurations.
The default configuration is stored under the 'agent' key.
default_agent: Name of the default agent to use.
Expand Down Expand Up @@ -48,10 +50,12 @@ class AppConfig(BaseModel):
"""

llms: dict[str, LLMConfig] = Field(default_factory=dict)
routing_llms: dict[str, LLMConfig] = Field(default_factory=dict)
agents: dict = Field(default_factory=dict)
default_agent: str = Field(default=OH_DEFAULT_AGENT)
sandbox: SandboxConfig = Field(default_factory=SandboxConfig)
security: SecurityConfig = Field(default_factory=SecurityConfig)
model_routing: ModelRoutingConfig = Field(default_factory=ModelRoutingConfig)
runtime: str = Field(default='docker')
file_store: str = Field(default='local')
file_store_path: str = Field(default='/tmp/openhands_file_store')
Expand Down Expand Up @@ -94,7 +98,10 @@ def get_llm_config(self, name='llm') -> LLMConfig:
return self.llms['llm']

def set_llm_config(self, value: LLMConfig, name='llm') -> None:
self.llms[name] = value
if value.for_routing:
self.routing_llms[name] = value
else:
self.llms[name] = value

def get_agent_config(self, name='agent') -> AgentConfig:
"""'agent' is the name for default config (for backward compatibility prior to 0.8)."""
Expand Down
1 change: 1 addition & 0 deletions openhands/core/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class LLMConfig(BaseModel):
custom_tokenizer: str | None = Field(default=None)
native_tool_calling: bool | None = Field(default=None)
reasoning_effort: str | None = Field(default='high')
for_routing: bool = Field(default=False)

model_config = {'extra': 'forbid'}

Expand Down
6 changes: 6 additions & 0 deletions openhands/core/config/model_routing_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel, Field


class ModelRoutingConfig(BaseModel):
reasoning_llm_config_name: str = Field(default='reasoning_model')
judge_llm_config_name: str = Field(default='judge_model')
14 changes: 8 additions & 6 deletions openhands/core/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
from openhands.core import logger
from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.app_config import AppConfig
from openhands.core.config.config_utils import (
OH_DEFAULT_AGENT,
OH_MAX_ITERATIONS,
)
from openhands.core.config.config_utils import OH_DEFAULT_AGENT, OH_MAX_ITERATIONS
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.model_routing_config import ModelRoutingConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
from openhands.storage import get_file_store
Expand Down Expand Up @@ -164,7 +162,6 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
logger.openhands_logger.debug(
'Attempt to load default LLM config from config toml'
)

# Extract generic LLM fields, which are not nested LLM configs
generic_llm_fields = {}
for k, v in value.items():
Expand Down Expand Up @@ -195,13 +192,18 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):

custom_llm_config = LLMConfig(**merged_llm_dict)
cfg.set_llm_config(custom_llm_config, nested_key)

elif key is not None and key.lower() == 'security':
logger.openhands_logger.debug(
'Attempt to load security config from config toml'
)
security_config = SecurityConfig(**value)
cfg.security = security_config
elif key is not None and key.lower() == 'model_routing':
logger.openhands_logger.debug(
'Attempt to load model routing config from config toml'
)
model_routing_config = ModelRoutingConfig(**value)
cfg.model_routing = model_routing_config
elif not key.startswith('sandbox') and key.lower() != 'core':
logger.openhands_logger.warning(
f'Unknown key in {toml_file}: "{key}"'
Expand Down
13 changes: 10 additions & 3 deletions openhands/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from openhands.controller import AgentController
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import (
AppConfig,
)
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.events.event import Event
Expand Down Expand Up @@ -62,9 +60,18 @@ def create_agent(runtime: Runtime, config: AppConfig) -> Agent:
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent)
routing_llms_config = config.routing_llms
model_routing_config = config.model_routing
routing_llms = {}
for config_name, routing_llm_config in routing_llms_config.items():
routing_llms[config_name] = LLM(
config=routing_llm_config,
)
agent = agent_cls(
llm=LLM(config=llm_config),
config=agent_config,
model_routing_config=model_routing_config,
routing_llms=routing_llms,
)
if agent.prompt_manager:
microagents = runtime.get_microagents_from_selected_repo(None)
Expand Down
1 change: 1 addition & 0 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
'gpt-4o-mini',
'gpt-4o',
'o1-2024-12-17',
'o1',
'o3-mini-2025-01-31',
'o3-mini',
]
Expand Down
4 changes: 4 additions & 0 deletions openhands/router/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from openhands.router.base import BaseRouter
from openhands.router.plan.llm_based import LLMBasedPlanRouter

__all__ = ['BaseRouter', 'LLMBasedPlanRouter']
Loading