mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Implement model routing support (#9738)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -13,6 +13,7 @@ from openhands.core.config.config_utils import (
|
||||
from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.config.model_routing_config import ModelRoutingConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
@@ -20,6 +21,8 @@ from openhands.core.config.utils import (
|
||||
finalize_config,
|
||||
get_agent_config_arg,
|
||||
get_llm_config_arg,
|
||||
get_llms_for_routing_config,
|
||||
get_model_routing_config_arg,
|
||||
load_from_env,
|
||||
load_from_toml,
|
||||
load_openhands_config,
|
||||
@@ -37,6 +40,7 @@ __all__ = [
|
||||
'LLMConfig',
|
||||
'SandboxConfig',
|
||||
'SecurityConfig',
|
||||
'ModelRoutingConfig',
|
||||
'ExtendedConfig',
|
||||
'load_openhands_config',
|
||||
'load_from_env',
|
||||
@@ -50,4 +54,6 @@ __all__ = [
|
||||
'get_evaluation_parser',
|
||||
'parse_arguments',
|
||||
'setup_config_from_args',
|
||||
'get_model_routing_config_arg',
|
||||
'get_llms_for_routing_config',
|
||||
]
|
||||
|
||||
@@ -7,6 +7,7 @@ from openhands.core.config.condenser_config import (
|
||||
ConversationWindowCondenserConfig,
|
||||
)
|
||||
from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.model_routing_config import ModelRoutingConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
@@ -57,6 +58,8 @@ class AgentConfig(BaseModel):
|
||||
# handled.
|
||||
default_factory=lambda: ConversationWindowCondenserConfig()
|
||||
)
|
||||
model_routing: ModelRoutingConfig = Field(default_factory=ModelRoutingConfig)
|
||||
"""Model routing configuration settings."""
|
||||
extended: ExtendedConfig = Field(default_factory=lambda: ExtendedConfig({}))
|
||||
"""Extended configuration for the agent."""
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ class LLMConfig(BaseModel):
|
||||
reasoning_effort: The effort to put into reasoning. This is a string that can be one of 'low', 'medium', 'high', or 'none'. Can apply to all reasoning models.
|
||||
seed: The seed to use for the LLM.
|
||||
safety_settings: Safety settings for models that support them (like Mistral AI and Gemini).
|
||||
for_routing: Whether this LLM is used for routing. This is set to True for models used in conjunction with the main LLM in the model routing feature.
|
||||
"""
|
||||
|
||||
model: str = Field(default='claude-sonnet-4-20250514')
|
||||
@@ -92,6 +93,7 @@ class LLMConfig(BaseModel):
|
||||
default=None,
|
||||
description='Safety settings for models that support them (like Mistral AI and Gemini)',
|
||||
)
|
||||
for_routing: bool = Field(default=False)
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
|
||||
39
openhands/core/config/model_routing_config.py
Normal file
39
openhands/core/config/model_routing_config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
|
||||
|
||||
class ModelRoutingConfig(BaseModel):
|
||||
"""Configuration for model routing.
|
||||
|
||||
Attributes:
|
||||
router_name (str): The name of the router to use. Default is 'noop_router'.
|
||||
llms_for_routing (dict[str, LLMConfig]): A dictionary mapping config names of LLMs for routing to their configurations.
|
||||
"""
|
||||
|
||||
router_name: str = Field(default='noop_router')
|
||||
llms_for_routing: dict[str, LLMConfig] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
@classmethod
|
||||
def from_toml_section(cls, data: dict) -> dict[str, 'ModelRoutingConfig']:
|
||||
"""
|
||||
Create a mapping of ModelRoutingConfig instances from a toml dictionary representing the [model_routing] section.
|
||||
|
||||
The configuration is built from all keys in data.
|
||||
|
||||
Returns:
|
||||
dict[str, ModelRoutingConfig]: A mapping where the key "model_routing" corresponds to the [model_routing] configuration
|
||||
"""
|
||||
|
||||
# Initialize the result mapping
|
||||
model_routing_mapping: dict[str, ModelRoutingConfig] = {}
|
||||
|
||||
# Try to create the configuration instance
|
||||
try:
|
||||
model_routing_mapping['model_routing'] = cls.model_validate(data)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f'Invalid model routing configuration: {e}')
|
||||
|
||||
return model_routing_mapping
|
||||
@@ -30,6 +30,7 @@ class OpenHandsConfig(BaseModel):
|
||||
The default configuration is stored under the 'agent' key.
|
||||
default_agent: Name of the default agent to use.
|
||||
sandbox: Sandbox configuration settings.
|
||||
security: Security configuration settings.
|
||||
runtime: Runtime environment identifier.
|
||||
file_store: Type of file store to use.
|
||||
file_store_path: Path to the file store.
|
||||
|
||||
@@ -25,6 +25,7 @@ from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.kubernetes_config import KubernetesConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.config.model_routing_config import ModelRoutingConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
@@ -225,6 +226,35 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None
|
||||
# Re-raise ValueError from SecurityConfig.from_toml_section
|
||||
raise ValueError('Error in [security] section in config.toml')
|
||||
|
||||
if 'model_routing' in toml_config:
|
||||
try:
|
||||
model_routing_mapping = ModelRoutingConfig.from_toml_section(
|
||||
toml_config['model_routing']
|
||||
)
|
||||
# We only use the base model routing config for now
|
||||
if 'model_routing' in model_routing_mapping:
|
||||
default_agent_config = cfg.get_agent_config()
|
||||
default_agent_config.model_routing = model_routing_mapping[
|
||||
'model_routing'
|
||||
]
|
||||
|
||||
# Construct the llms_for_routing by filtering llms with for_routing = True
|
||||
llms_for_routing_dict = {}
|
||||
for llm_name, llm_config in cfg.llms.items():
|
||||
if llm_config and llm_config.for_routing:
|
||||
llms_for_routing_dict[llm_name] = llm_config
|
||||
default_agent_config.model_routing.llms_for_routing = (
|
||||
llms_for_routing_dict
|
||||
)
|
||||
|
||||
logger.openhands_logger.debug(
|
||||
'Default model routing configuration loaded from config toml and assigned to default agent'
|
||||
)
|
||||
except (TypeError, KeyError, ValidationError) as e:
|
||||
logger.openhands_logger.warning(
|
||||
f'Cannot parse [model_routing] config from toml, values have not been applied.\nError: {e}'
|
||||
)
|
||||
|
||||
# Process sandbox section if present
|
||||
if 'sandbox' in toml_config:
|
||||
try:
|
||||
@@ -327,6 +357,7 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None
|
||||
'condenser',
|
||||
'mcp',
|
||||
'kubernetes',
|
||||
'model_routing',
|
||||
}
|
||||
for key in toml_config:
|
||||
if key.lower() not in known_sections:
|
||||
@@ -559,6 +590,41 @@ def get_llm_config_arg(
|
||||
return None
|
||||
|
||||
|
||||
def get_llms_for_routing_config(toml_file: str = 'config.toml') -> dict[str, LLMConfig]:
|
||||
"""Get the LLMs that are configured for routing from the config file.
|
||||
|
||||
This function will return a dictionary of LLMConfig objects that are configured
|
||||
for routing, i.e., those with `for_routing` set to True.
|
||||
|
||||
Args:
|
||||
toml_file: Path to the configuration file to read from. Defaults to 'config.toml'.
|
||||
|
||||
Returns:
|
||||
dict[str, LLMConfig]: A dictionary of LLMConfig objects for routing.
|
||||
"""
|
||||
llms_for_routing: dict[str, LLMConfig] = {}
|
||||
|
||||
try:
|
||||
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
|
||||
toml_config = toml.load(toml_contents)
|
||||
except FileNotFoundError:
|
||||
return llms_for_routing
|
||||
except toml.TomlDecodeError as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Cannot parse LLM configs from {toml_file}. Exception: {e}'
|
||||
)
|
||||
return llms_for_routing
|
||||
|
||||
llm_configs = LLMConfig.from_toml_section(toml_config.get('llm', {}))
|
||||
|
||||
if llm_configs:
|
||||
for llm_name, llm_config in llm_configs.items():
|
||||
if llm_config.for_routing:
|
||||
llms_for_routing[llm_name] = llm_config
|
||||
|
||||
return llms_for_routing
|
||||
|
||||
|
||||
def get_condenser_config_arg(
|
||||
condenser_config_arg: str, toml_file: str = 'config.toml'
|
||||
) -> CondenserConfig | None:
|
||||
@@ -671,6 +737,50 @@ def get_condenser_config_arg(
|
||||
return None
|
||||
|
||||
|
||||
def get_model_routing_config_arg(toml_file: str = 'config.toml') -> ModelRoutingConfig:
|
||||
"""Get the model routing settings from the config file. We only support the default model routing config [model_routing].
|
||||
|
||||
Args:
|
||||
toml_file: Path to the configuration file to read from. Defaults to 'config.toml'.
|
||||
|
||||
Returns:
|
||||
ModelRoutingConfig: The ModelRoutingConfig object with the settings from the config file, or the object with default values if not found/error.
|
||||
"""
|
||||
logger.openhands_logger.debug(
|
||||
f"Loading model routing config ['model_routing'] from {toml_file}"
|
||||
)
|
||||
default_cfg = ModelRoutingConfig()
|
||||
|
||||
# load the toml file
|
||||
try:
|
||||
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
|
||||
toml_config = toml.load(toml_contents)
|
||||
except FileNotFoundError as e:
|
||||
logger.openhands_logger.error(f'Config file not found: {toml_file}. Error: {e}')
|
||||
return default_cfg
|
||||
except toml.TomlDecodeError as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Cannot parse model routing group [model_routing] from {toml_file}. Exception: {e}'
|
||||
)
|
||||
return default_cfg
|
||||
|
||||
# Update the model routing config with the specified section
|
||||
if 'model_routing' in toml_config:
|
||||
try:
|
||||
model_routing_data = toml_config['model_routing']
|
||||
return ModelRoutingConfig(**model_routing_data)
|
||||
except ValidationError as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Invalid model routing configuration for [model_routing]: {e}'
|
||||
)
|
||||
return default_cfg
|
||||
|
||||
logger.openhands_logger.warning(
|
||||
f'Model routing config section [model_routing] not found in {toml_file}'
|
||||
)
|
||||
return default_cfg
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
"""Parse command line arguments."""
|
||||
parser = get_headless_parser()
|
||||
|
||||
Reference in New Issue
Block a user