Implement model routing support (#9738)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Ryan H. Tran
2025-09-08 16:19:34 +07:00
committed by GitHub
parent af0ab5a9f2
commit df9320f8ab
15 changed files with 515 additions and 7 deletions

View File

@@ -108,10 +108,24 @@ class LLMRegistry:
def get_active_llm(self) -> LLM:
return self.active_agent_llm
def _set_active_llm(self, service_id) -> None:
if service_id not in self.service_to_llm:
raise ValueError(f'Unrecognized service ID: {service_id}')
self.active_agent_llm = self.service_to_llm[service_id]
def get_router(self, agent_config: AgentConfig) -> 'LLM':
"""
Get a router instance that inherits from LLM.
"""
# Import here to avoid circular imports
from openhands.llm.router import RouterLLM
router_name = agent_config.model_routing.router_name
if router_name == 'noop_router':
# Return the main LLM directly (no routing)
return self.get_llm_from_agent_config('agent', agent_config)
return RouterLLM.from_config(
agent_config=agent_config,
llm_registry=self,
retry_listener=self.retry_listner,
)
def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None:
self.subscriber = callback

View File

@@ -0,0 +1,39 @@
# Model Routing Module
**⚠️ Experimental Feature**: This module is experimental and under active development.
## Overview
Model routing enables OpenHands to switch between different LLM models during a conversation. An example use case is routing between a primary (expensive, multimodal) model and a secondary (cheaper, text-only) model.
## Available Routers
- **`noop_router`** (default): No routing, always uses primary LLM
- **`multimodal_router`**: A router that switches based on:
- Routes to primary model for images or when secondary model's context limit is exceeded
- Uses secondary model for text-only requests within its context limit
## Configuration
Add to your `config.toml`:
```toml
# Main LLM (primary model)
[llm]
model = "claude-sonnet-4"
api_key = "your-api-key"
# Secondary model for routing
[llm.secondary_model]
model = "kimi-k2"
api_key = "your-api-key"
for_routing = true
# Enable routing
[model_routing]
router_name = "multimodal_router"
```
## Extending
Create custom routers by inheriting from `BaseRouter` and implementing `set_active_llm()`. Register in `ROUTER_REGISTRY`.

View File

@@ -0,0 +1,8 @@
from openhands.llm.router.base import ROUTER_LLM_REGISTRY, RouterLLM
from openhands.llm.router.rule_based.impl import MultimodalRouter
__all__ = [
'RouterLLM',
'ROUTER_LLM_REGISTRY',
'MultimodalRouter',
]

View File

@@ -0,0 +1,164 @@
import copy
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable
from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
if TYPE_CHECKING:
from openhands.llm.llm_registry import LLMRegistry
ROUTER_LLM_REGISTRY: dict[str, type['RouterLLM']] = {}
class RouterLLM(LLM):
"""
Base class for multiple LLM acting as a unified LLM.
This class provides a foundation for implementing model routing by inheriting from LLM,
allowing routers to work with multiple underlying LLM models while presenting a unified
LLM interface to consumers.
Key features:
- Works with multiple LLMs configured via llms_for_routing
- Delegates all other operations/properties to the selected LLM
- Provides routing interface through _select_llm() method
"""
def __init__(
self,
agent_config: AgentConfig,
llm_registry: 'LLMRegistry',
service_id: str = 'router_llm',
metrics: Metrics | None = None,
retry_listener: Callable[[int, int], None] | None = None,
):
"""
Initialize RouterLLM with multiple LLM support.
"""
self.llm_registry = llm_registry
self.model_routing_config = agent_config.model_routing
# Get the primary agent LLM
self.primary_llm = llm_registry.get_llm_from_agent_config('agent', agent_config)
# Instantiate all the LLM instances for routing
llms_for_routing_config = self.model_routing_config.llms_for_routing
self.llms_for_routing = {
config_name: self.llm_registry.get_llm(
f'llm_for_routing.{config_name}', config=llm_config
)
for config_name, llm_config in llms_for_routing_config.items()
}
# All available LLMs for routing (set this BEFORE calling super().__init__)
self.available_llms = {'primary': self.primary_llm, **self.llms_for_routing}
# Create router config based on primary LLM
router_config = copy.deepcopy(self.primary_llm.config)
# Update model name to indicate this is a router
llm_names = [self.primary_llm.config.model]
if self.model_routing_config.llms_for_routing:
llm_names.extend(
config.model
for config in self.model_routing_config.llms_for_routing.values()
)
router_config.model = f'router({",".join(llm_names)})'
# Initialize parent LLM class
super().__init__(
config=router_config,
service_id=service_id,
metrics=metrics,
retry_listener=retry_listener,
)
# Current LLM state
self._current_llm = self.primary_llm # Default to primary LLM
self._last_routing_decision = 'primary'
logger.info(
f'RouterLLM initialized with {len(self.available_llms)} LLMs: {list(self.available_llms.keys())}'
)
@abstractmethod
def _select_llm(self, messages: list[Message]) -> str:
"""
Select which LLM to use based on messages and events.
"""
pass
def _get_llm_by_key(self, llm_key: str) -> LLM:
"""
Get LLM instance by key.
"""
if llm_key not in self.available_llms:
raise ValueError(
f'Unknown LLM key: {llm_key}. Available: {list(self.available_llms.keys())}'
)
return self.available_llms[llm_key]
@property
def completion(self) -> Callable:
"""
Override completion to route to appropriate LLM.
This method intercepts completion calls and routes them to the appropriate
underlying LLM based on the routing logic implemented in _select_llm().
"""
def router_completion(*args: Any, **kwargs: Any) -> Any:
# Extract messages for routing decision
messages = kwargs.get('messages', [])
if args and not messages:
messages = args[0] if args else []
# Select appropriate LLM
selected_llm_key = self._select_llm(messages)
selected_llm = self._get_llm_by_key(selected_llm_key)
# Update current state
self._current_llm = selected_llm
self._last_routing_decision = selected_llm_key
logger.debug(
f'RouterLLM routing to {selected_llm_key} ({selected_llm.config.model})'
)
# Delegate to selected LLM
return selected_llm.completion(*args, **kwargs)
return router_completion
def __str__(self) -> str:
"""String representation of the router."""
return f'{self.__class__.__name__}(llms={list(self.available_llms.keys())})'
def __repr__(self) -> str:
"""Detailed string representation of the router."""
return (
f'{self.__class__.__name__}('
f'primary={self.primary_llm.config.model}, '
f'routing_llms={[llm.config.model for llm in self.llms_for_routing.values()]}, '
f'current={self._last_routing_decision})'
)
def __getattr__(self, name):
"""Delegate other attributes/methods to the active LLM."""
return getattr(self._current_llm, name)
@classmethod
def from_config(
cls, llm_registry: 'LLMRegistry', agent_config: AgentConfig, **kwargs
) -> 'RouterLLM':
"""Factory method to create a RouterLLM instance from configuration."""
router_cls = ROUTER_LLM_REGISTRY.get(agent_config.model_routing.router_name)
if not router_cls:
raise ValueError(
f'Router LLM {agent_config.model_routing.router_name} not found.'
)
return router_cls(agent_config, llm_registry, **kwargs)

View File

@@ -0,0 +1,74 @@
from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.llm.router.base import ROUTER_LLM_REGISTRY, RouterLLM
class MultimodalRouter(RouterLLM):
SECONDARY_MODEL_CONFIG_NAME = 'secondary_model'
ROUTER_NAME = 'multimodal_router'
def __init__(
self,
agent_config: AgentConfig,
llm_registry: LLMRegistry,
**kwargs,
):
super().__init__(agent_config, llm_registry, **kwargs)
self._validate_model_routing_config(self.llms_for_routing)
# States
self.max_token_exceeded = False
def _select_llm(self, messages: list[Message]) -> str:
"""Select LLM based on multimodal content and token limits."""
route_to_primary = False
# Check for multimodal content in messages
for message in messages:
if message.contains_image:
logger.info(
'Multimodal content detected in messages. Routing to the primary model.'
)
route_to_primary = True
if not route_to_primary and self.max_token_exceeded:
route_to_primary = True
# Check if `messages` exceeds context window of the secondary model
# Assuming the secondary model has a lower context window limit compared to the primary model
secondary_llm = self.available_llms.get(self.SECONDARY_MODEL_CONFIG_NAME)
if secondary_llm and (
secondary_llm.config.max_input_tokens
and secondary_llm.get_token_count(messages)
> secondary_llm.config.max_input_tokens
):
logger.warning(
f"Messages having {secondary_llm.get_token_count(messages)} tokens, exceed secondary model's max input tokens ({secondary_llm.config.max_input_tokens} tokens). "
'Routing to the primary model.'
)
self.max_token_exceeded = True
route_to_primary = True
if route_to_primary:
logger.info('Routing to the primary model...')
return 'primary'
else:
logger.info('Routing to the secondary model...')
return self.SECONDARY_MODEL_CONFIG_NAME
def vision_is_active(self):
return self.primary_llm.vision_is_active()
def _validate_model_routing_config(self, llms_for_routing: dict[str, LLM]):
if self.SECONDARY_MODEL_CONFIG_NAME not in llms_for_routing:
raise ValueError(
f'Secondary LLM config {self.SECONDARY_MODEL_CONFIG_NAME} not found.'
)
# Register the router
ROUTER_LLM_REGISTRY[MultimodalRouter.ROUTER_NAME] = MultimodalRouter