mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Implement model routing support (#9738)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
af0ab5a9f2
commit
df9320f8ab
@ -219,6 +219,14 @@ correct_num = 5
|
||||
api_key = ""
|
||||
model = "gpt-4o"
|
||||
|
||||
# Example routing LLM configuration for multimodal model routing
|
||||
# Uncomment and configure to enable model routing with a secondary model
|
||||
#[llm.secondary_model]
|
||||
#model = "kimi-k2"
|
||||
#api_key = ""
|
||||
#for_routing = true
|
||||
#max_input_tokens = 128000
|
||||
|
||||
|
||||
#################################### Agent ###################################
|
||||
# Configuration for agents (group name starts with 'agent')
|
||||
@ -480,3 +488,14 @@ type = "noop"
|
||||
|
||||
# Run the runtime sandbox container in privileged mode for use with docker-in-docker
|
||||
#privileged = false
|
||||
|
||||
#################################### Model Routing ############################
|
||||
# Configuration for experimental model routing feature
|
||||
# Enables intelligent switching between different LLM models for specific purposes
|
||||
##############################################################################
|
||||
[model_routing]
|
||||
# Router to use for model selection
|
||||
# Available options:
|
||||
# - "noop_router" (default): No routing, always uses primary LLM
|
||||
# - "multimodal_router": A router that switches between primary and secondary models, depending on whether the input is multimodal or not
|
||||
#router_name = "noop_router"
|
||||
|
||||
@ -28,6 +28,7 @@ from evaluation.utils.shared import (
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
update_llm_config_for_completions_logging,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
@ -36,7 +37,11 @@ from openhands.core.config import (
|
||||
get_llm_config_arg,
|
||||
load_from_toml,
|
||||
)
|
||||
from openhands.core.config.utils import get_agent_config_arg
|
||||
from openhands.core.config.utils import (
|
||||
get_agent_config_arg,
|
||||
get_llms_for_routing_config,
|
||||
get_model_routing_config_arg,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction
|
||||
@ -57,6 +62,7 @@ AGENT_CLS_TO_INST_SUFFIX = {
|
||||
|
||||
|
||||
def get_config(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
) -> OpenHandsConfig:
|
||||
sandbox_config = get_default_sandbox_config_for_eval()
|
||||
@ -66,13 +72,24 @@ def get_config(
|
||||
sandbox_config=sandbox_config,
|
||||
runtime='docker',
|
||||
)
|
||||
config.set_llm_config(metadata.llm_config)
|
||||
config.set_llm_config(
|
||||
update_llm_config_for_completions_logging(
|
||||
metadata.llm_config, metadata.eval_output_dir, instance['instance_id']
|
||||
)
|
||||
)
|
||||
model_routing_config = get_model_routing_config_arg()
|
||||
model_routing_config.llms_for_routing = (
|
||||
get_llms_for_routing_config()
|
||||
) # Populate with LLMs for routing from config.toml file
|
||||
|
||||
if metadata.agent_config:
|
||||
metadata.agent_config.model_routing = model_routing_config
|
||||
config.set_agent_config(metadata.agent_config, metadata.agent_class)
|
||||
else:
|
||||
logger.info('Agent config not provided, using default settings')
|
||||
agent_config = config.get_agent_config(metadata.agent_class)
|
||||
agent_config.enable_prompt_extensions = False
|
||||
agent_config.model_routing = model_routing_config
|
||||
|
||||
config_copy = copy.deepcopy(config)
|
||||
load_from_toml(config_copy)
|
||||
@ -145,7 +162,7 @@ def process_instance(
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
) -> EvalOutput:
|
||||
config = get_config(metadata)
|
||||
config = get_config(instance, metadata)
|
||||
|
||||
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||
if reset_logger:
|
||||
|
||||
@ -47,6 +47,8 @@ from openhands.core.config import (
|
||||
get_agent_config_arg,
|
||||
get_evaluation_parser,
|
||||
get_llm_config_arg,
|
||||
get_llms_for_routing_config,
|
||||
get_model_routing_config_arg,
|
||||
)
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.core.config.utils import get_condenser_config_arg
|
||||
@ -244,6 +246,11 @@ def get_config(
|
||||
# get 'draft_editor' config if exists
|
||||
config.set_llm_config(get_llm_config_arg('draft_editor'), 'draft_editor')
|
||||
|
||||
model_routing_config = get_model_routing_config_arg()
|
||||
model_routing_config.llms_for_routing = (
|
||||
get_llms_for_routing_config()
|
||||
) # Populate with LLMs for routing from config.toml file
|
||||
|
||||
agent_config = AgentConfig(
|
||||
enable_jupyter=False,
|
||||
enable_browsing=RUN_WITH_BROWSING,
|
||||
@ -251,8 +258,10 @@ def get_config(
|
||||
enable_mcp=False,
|
||||
condenser=metadata.condenser_config,
|
||||
enable_prompt_extensions=False,
|
||||
model_routing=model_routing_config,
|
||||
)
|
||||
config.set_agent_config(agent_config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
|
||||
@ -92,6 +92,9 @@ class CodeActAgent(Agent):
|
||||
self.condenser = Condenser.from_config(self.config.condenser, llm_registry)
|
||||
logger.debug(f'Using condenser: {type(self.condenser)}')
|
||||
|
||||
# Override with router if needed
|
||||
self.llm = self.llm_registry.get_router(self.config)
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> PromptManager:
|
||||
if self._prompt_manager is None:
|
||||
|
||||
@ -13,6 +13,7 @@ from openhands.core.config.config_utils import (
|
||||
from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.config.model_routing_config import ModelRoutingConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
@ -20,6 +21,8 @@ from openhands.core.config.utils import (
|
||||
finalize_config,
|
||||
get_agent_config_arg,
|
||||
get_llm_config_arg,
|
||||
get_llms_for_routing_config,
|
||||
get_model_routing_config_arg,
|
||||
load_from_env,
|
||||
load_from_toml,
|
||||
load_openhands_config,
|
||||
@ -37,6 +40,7 @@ __all__ = [
|
||||
'LLMConfig',
|
||||
'SandboxConfig',
|
||||
'SecurityConfig',
|
||||
'ModelRoutingConfig',
|
||||
'ExtendedConfig',
|
||||
'load_openhands_config',
|
||||
'load_from_env',
|
||||
@ -50,4 +54,6 @@ __all__ = [
|
||||
'get_evaluation_parser',
|
||||
'parse_arguments',
|
||||
'setup_config_from_args',
|
||||
'get_model_routing_config_arg',
|
||||
'get_llms_for_routing_config',
|
||||
]
|
||||
|
||||
@ -7,6 +7,7 @@ from openhands.core.config.condenser_config import (
|
||||
ConversationWindowCondenserConfig,
|
||||
)
|
||||
from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.model_routing_config import ModelRoutingConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
@ -57,6 +58,8 @@ class AgentConfig(BaseModel):
|
||||
# handled.
|
||||
default_factory=lambda: ConversationWindowCondenserConfig()
|
||||
)
|
||||
model_routing: ModelRoutingConfig = Field(default_factory=ModelRoutingConfig)
|
||||
"""Model routing configuration settings."""
|
||||
extended: ExtendedConfig = Field(default_factory=lambda: ExtendedConfig({}))
|
||||
"""Extended configuration for the agent."""
|
||||
|
||||
|
||||
@ -46,6 +46,7 @@ class LLMConfig(BaseModel):
|
||||
reasoning_effort: The effort to put into reasoning. This is a string that can be one of 'low', 'medium', 'high', or 'none'. Can apply to all reasoning models.
|
||||
seed: The seed to use for the LLM.
|
||||
safety_settings: Safety settings for models that support them (like Mistral AI and Gemini).
|
||||
for_routing: Whether this LLM is used for routing. This is set to True for models used in conjunction with the main LLM in the model routing feature.
|
||||
"""
|
||||
|
||||
model: str = Field(default='claude-sonnet-4-20250514')
|
||||
@ -92,6 +93,7 @@ class LLMConfig(BaseModel):
|
||||
default=None,
|
||||
description='Safety settings for models that support them (like Mistral AI and Gemini)',
|
||||
)
|
||||
for_routing: bool = Field(default=False)
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
|
||||
39
openhands/core/config/model_routing_config.py
Normal file
39
openhands/core/config/model_routing_config.py
Normal file
@ -0,0 +1,39 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
|
||||
|
||||
class ModelRoutingConfig(BaseModel):
|
||||
"""Configuration for model routing.
|
||||
|
||||
Attributes:
|
||||
router_name (str): The name of the router to use. Default is 'noop_router'.
|
||||
llms_for_routing (dict[str, LLMConfig]): A dictionary mapping config names of LLMs for routing to their configurations.
|
||||
"""
|
||||
|
||||
router_name: str = Field(default='noop_router')
|
||||
llms_for_routing: dict[str, LLMConfig] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
@classmethod
|
||||
def from_toml_section(cls, data: dict) -> dict[str, 'ModelRoutingConfig']:
|
||||
"""
|
||||
Create a mapping of ModelRoutingConfig instances from a toml dictionary representing the [model_routing] section.
|
||||
|
||||
The configuration is built from all keys in data.
|
||||
|
||||
Returns:
|
||||
dict[str, ModelRoutingConfig]: A mapping where the key "model_routing" corresponds to the [model_routing] configuration
|
||||
"""
|
||||
|
||||
# Initialize the result mapping
|
||||
model_routing_mapping: dict[str, ModelRoutingConfig] = {}
|
||||
|
||||
# Try to create the configuration instance
|
||||
try:
|
||||
model_routing_mapping['model_routing'] = cls.model_validate(data)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f'Invalid model routing configuration: {e}')
|
||||
|
||||
return model_routing_mapping
|
||||
@ -30,6 +30,7 @@ class OpenHandsConfig(BaseModel):
|
||||
The default configuration is stored under the 'agent' key.
|
||||
default_agent: Name of the default agent to use.
|
||||
sandbox: Sandbox configuration settings.
|
||||
security: Security configuration settings.
|
||||
runtime: Runtime environment identifier.
|
||||
file_store: Type of file store to use.
|
||||
file_store_path: Path to the file store.
|
||||
|
||||
@ -25,6 +25,7 @@ from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.kubernetes_config import KubernetesConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.config.model_routing_config import ModelRoutingConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
@ -225,6 +226,35 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None
|
||||
# Re-raise ValueError from SecurityConfig.from_toml_section
|
||||
raise ValueError('Error in [security] section in config.toml')
|
||||
|
||||
if 'model_routing' in toml_config:
|
||||
try:
|
||||
model_routing_mapping = ModelRoutingConfig.from_toml_section(
|
||||
toml_config['model_routing']
|
||||
)
|
||||
# We only use the base model routing config for now
|
||||
if 'model_routing' in model_routing_mapping:
|
||||
default_agent_config = cfg.get_agent_config()
|
||||
default_agent_config.model_routing = model_routing_mapping[
|
||||
'model_routing'
|
||||
]
|
||||
|
||||
# Construct the llms_for_routing by filtering llms with for_routing = True
|
||||
llms_for_routing_dict = {}
|
||||
for llm_name, llm_config in cfg.llms.items():
|
||||
if llm_config and llm_config.for_routing:
|
||||
llms_for_routing_dict[llm_name] = llm_config
|
||||
default_agent_config.model_routing.llms_for_routing = (
|
||||
llms_for_routing_dict
|
||||
)
|
||||
|
||||
logger.openhands_logger.debug(
|
||||
'Default model routing configuration loaded from config toml and assigned to default agent'
|
||||
)
|
||||
except (TypeError, KeyError, ValidationError) as e:
|
||||
logger.openhands_logger.warning(
|
||||
f'Cannot parse [model_routing] config from toml, values have not been applied.\nError: {e}'
|
||||
)
|
||||
|
||||
# Process sandbox section if present
|
||||
if 'sandbox' in toml_config:
|
||||
try:
|
||||
@ -327,6 +357,7 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None
|
||||
'condenser',
|
||||
'mcp',
|
||||
'kubernetes',
|
||||
'model_routing',
|
||||
}
|
||||
for key in toml_config:
|
||||
if key.lower() not in known_sections:
|
||||
@ -559,6 +590,41 @@ def get_llm_config_arg(
|
||||
return None
|
||||
|
||||
|
||||
def get_llms_for_routing_config(toml_file: str = 'config.toml') -> dict[str, LLMConfig]:
|
||||
"""Get the LLMs that are configured for routing from the config file.
|
||||
|
||||
This function will return a dictionary of LLMConfig objects that are configured
|
||||
for routing, i.e., those with `for_routing` set to True.
|
||||
|
||||
Args:
|
||||
toml_file: Path to the configuration file to read from. Defaults to 'config.toml'.
|
||||
|
||||
Returns:
|
||||
dict[str, LLMConfig]: A dictionary of LLMConfig objects for routing.
|
||||
"""
|
||||
llms_for_routing: dict[str, LLMConfig] = {}
|
||||
|
||||
try:
|
||||
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
|
||||
toml_config = toml.load(toml_contents)
|
||||
except FileNotFoundError:
|
||||
return llms_for_routing
|
||||
except toml.TomlDecodeError as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Cannot parse LLM configs from {toml_file}. Exception: {e}'
|
||||
)
|
||||
return llms_for_routing
|
||||
|
||||
llm_configs = LLMConfig.from_toml_section(toml_config.get('llm', {}))
|
||||
|
||||
if llm_configs:
|
||||
for llm_name, llm_config in llm_configs.items():
|
||||
if llm_config.for_routing:
|
||||
llms_for_routing[llm_name] = llm_config
|
||||
|
||||
return llms_for_routing
|
||||
|
||||
|
||||
def get_condenser_config_arg(
|
||||
condenser_config_arg: str, toml_file: str = 'config.toml'
|
||||
) -> CondenserConfig | None:
|
||||
@ -671,6 +737,50 @@ def get_condenser_config_arg(
|
||||
return None
|
||||
|
||||
|
||||
def get_model_routing_config_arg(toml_file: str = 'config.toml') -> ModelRoutingConfig:
|
||||
"""Get the model routing settings from the config file. We only support the default model routing config [model_routing].
|
||||
|
||||
Args:
|
||||
toml_file: Path to the configuration file to read from. Defaults to 'config.toml'.
|
||||
|
||||
Returns:
|
||||
ModelRoutingConfig: The ModelRoutingConfig object with the settings from the config file, or the object with default values if not found/error.
|
||||
"""
|
||||
logger.openhands_logger.debug(
|
||||
f"Loading model routing config ['model_routing'] from {toml_file}"
|
||||
)
|
||||
default_cfg = ModelRoutingConfig()
|
||||
|
||||
# load the toml file
|
||||
try:
|
||||
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
|
||||
toml_config = toml.load(toml_contents)
|
||||
except FileNotFoundError as e:
|
||||
logger.openhands_logger.error(f'Config file not found: {toml_file}. Error: {e}')
|
||||
return default_cfg
|
||||
except toml.TomlDecodeError as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Cannot parse model routing group [model_routing] from {toml_file}. Exception: {e}'
|
||||
)
|
||||
return default_cfg
|
||||
|
||||
# Update the model routing config with the specified section
|
||||
if 'model_routing' in toml_config:
|
||||
try:
|
||||
model_routing_data = toml_config['model_routing']
|
||||
return ModelRoutingConfig(**model_routing_data)
|
||||
except ValidationError as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Invalid model routing configuration for [model_routing]: {e}'
|
||||
)
|
||||
return default_cfg
|
||||
|
||||
logger.openhands_logger.warning(
|
||||
f'Model routing config section [model_routing] not found in {toml_file}'
|
||||
)
|
||||
return default_cfg
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
"""Parse command line arguments."""
|
||||
parser = get_headless_parser()
|
||||
|
||||
@ -108,10 +108,24 @@ class LLMRegistry:
|
||||
def get_active_llm(self) -> LLM:
|
||||
return self.active_agent_llm
|
||||
|
||||
def _set_active_llm(self, service_id) -> None:
|
||||
if service_id not in self.service_to_llm:
|
||||
raise ValueError(f'Unrecognized service ID: {service_id}')
|
||||
self.active_agent_llm = self.service_to_llm[service_id]
|
||||
def get_router(self, agent_config: AgentConfig) -> 'LLM':
|
||||
"""
|
||||
Get a router instance that inherits from LLM.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from openhands.llm.router import RouterLLM
|
||||
|
||||
router_name = agent_config.model_routing.router_name
|
||||
|
||||
if router_name == 'noop_router':
|
||||
# Return the main LLM directly (no routing)
|
||||
return self.get_llm_from_agent_config('agent', agent_config)
|
||||
|
||||
return RouterLLM.from_config(
|
||||
agent_config=agent_config,
|
||||
llm_registry=self,
|
||||
retry_listener=self.retry_listner,
|
||||
)
|
||||
|
||||
def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None:
|
||||
self.subscriber = callback
|
||||
|
||||
39
openhands/llm/router/README.md
Normal file
39
openhands/llm/router/README.md
Normal file
@ -0,0 +1,39 @@
|
||||
# Model Routing Module
|
||||
|
||||
**⚠️ Experimental Feature**: This module is experimental and under active development.
|
||||
|
||||
## Overview
|
||||
|
||||
Model routing enables OpenHands to switch between different LLM models during a conversation. An example use case is routing between a primary (expensive, multimodal) model and a secondary (cheaper, text-only) model.
|
||||
|
||||
## Available Routers
|
||||
|
||||
- **`noop_router`** (default): No routing, always uses primary LLM
|
||||
- **`multimodal_router`**: A router that switches based on:
|
||||
- Routes to primary model for images or when secondary model's context limit is exceeded
|
||||
- Uses secondary model for text-only requests within its context limit
|
||||
|
||||
## Configuration
|
||||
|
||||
Add to your `config.toml`:
|
||||
|
||||
```toml
|
||||
# Main LLM (primary model)
|
||||
[llm]
|
||||
model = "claude-sonnet-4"
|
||||
api_key = "your-api-key"
|
||||
|
||||
# Secondary model for routing
|
||||
[llm.secondary_model]
|
||||
model = "kimi-k2"
|
||||
api_key = "your-api-key"
|
||||
for_routing = true
|
||||
|
||||
# Enable routing
|
||||
[model_routing]
|
||||
router_name = "multimodal_router"
|
||||
```
|
||||
|
||||
## Extending
|
||||
|
||||
Create custom routers by inheriting from `BaseRouter` and implementing `set_active_llm()`. Register in `ROUTER_REGISTRY`.
|
||||
8
openhands/llm/router/__init__.py
Normal file
8
openhands/llm/router/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from openhands.llm.router.base import ROUTER_LLM_REGISTRY, RouterLLM
|
||||
from openhands.llm.router.rule_based.impl import MultimodalRouter
|
||||
|
||||
__all__ = [
|
||||
'RouterLLM',
|
||||
'ROUTER_LLM_REGISTRY',
|
||||
'MultimodalRouter',
|
||||
]
|
||||
164
openhands/llm/router/base.py
Normal file
164
openhands/llm/router/base.py
Normal file
@ -0,0 +1,164 @@
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
|
||||
ROUTER_LLM_REGISTRY: dict[str, type['RouterLLM']] = {}
|
||||
|
||||
|
||||
class RouterLLM(LLM):
|
||||
"""
|
||||
Base class for multiple LLM acting as a unified LLM.
|
||||
|
||||
This class provides a foundation for implementing model routing by inheriting from LLM,
|
||||
allowing routers to work with multiple underlying LLM models while presenting a unified
|
||||
LLM interface to consumers.
|
||||
|
||||
Key features:
|
||||
- Works with multiple LLMs configured via llms_for_routing
|
||||
- Delegates all other operations/properties to the selected LLM
|
||||
- Provides routing interface through _select_llm() method
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
llm_registry: 'LLMRegistry',
|
||||
service_id: str = 'router_llm',
|
||||
metrics: Metrics | None = None,
|
||||
retry_listener: Callable[[int, int], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize RouterLLM with multiple LLM support.
|
||||
"""
|
||||
self.llm_registry = llm_registry
|
||||
self.model_routing_config = agent_config.model_routing
|
||||
|
||||
# Get the primary agent LLM
|
||||
self.primary_llm = llm_registry.get_llm_from_agent_config('agent', agent_config)
|
||||
|
||||
# Instantiate all the LLM instances for routing
|
||||
llms_for_routing_config = self.model_routing_config.llms_for_routing
|
||||
self.llms_for_routing = {
|
||||
config_name: self.llm_registry.get_llm(
|
||||
f'llm_for_routing.{config_name}', config=llm_config
|
||||
)
|
||||
for config_name, llm_config in llms_for_routing_config.items()
|
||||
}
|
||||
|
||||
# All available LLMs for routing (set this BEFORE calling super().__init__)
|
||||
self.available_llms = {'primary': self.primary_llm, **self.llms_for_routing}
|
||||
|
||||
# Create router config based on primary LLM
|
||||
router_config = copy.deepcopy(self.primary_llm.config)
|
||||
|
||||
# Update model name to indicate this is a router
|
||||
llm_names = [self.primary_llm.config.model]
|
||||
if self.model_routing_config.llms_for_routing:
|
||||
llm_names.extend(
|
||||
config.model
|
||||
for config in self.model_routing_config.llms_for_routing.values()
|
||||
)
|
||||
router_config.model = f'router({",".join(llm_names)})'
|
||||
|
||||
# Initialize parent LLM class
|
||||
super().__init__(
|
||||
config=router_config,
|
||||
service_id=service_id,
|
||||
metrics=metrics,
|
||||
retry_listener=retry_listener,
|
||||
)
|
||||
|
||||
# Current LLM state
|
||||
self._current_llm = self.primary_llm # Default to primary LLM
|
||||
self._last_routing_decision = 'primary'
|
||||
|
||||
logger.info(
|
||||
f'RouterLLM initialized with {len(self.available_llms)} LLMs: {list(self.available_llms.keys())}'
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _select_llm(self, messages: list[Message]) -> str:
|
||||
"""
|
||||
Select which LLM to use based on messages and events.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_llm_by_key(self, llm_key: str) -> LLM:
|
||||
"""
|
||||
Get LLM instance by key.
|
||||
"""
|
||||
if llm_key not in self.available_llms:
|
||||
raise ValueError(
|
||||
f'Unknown LLM key: {llm_key}. Available: {list(self.available_llms.keys())}'
|
||||
)
|
||||
return self.available_llms[llm_key]
|
||||
|
||||
@property
|
||||
def completion(self) -> Callable:
|
||||
"""
|
||||
Override completion to route to appropriate LLM.
|
||||
|
||||
This method intercepts completion calls and routes them to the appropriate
|
||||
underlying LLM based on the routing logic implemented in _select_llm().
|
||||
"""
|
||||
|
||||
def router_completion(*args: Any, **kwargs: Any) -> Any:
|
||||
# Extract messages for routing decision
|
||||
messages = kwargs.get('messages', [])
|
||||
if args and not messages:
|
||||
messages = args[0] if args else []
|
||||
|
||||
# Select appropriate LLM
|
||||
selected_llm_key = self._select_llm(messages)
|
||||
selected_llm = self._get_llm_by_key(selected_llm_key)
|
||||
|
||||
# Update current state
|
||||
self._current_llm = selected_llm
|
||||
self._last_routing_decision = selected_llm_key
|
||||
|
||||
logger.debug(
|
||||
f'RouterLLM routing to {selected_llm_key} ({selected_llm.config.model})'
|
||||
)
|
||||
|
||||
# Delegate to selected LLM
|
||||
return selected_llm.completion(*args, **kwargs)
|
||||
|
||||
return router_completion
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the router."""
|
||||
return f'{self.__class__.__name__}(llms={list(self.available_llms.keys())})'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Detailed string representation of the router."""
|
||||
return (
|
||||
f'{self.__class__.__name__}('
|
||||
f'primary={self.primary_llm.config.model}, '
|
||||
f'routing_llms={[llm.config.model for llm in self.llms_for_routing.values()]}, '
|
||||
f'current={self._last_routing_decision})'
|
||||
)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Delegate other attributes/methods to the active LLM."""
|
||||
return getattr(self._current_llm, name)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, llm_registry: 'LLMRegistry', agent_config: AgentConfig, **kwargs
|
||||
) -> 'RouterLLM':
|
||||
"""Factory method to create a RouterLLM instance from configuration."""
|
||||
router_cls = ROUTER_LLM_REGISTRY.get(agent_config.model_routing.router_name)
|
||||
if not router_cls:
|
||||
raise ValueError(
|
||||
f'Router LLM {agent_config.model_routing.router_name} not found.'
|
||||
)
|
||||
return router_cls(agent_config, llm_registry, **kwargs)
|
||||
74
openhands/llm/router/rule_based/impl.py
Normal file
74
openhands/llm/router/rule_based/impl.py
Normal file
@ -0,0 +1,74 @@
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.llm.router.base import ROUTER_LLM_REGISTRY, RouterLLM
|
||||
|
||||
|
||||
class MultimodalRouter(RouterLLM):
|
||||
SECONDARY_MODEL_CONFIG_NAME = 'secondary_model'
|
||||
ROUTER_NAME = 'multimodal_router'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(agent_config, llm_registry, **kwargs)
|
||||
|
||||
self._validate_model_routing_config(self.llms_for_routing)
|
||||
|
||||
# States
|
||||
self.max_token_exceeded = False
|
||||
|
||||
def _select_llm(self, messages: list[Message]) -> str:
|
||||
"""Select LLM based on multimodal content and token limits."""
|
||||
route_to_primary = False
|
||||
|
||||
# Check for multimodal content in messages
|
||||
for message in messages:
|
||||
if message.contains_image:
|
||||
logger.info(
|
||||
'Multimodal content detected in messages. Routing to the primary model.'
|
||||
)
|
||||
route_to_primary = True
|
||||
|
||||
if not route_to_primary and self.max_token_exceeded:
|
||||
route_to_primary = True
|
||||
|
||||
# Check if `messages` exceeds context window of the secondary model
|
||||
# Assuming the secondary model has a lower context window limit compared to the primary model
|
||||
secondary_llm = self.available_llms.get(self.SECONDARY_MODEL_CONFIG_NAME)
|
||||
if secondary_llm and (
|
||||
secondary_llm.config.max_input_tokens
|
||||
and secondary_llm.get_token_count(messages)
|
||||
> secondary_llm.config.max_input_tokens
|
||||
):
|
||||
logger.warning(
|
||||
f"Messages having {secondary_llm.get_token_count(messages)} tokens, exceed secondary model's max input tokens ({secondary_llm.config.max_input_tokens} tokens). "
|
||||
'Routing to the primary model.'
|
||||
)
|
||||
self.max_token_exceeded = True
|
||||
route_to_primary = True
|
||||
|
||||
if route_to_primary:
|
||||
logger.info('Routing to the primary model...')
|
||||
return 'primary'
|
||||
else:
|
||||
logger.info('Routing to the secondary model...')
|
||||
return self.SECONDARY_MODEL_CONFIG_NAME
|
||||
|
||||
def vision_is_active(self):
|
||||
return self.primary_llm.vision_is_active()
|
||||
|
||||
def _validate_model_routing_config(self, llms_for_routing: dict[str, LLM]):
|
||||
if self.SECONDARY_MODEL_CONFIG_NAME not in llms_for_routing:
|
||||
raise ValueError(
|
||||
f'Secondary LLM config {self.SECONDARY_MODEL_CONFIG_NAME} not found.'
|
||||
)
|
||||
|
||||
|
||||
# Register the router
|
||||
ROUTER_LLM_REGISTRY[MultimodalRouter.ROUTER_NAME] = MultimodalRouter
|
||||
Loading…
x
Reference in New Issue
Block a user