Implement model routing support (#9738)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Ryan H. Tran 2025-09-08 16:19:34 +07:00 committed by GitHub
parent af0ab5a9f2
commit df9320f8ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 515 additions and 7 deletions

View File

@ -219,6 +219,14 @@ correct_num = 5
api_key = ""
model = "gpt-4o"
# Example routing LLM configuration for multimodal model routing
# Uncomment and configure to enable model routing with a secondary model
#[llm.secondary_model]
#model = "kimi-k2"
#api_key = ""
#for_routing = true
#max_input_tokens = 128000
#################################### Agent ###################################
# Configuration for agents (group name starts with 'agent')
@ -480,3 +488,14 @@ type = "noop"
# Run the runtime sandbox container in privileged mode for use with docker-in-docker
#privileged = false
#################################### Model Routing ############################
# Configuration for experimental model routing feature
# Enables intelligent switching between different LLM models for specific purposes
##############################################################################
[model_routing]
# Router to use for model selection
# Available options:
# - "noop_router" (default): No routing, always uses primary LLM
# - "multimodal_router": A router that switches between primary and secondary models, depending on whether the input is multimodal or not
#router_name = "noop_router"

View File

@ -28,6 +28,7 @@ from evaluation.utils.shared import (
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
update_llm_config_for_completions_logging,
)
from openhands.controller.state.state import State
from openhands.core.config import (
@ -36,7 +37,11 @@ from openhands.core.config import (
get_llm_config_arg,
load_from_toml,
)
from openhands.core.config.utils import get_agent_config_arg
from openhands.core.config.utils import (
get_agent_config_arg,
get_llms_for_routing_config,
get_model_routing_config_arg,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction
@ -57,6 +62,7 @@ AGENT_CLS_TO_INST_SUFFIX = {
def get_config(
instance: pd.Series,
metadata: EvalMetadata,
) -> OpenHandsConfig:
sandbox_config = get_default_sandbox_config_for_eval()
@ -66,13 +72,24 @@ def get_config(
sandbox_config=sandbox_config,
runtime='docker',
)
config.set_llm_config(metadata.llm_config)
config.set_llm_config(
update_llm_config_for_completions_logging(
metadata.llm_config, metadata.eval_output_dir, instance['instance_id']
)
)
model_routing_config = get_model_routing_config_arg()
model_routing_config.llms_for_routing = (
get_llms_for_routing_config()
) # Populate with LLMs for routing from config.toml file
if metadata.agent_config:
metadata.agent_config.model_routing = model_routing_config
config.set_agent_config(metadata.agent_config, metadata.agent_class)
else:
logger.info('Agent config not provided, using default settings')
agent_config = config.get_agent_config(metadata.agent_class)
agent_config.enable_prompt_extensions = False
agent_config.model_routing = model_routing_config
config_copy = copy.deepcopy(config)
load_from_toml(config_copy)
@ -145,7 +162,7 @@ def process_instance(
metadata: EvalMetadata,
reset_logger: bool = True,
) -> EvalOutput:
config = get_config(metadata)
config = get_config(instance, metadata)
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
if reset_logger:

View File

@ -47,6 +47,8 @@ from openhands.core.config import (
get_agent_config_arg,
get_evaluation_parser,
get_llm_config_arg,
get_llms_for_routing_config,
get_model_routing_config_arg,
)
from openhands.core.config.condenser_config import NoOpCondenserConfig
from openhands.core.config.utils import get_condenser_config_arg
@ -244,6 +246,11 @@ def get_config(
# get 'draft_editor' config if exists
config.set_llm_config(get_llm_config_arg('draft_editor'), 'draft_editor')
model_routing_config = get_model_routing_config_arg()
model_routing_config.llms_for_routing = (
get_llms_for_routing_config()
) # Populate with LLMs for routing from config.toml file
agent_config = AgentConfig(
enable_jupyter=False,
enable_browsing=RUN_WITH_BROWSING,
@ -251,8 +258,10 @@ def get_config(
enable_mcp=False,
condenser=metadata.condenser_config,
enable_prompt_extensions=False,
model_routing=model_routing_config,
)
config.set_agent_config(agent_config)
return config

View File

@ -92,6 +92,9 @@ class CodeActAgent(Agent):
self.condenser = Condenser.from_config(self.config.condenser, llm_registry)
logger.debug(f'Using condenser: {type(self.condenser)}')
# Override with router if needed
self.llm = self.llm_registry.get_router(self.config)
@property
def prompt_manager(self) -> PromptManager:
if self._prompt_manager is None:

View File

@ -13,6 +13,7 @@ from openhands.core.config.config_utils import (
from openhands.core.config.extended_config import ExtendedConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.config.model_routing_config import ModelRoutingConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
@ -20,6 +21,8 @@ from openhands.core.config.utils import (
finalize_config,
get_agent_config_arg,
get_llm_config_arg,
get_llms_for_routing_config,
get_model_routing_config_arg,
load_from_env,
load_from_toml,
load_openhands_config,
@ -37,6 +40,7 @@ __all__ = [
'LLMConfig',
'SandboxConfig',
'SecurityConfig',
'ModelRoutingConfig',
'ExtendedConfig',
'load_openhands_config',
'load_from_env',
@ -50,4 +54,6 @@ __all__ = [
'get_evaluation_parser',
'parse_arguments',
'setup_config_from_args',
'get_model_routing_config_arg',
'get_llms_for_routing_config',
]

View File

@ -7,6 +7,7 @@ from openhands.core.config.condenser_config import (
ConversationWindowCondenserConfig,
)
from openhands.core.config.extended_config import ExtendedConfig
from openhands.core.config.model_routing_config import ModelRoutingConfig
from openhands.core.logger import openhands_logger as logger
from openhands.utils.import_utils import get_impl
@ -57,6 +58,8 @@ class AgentConfig(BaseModel):
# handled.
default_factory=lambda: ConversationWindowCondenserConfig()
)
model_routing: ModelRoutingConfig = Field(default_factory=ModelRoutingConfig)
"""Model routing configuration settings."""
extended: ExtendedConfig = Field(default_factory=lambda: ExtendedConfig({}))
"""Extended configuration for the agent."""

View File

@ -46,6 +46,7 @@ class LLMConfig(BaseModel):
reasoning_effort: The effort to put into reasoning. This is a string that can be one of 'low', 'medium', 'high', or 'none'. Can apply to all reasoning models.
seed: The seed to use for the LLM.
safety_settings: Safety settings for models that support them (like Mistral AI and Gemini).
for_routing: Whether this LLM is used for routing. This is set to True for models used in conjunction with the main LLM in the model routing feature.
"""
model: str = Field(default='claude-sonnet-4-20250514')
@ -92,6 +93,7 @@ class LLMConfig(BaseModel):
default=None,
description='Safety settings for models that support them (like Mistral AI and Gemini)',
)
for_routing: bool = Field(default=False)
model_config = ConfigDict(extra='forbid')

View File

@ -0,0 +1,39 @@
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from openhands.core.config.llm_config import LLMConfig
class ModelRoutingConfig(BaseModel):
"""Configuration for model routing.
Attributes:
router_name (str): The name of the router to use. Default is 'noop_router'.
llms_for_routing (dict[str, LLMConfig]): A dictionary mapping config names of LLMs for routing to their configurations.
"""
router_name: str = Field(default='noop_router')
llms_for_routing: dict[str, LLMConfig] = Field(default_factory=dict)
model_config = ConfigDict(extra='forbid')
@classmethod
def from_toml_section(cls, data: dict) -> dict[str, 'ModelRoutingConfig']:
"""
Create a mapping of ModelRoutingConfig instances from a toml dictionary representing the [model_routing] section.
The configuration is built from all keys in data.
Returns:
dict[str, ModelRoutingConfig]: A mapping where the key "model_routing" corresponds to the [model_routing] configuration
"""
# Initialize the result mapping
model_routing_mapping: dict[str, ModelRoutingConfig] = {}
# Try to create the configuration instance
try:
model_routing_mapping['model_routing'] = cls.model_validate(data)
except ValidationError as e:
raise ValueError(f'Invalid model routing configuration: {e}')
return model_routing_mapping

View File

@ -30,6 +30,7 @@ class OpenHandsConfig(BaseModel):
The default configuration is stored under the 'agent' key.
default_agent: Name of the default agent to use.
sandbox: Sandbox configuration settings.
security: Security configuration settings.
runtime: Runtime environment identifier.
file_store: Type of file store to use.
file_store_path: Path to the file store.

View File

@ -25,6 +25,7 @@ from openhands.core.config.extended_config import ExtendedConfig
from openhands.core.config.kubernetes_config import KubernetesConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.config.model_routing_config import ModelRoutingConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
@ -225,6 +226,35 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None
# Re-raise ValueError from SecurityConfig.from_toml_section
raise ValueError('Error in [security] section in config.toml')
if 'model_routing' in toml_config:
try:
model_routing_mapping = ModelRoutingConfig.from_toml_section(
toml_config['model_routing']
)
# We only use the base model routing config for now
if 'model_routing' in model_routing_mapping:
default_agent_config = cfg.get_agent_config()
default_agent_config.model_routing = model_routing_mapping[
'model_routing'
]
# Construct the llms_for_routing by filtering llms with for_routing = True
llms_for_routing_dict = {}
for llm_name, llm_config in cfg.llms.items():
if llm_config and llm_config.for_routing:
llms_for_routing_dict[llm_name] = llm_config
default_agent_config.model_routing.llms_for_routing = (
llms_for_routing_dict
)
logger.openhands_logger.debug(
'Default model routing configuration loaded from config toml and assigned to default agent'
)
except (TypeError, KeyError, ValidationError) as e:
logger.openhands_logger.warning(
f'Cannot parse [model_routing] config from toml, values have not been applied.\nError: {e}'
)
# Process sandbox section if present
if 'sandbox' in toml_config:
try:
@ -327,6 +357,7 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None
'condenser',
'mcp',
'kubernetes',
'model_routing',
}
for key in toml_config:
if key.lower() not in known_sections:
@ -559,6 +590,41 @@ def get_llm_config_arg(
return None
def get_llms_for_routing_config(toml_file: str = 'config.toml') -> dict[str, LLMConfig]:
"""Get the LLMs that are configured for routing from the config file.
This function will return a dictionary of LLMConfig objects that are configured
for routing, i.e., those with `for_routing` set to True.
Args:
toml_file: Path to the configuration file to read from. Defaults to 'config.toml'.
Returns:
dict[str, LLMConfig]: A dictionary of LLMConfig objects for routing.
"""
llms_for_routing: dict[str, LLMConfig] = {}
try:
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
toml_config = toml.load(toml_contents)
except FileNotFoundError:
return llms_for_routing
except toml.TomlDecodeError as e:
logger.openhands_logger.error(
f'Cannot parse LLM configs from {toml_file}. Exception: {e}'
)
return llms_for_routing
llm_configs = LLMConfig.from_toml_section(toml_config.get('llm', {}))
if llm_configs:
for llm_name, llm_config in llm_configs.items():
if llm_config.for_routing:
llms_for_routing[llm_name] = llm_config
return llms_for_routing
def get_condenser_config_arg(
condenser_config_arg: str, toml_file: str = 'config.toml'
) -> CondenserConfig | None:
@ -671,6 +737,50 @@ def get_condenser_config_arg(
return None
def get_model_routing_config_arg(toml_file: str = 'config.toml') -> ModelRoutingConfig:
"""Get the model routing settings from the config file. We only support the default model routing config [model_routing].
Args:
toml_file: Path to the configuration file to read from. Defaults to 'config.toml'.
Returns:
ModelRoutingConfig: The ModelRoutingConfig object with the settings from the config file, or the object with default values if not found/error.
"""
logger.openhands_logger.debug(
f"Loading model routing config ['model_routing'] from {toml_file}"
)
default_cfg = ModelRoutingConfig()
# load the toml file
try:
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
toml_config = toml.load(toml_contents)
except FileNotFoundError as e:
logger.openhands_logger.error(f'Config file not found: {toml_file}. Error: {e}')
return default_cfg
except toml.TomlDecodeError as e:
logger.openhands_logger.error(
f'Cannot parse model routing group [model_routing] from {toml_file}. Exception: {e}'
)
return default_cfg
# Update the model routing config with the specified section
if 'model_routing' in toml_config:
try:
model_routing_data = toml_config['model_routing']
return ModelRoutingConfig(**model_routing_data)
except ValidationError as e:
logger.openhands_logger.error(
f'Invalid model routing configuration for [model_routing]: {e}'
)
return default_cfg
logger.openhands_logger.warning(
f'Model routing config section [model_routing] not found in {toml_file}'
)
return default_cfg
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = get_headless_parser()

View File

@ -108,10 +108,24 @@ class LLMRegistry:
def get_active_llm(self) -> LLM:
return self.active_agent_llm
def _set_active_llm(self, service_id) -> None:
if service_id not in self.service_to_llm:
raise ValueError(f'Unrecognized service ID: {service_id}')
self.active_agent_llm = self.service_to_llm[service_id]
def get_router(self, agent_config: AgentConfig) -> 'LLM':
"""
Get a router instance that inherits from LLM.
"""
# Import here to avoid circular imports
from openhands.llm.router import RouterLLM
router_name = agent_config.model_routing.router_name
if router_name == 'noop_router':
# Return the main LLM directly (no routing)
return self.get_llm_from_agent_config('agent', agent_config)
return RouterLLM.from_config(
agent_config=agent_config,
llm_registry=self,
retry_listener=self.retry_listner,
)
def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None:
self.subscriber = callback

View File

@ -0,0 +1,39 @@
# Model Routing Module
**⚠️ Experimental Feature**: This module is experimental and under active development.
## Overview
Model routing enables OpenHands to switch between different LLM models during a conversation. An example use case is routing between a primary (expensive, multimodal) model and a secondary (cheaper, text-only) model.
## Available Routers
- **`noop_router`** (default): No routing, always uses primary LLM
- **`multimodal_router`**: A router that switches based on:
- Routes to primary model for images or when secondary model's context limit is exceeded
- Uses secondary model for text-only requests within its context limit
## Configuration
Add to your `config.toml`:
```toml
# Main LLM (primary model)
[llm]
model = "claude-sonnet-4"
api_key = "your-api-key"
# Secondary model for routing
[llm.secondary_model]
model = "kimi-k2"
api_key = "your-api-key"
for_routing = true
# Enable routing
[model_routing]
router_name = "multimodal_router"
```
## Extending
Create custom routers by inheriting from `BaseRouter` and implementing `set_active_llm()`. Register in `ROUTER_REGISTRY`.

View File

@ -0,0 +1,8 @@
from openhands.llm.router.base import ROUTER_LLM_REGISTRY, RouterLLM
from openhands.llm.router.rule_based.impl import MultimodalRouter
__all__ = [
'RouterLLM',
'ROUTER_LLM_REGISTRY',
'MultimodalRouter',
]

View File

@ -0,0 +1,164 @@
import copy
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable
from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
if TYPE_CHECKING:
from openhands.llm.llm_registry import LLMRegistry
ROUTER_LLM_REGISTRY: dict[str, type['RouterLLM']] = {}
class RouterLLM(LLM):
"""
Base class for multiple LLM acting as a unified LLM.
This class provides a foundation for implementing model routing by inheriting from LLM,
allowing routers to work with multiple underlying LLM models while presenting a unified
LLM interface to consumers.
Key features:
- Works with multiple LLMs configured via llms_for_routing
- Delegates all other operations/properties to the selected LLM
- Provides routing interface through _select_llm() method
"""
def __init__(
self,
agent_config: AgentConfig,
llm_registry: 'LLMRegistry',
service_id: str = 'router_llm',
metrics: Metrics | None = None,
retry_listener: Callable[[int, int], None] | None = None,
):
"""
Initialize RouterLLM with multiple LLM support.
"""
self.llm_registry = llm_registry
self.model_routing_config = agent_config.model_routing
# Get the primary agent LLM
self.primary_llm = llm_registry.get_llm_from_agent_config('agent', agent_config)
# Instantiate all the LLM instances for routing
llms_for_routing_config = self.model_routing_config.llms_for_routing
self.llms_for_routing = {
config_name: self.llm_registry.get_llm(
f'llm_for_routing.{config_name}', config=llm_config
)
for config_name, llm_config in llms_for_routing_config.items()
}
# All available LLMs for routing (set this BEFORE calling super().__init__)
self.available_llms = {'primary': self.primary_llm, **self.llms_for_routing}
# Create router config based on primary LLM
router_config = copy.deepcopy(self.primary_llm.config)
# Update model name to indicate this is a router
llm_names = [self.primary_llm.config.model]
if self.model_routing_config.llms_for_routing:
llm_names.extend(
config.model
for config in self.model_routing_config.llms_for_routing.values()
)
router_config.model = f'router({",".join(llm_names)})'
# Initialize parent LLM class
super().__init__(
config=router_config,
service_id=service_id,
metrics=metrics,
retry_listener=retry_listener,
)
# Current LLM state
self._current_llm = self.primary_llm # Default to primary LLM
self._last_routing_decision = 'primary'
logger.info(
f'RouterLLM initialized with {len(self.available_llms)} LLMs: {list(self.available_llms.keys())}'
)
@abstractmethod
def _select_llm(self, messages: list[Message]) -> str:
"""
Select which LLM to use based on messages and events.
"""
pass
def _get_llm_by_key(self, llm_key: str) -> LLM:
"""
Get LLM instance by key.
"""
if llm_key not in self.available_llms:
raise ValueError(
f'Unknown LLM key: {llm_key}. Available: {list(self.available_llms.keys())}'
)
return self.available_llms[llm_key]
@property
def completion(self) -> Callable:
"""
Override completion to route to appropriate LLM.
This method intercepts completion calls and routes them to the appropriate
underlying LLM based on the routing logic implemented in _select_llm().
"""
def router_completion(*args: Any, **kwargs: Any) -> Any:
# Extract messages for routing decision
messages = kwargs.get('messages', [])
if args and not messages:
messages = args[0] if args else []
# Select appropriate LLM
selected_llm_key = self._select_llm(messages)
selected_llm = self._get_llm_by_key(selected_llm_key)
# Update current state
self._current_llm = selected_llm
self._last_routing_decision = selected_llm_key
logger.debug(
f'RouterLLM routing to {selected_llm_key} ({selected_llm.config.model})'
)
# Delegate to selected LLM
return selected_llm.completion(*args, **kwargs)
return router_completion
def __str__(self) -> str:
"""String representation of the router."""
return f'{self.__class__.__name__}(llms={list(self.available_llms.keys())})'
def __repr__(self) -> str:
"""Detailed string representation of the router."""
return (
f'{self.__class__.__name__}('
f'primary={self.primary_llm.config.model}, '
f'routing_llms={[llm.config.model for llm in self.llms_for_routing.values()]}, '
f'current={self._last_routing_decision})'
)
def __getattr__(self, name):
"""Delegate other attributes/methods to the active LLM."""
return getattr(self._current_llm, name)
@classmethod
def from_config(
cls, llm_registry: 'LLMRegistry', agent_config: AgentConfig, **kwargs
) -> 'RouterLLM':
"""Factory method to create a RouterLLM instance from configuration."""
router_cls = ROUTER_LLM_REGISTRY.get(agent_config.model_routing.router_name)
if not router_cls:
raise ValueError(
f'Router LLM {agent_config.model_routing.router_name} not found.'
)
return router_cls(agent_config, llm_registry, **kwargs)

View File

@ -0,0 +1,74 @@
from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.llm.router.base import ROUTER_LLM_REGISTRY, RouterLLM
class MultimodalRouter(RouterLLM):
SECONDARY_MODEL_CONFIG_NAME = 'secondary_model'
ROUTER_NAME = 'multimodal_router'
def __init__(
self,
agent_config: AgentConfig,
llm_registry: LLMRegistry,
**kwargs,
):
super().__init__(agent_config, llm_registry, **kwargs)
self._validate_model_routing_config(self.llms_for_routing)
# States
self.max_token_exceeded = False
def _select_llm(self, messages: list[Message]) -> str:
"""Select LLM based on multimodal content and token limits."""
route_to_primary = False
# Check for multimodal content in messages
for message in messages:
if message.contains_image:
logger.info(
'Multimodal content detected in messages. Routing to the primary model.'
)
route_to_primary = True
if not route_to_primary and self.max_token_exceeded:
route_to_primary = True
# Check if `messages` exceeds context window of the secondary model
# Assuming the secondary model has a lower context window limit compared to the primary model
secondary_llm = self.available_llms.get(self.SECONDARY_MODEL_CONFIG_NAME)
if secondary_llm and (
secondary_llm.config.max_input_tokens
and secondary_llm.get_token_count(messages)
> secondary_llm.config.max_input_tokens
):
logger.warning(
f"Messages having {secondary_llm.get_token_count(messages)} tokens, exceed secondary model's max input tokens ({secondary_llm.config.max_input_tokens} tokens). "
'Routing to the primary model.'
)
self.max_token_exceeded = True
route_to_primary = True
if route_to_primary:
logger.info('Routing to the primary model...')
return 'primary'
else:
logger.info('Routing to the secondary model...')
return self.SECONDARY_MODEL_CONFIG_NAME
def vision_is_active(self):
return self.primary_llm.vision_is_active()
def _validate_model_routing_config(self, llms_for_routing: dict[str, LLM]):
if self.SECONDARY_MODEL_CONFIG_NAME not in llms_for_routing:
raise ValueError(
f'Secondary LLM config {self.SECONDARY_MODEL_CONFIG_NAME} not found.'
)
# Register the router
ROUTER_LLM_REGISTRY[MultimodalRouter.ROUTER_NAME] = MultimodalRouter