diff --git a/config.template.toml b/config.template.toml index 23d3775f9a..542a3c2e71 100644 --- a/config.template.toml +++ b/config.template.toml @@ -219,6 +219,14 @@ correct_num = 5 api_key = "" model = "gpt-4o" +# Example routing LLM configuration for multimodal model routing +# Uncomment and configure to enable model routing with a secondary model +#[llm.secondary_model] +#model = "kimi-k2" +#api_key = "" +#for_routing = true +#max_input_tokens = 128000 + #################################### Agent ################################### # Configuration for agents (group name starts with 'agent') @@ -480,3 +488,14 @@ type = "noop" # Run the runtime sandbox container in privileged mode for use with docker-in-docker #privileged = false + +#################################### Model Routing ############################ +# Configuration for experimental model routing feature +# Enables intelligent switching between different LLM models for specific purposes +############################################################################## +[model_routing] +# Router to use for model selection +# Available options: +# - "noop_router" (default): No routing, always uses primary LLM +# - "multimodal_router": A router that switches between primary and secondary models, depending on whether the input is multimodal or not +#router_name = "noop_router" diff --git a/evaluation/benchmarks/gaia/run_infer.py b/evaluation/benchmarks/gaia/run_infer.py index 480df59bd2..d8aee7ec34 100644 --- a/evaluation/benchmarks/gaia/run_infer.py +++ b/evaluation/benchmarks/gaia/run_infer.py @@ -28,6 +28,7 @@ from evaluation.utils.shared import ( prepare_dataset, reset_logger_for_multiprocessing, run_evaluation, + update_llm_config_for_completions_logging, ) from openhands.controller.state.state import State from openhands.core.config import ( @@ -36,7 +37,11 @@ from openhands.core.config import ( get_llm_config_arg, load_from_toml, ) -from openhands.core.config.utils import get_agent_config_arg +from openhands.core.config.utils import ( + get_agent_config_arg, + get_llms_for_routing_config, + get_model_routing_config_arg, +) from openhands.core.logger import openhands_logger as logger from openhands.core.main import create_runtime, run_controller from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction @@ -57,6 +62,7 @@ AGENT_CLS_TO_INST_SUFFIX = { def get_config( + instance: pd.Series, metadata: EvalMetadata, ) -> OpenHandsConfig: sandbox_config = get_default_sandbox_config_for_eval() @@ -66,13 +72,24 @@ def get_config( sandbox_config=sandbox_config, runtime='docker', ) - config.set_llm_config(metadata.llm_config) + config.set_llm_config( + update_llm_config_for_completions_logging( + metadata.llm_config, metadata.eval_output_dir, instance['instance_id'] + ) + ) + model_routing_config = get_model_routing_config_arg() + model_routing_config.llms_for_routing = ( + get_llms_for_routing_config() + ) # Populate with LLMs for routing from config.toml file + if metadata.agent_config: + metadata.agent_config.model_routing = model_routing_config config.set_agent_config(metadata.agent_config, metadata.agent_class) else: logger.info('Agent config not provided, using default settings') agent_config = config.get_agent_config(metadata.agent_class) agent_config.enable_prompt_extensions = False + agent_config.model_routing = model_routing_config config_copy = copy.deepcopy(config) load_from_toml(config_copy) @@ -145,7 +162,7 @@ def process_instance( metadata: EvalMetadata, reset_logger: bool = True, ) -> EvalOutput: - config = get_config(metadata) + config = get_config(instance, metadata) # Setup the logger properly, so you can run multi-processing to parallelize the evaluation if reset_logger: diff --git a/evaluation/benchmarks/swe_bench/run_infer.py b/evaluation/benchmarks/swe_bench/run_infer.py index c3e9764152..2b86cc3baa 100644 --- a/evaluation/benchmarks/swe_bench/run_infer.py +++ b/evaluation/benchmarks/swe_bench/run_infer.py @@ -47,6 +47,8 @@ from openhands.core.config import ( get_agent_config_arg, get_evaluation_parser, get_llm_config_arg, + get_llms_for_routing_config, + get_model_routing_config_arg, ) from openhands.core.config.condenser_config import NoOpCondenserConfig from openhands.core.config.utils import get_condenser_config_arg @@ -244,6 +246,11 @@ def get_config( # get 'draft_editor' config if exists config.set_llm_config(get_llm_config_arg('draft_editor'), 'draft_editor') + model_routing_config = get_model_routing_config_arg() + model_routing_config.llms_for_routing = ( + get_llms_for_routing_config() + ) # Populate with LLMs for routing from config.toml file + agent_config = AgentConfig( enable_jupyter=False, enable_browsing=RUN_WITH_BROWSING, @@ -251,8 +258,10 @@ def get_config( enable_mcp=False, condenser=metadata.condenser_config, enable_prompt_extensions=False, + model_routing=model_routing_config, ) config.set_agent_config(agent_config) + return config diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 1f4686c0f0..f73ff2b601 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -92,6 +92,9 @@ class CodeActAgent(Agent): self.condenser = Condenser.from_config(self.config.condenser, llm_registry) logger.debug(f'Using condenser: {type(self.condenser)}') + # Override with router if needed + self.llm = self.llm_registry.get_router(self.config) + @property def prompt_manager(self) -> PromptManager: if self._prompt_manager is None: diff --git a/openhands/core/config/__init__.py b/openhands/core/config/__init__.py index 97d71bd5f1..df7f745bef 100644 --- a/openhands/core/config/__init__.py +++ b/openhands/core/config/__init__.py @@ -13,6 +13,7 @@ from openhands.core.config.config_utils import ( from openhands.core.config.extended_config import ExtendedConfig from openhands.core.config.llm_config import LLMConfig from openhands.core.config.mcp_config import MCPConfig +from openhands.core.config.model_routing_config import ModelRoutingConfig from openhands.core.config.openhands_config import OpenHandsConfig from openhands.core.config.sandbox_config import SandboxConfig from openhands.core.config.security_config import SecurityConfig @@ -20,6 +21,8 @@ from openhands.core.config.utils import ( finalize_config, get_agent_config_arg, get_llm_config_arg, + get_llms_for_routing_config, + get_model_routing_config_arg, load_from_env, load_from_toml, load_openhands_config, @@ -37,6 +40,7 @@ __all__ = [ 'LLMConfig', 'SandboxConfig', 'SecurityConfig', + 'ModelRoutingConfig', 'ExtendedConfig', 'load_openhands_config', 'load_from_env', @@ -50,4 +54,6 @@ __all__ = [ 'get_evaluation_parser', 'parse_arguments', 'setup_config_from_args', + 'get_model_routing_config_arg', + 'get_llms_for_routing_config', ] diff --git a/openhands/core/config/agent_config.py b/openhands/core/config/agent_config.py index 12f64bc9f3..fac57ce1fb 100644 --- a/openhands/core/config/agent_config.py +++ b/openhands/core/config/agent_config.py @@ -7,6 +7,7 @@ from openhands.core.config.condenser_config import ( ConversationWindowCondenserConfig, ) from openhands.core.config.extended_config import ExtendedConfig +from openhands.core.config.model_routing_config import ModelRoutingConfig from openhands.core.logger import openhands_logger as logger from openhands.utils.import_utils import get_impl @@ -57,6 +58,8 @@ class AgentConfig(BaseModel): # handled. default_factory=lambda: ConversationWindowCondenserConfig() ) + model_routing: ModelRoutingConfig = Field(default_factory=ModelRoutingConfig) + """Model routing configuration settings.""" extended: ExtendedConfig = Field(default_factory=lambda: ExtendedConfig({})) """Extended configuration for the agent.""" diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index 40d27ea78b..f8f1bce726 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -46,6 +46,7 @@ class LLMConfig(BaseModel): reasoning_effort: The effort to put into reasoning. This is a string that can be one of 'low', 'medium', 'high', or 'none'. Can apply to all reasoning models. seed: The seed to use for the LLM. safety_settings: Safety settings for models that support them (like Mistral AI and Gemini). + for_routing: Whether this LLM is used for routing. This is set to True for models used in conjunction with the main LLM in the model routing feature. """ model: str = Field(default='claude-sonnet-4-20250514') @@ -92,6 +93,7 @@ class LLMConfig(BaseModel): default=None, description='Safety settings for models that support them (like Mistral AI and Gemini)', ) + for_routing: bool = Field(default=False) model_config = ConfigDict(extra='forbid') diff --git a/openhands/core/config/model_routing_config.py b/openhands/core/config/model_routing_config.py new file mode 100644 index 0000000000..9377a5b097 --- /dev/null +++ b/openhands/core/config/model_routing_config.py @@ -0,0 +1,39 @@ +from pydantic import BaseModel, ConfigDict, Field, ValidationError + +from openhands.core.config.llm_config import LLMConfig + + +class ModelRoutingConfig(BaseModel): + """Configuration for model routing. + + Attributes: + router_name (str): The name of the router to use. Default is 'noop_router'. + llms_for_routing (dict[str, LLMConfig]): A dictionary mapping config names of LLMs for routing to their configurations. + """ + + router_name: str = Field(default='noop_router') + llms_for_routing: dict[str, LLMConfig] = Field(default_factory=dict) + + model_config = ConfigDict(extra='forbid') + + @classmethod + def from_toml_section(cls, data: dict) -> dict[str, 'ModelRoutingConfig']: + """ + Create a mapping of ModelRoutingConfig instances from a toml dictionary representing the [model_routing] section. + + The configuration is built from all keys in data. + + Returns: + dict[str, ModelRoutingConfig]: A mapping where the key "model_routing" corresponds to the [model_routing] configuration + """ + + # Initialize the result mapping + model_routing_mapping: dict[str, ModelRoutingConfig] = {} + + # Try to create the configuration instance + try: + model_routing_mapping['model_routing'] = cls.model_validate(data) + except ValidationError as e: + raise ValueError(f'Invalid model routing configuration: {e}') + + return model_routing_mapping diff --git a/openhands/core/config/openhands_config.py b/openhands/core/config/openhands_config.py index a57053f36a..24990bc770 100644 --- a/openhands/core/config/openhands_config.py +++ b/openhands/core/config/openhands_config.py @@ -30,6 +30,7 @@ class OpenHandsConfig(BaseModel): The default configuration is stored under the 'agent' key. default_agent: Name of the default agent to use. sandbox: Sandbox configuration settings. + security: Security configuration settings. runtime: Runtime environment identifier. file_store: Type of file store to use. file_store_path: Path to the file store. diff --git a/openhands/core/config/utils.py b/openhands/core/config/utils.py index cf95e75d7d..59ded7d598 100644 --- a/openhands/core/config/utils.py +++ b/openhands/core/config/utils.py @@ -25,6 +25,7 @@ from openhands.core.config.extended_config import ExtendedConfig from openhands.core.config.kubernetes_config import KubernetesConfig from openhands.core.config.llm_config import LLMConfig from openhands.core.config.mcp_config import MCPConfig +from openhands.core.config.model_routing_config import ModelRoutingConfig from openhands.core.config.openhands_config import OpenHandsConfig from openhands.core.config.sandbox_config import SandboxConfig from openhands.core.config.security_config import SecurityConfig @@ -225,6 +226,35 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None # Re-raise ValueError from SecurityConfig.from_toml_section raise ValueError('Error in [security] section in config.toml') + if 'model_routing' in toml_config: + try: + model_routing_mapping = ModelRoutingConfig.from_toml_section( + toml_config['model_routing'] + ) + # We only use the base model routing config for now + if 'model_routing' in model_routing_mapping: + default_agent_config = cfg.get_agent_config() + default_agent_config.model_routing = model_routing_mapping[ + 'model_routing' + ] + + # Construct the llms_for_routing by filtering llms with for_routing = True + llms_for_routing_dict = {} + for llm_name, llm_config in cfg.llms.items(): + if llm_config and llm_config.for_routing: + llms_for_routing_dict[llm_name] = llm_config + default_agent_config.model_routing.llms_for_routing = ( + llms_for_routing_dict + ) + + logger.openhands_logger.debug( + 'Default model routing configuration loaded from config toml and assigned to default agent' + ) + except (TypeError, KeyError, ValidationError) as e: + logger.openhands_logger.warning( + f'Cannot parse [model_routing] config from toml, values have not been applied.\nError: {e}' + ) + # Process sandbox section if present if 'sandbox' in toml_config: try: @@ -327,6 +357,7 @@ def load_from_toml(cfg: OpenHandsConfig, toml_file: str = 'config.toml') -> None 'condenser', 'mcp', 'kubernetes', + 'model_routing', } for key in toml_config: if key.lower() not in known_sections: @@ -559,6 +590,41 @@ def get_llm_config_arg( return None +def get_llms_for_routing_config(toml_file: str = 'config.toml') -> dict[str, LLMConfig]: + """Get the LLMs that are configured for routing from the config file. + + This function will return a dictionary of LLMConfig objects that are configured + for routing, i.e., those with `for_routing` set to True. + + Args: + toml_file: Path to the configuration file to read from. Defaults to 'config.toml'. + + Returns: + dict[str, LLMConfig]: A dictionary of LLMConfig objects for routing. + """ + llms_for_routing: dict[str, LLMConfig] = {} + + try: + with open(toml_file, 'r', encoding='utf-8') as toml_contents: + toml_config = toml.load(toml_contents) + except FileNotFoundError: + return llms_for_routing + except toml.TomlDecodeError as e: + logger.openhands_logger.error( + f'Cannot parse LLM configs from {toml_file}. Exception: {e}' + ) + return llms_for_routing + + llm_configs = LLMConfig.from_toml_section(toml_config.get('llm', {})) + + if llm_configs: + for llm_name, llm_config in llm_configs.items(): + if llm_config.for_routing: + llms_for_routing[llm_name] = llm_config + + return llms_for_routing + + def get_condenser_config_arg( condenser_config_arg: str, toml_file: str = 'config.toml' ) -> CondenserConfig | None: @@ -671,6 +737,50 @@ def get_condenser_config_arg( return None +def get_model_routing_config_arg(toml_file: str = 'config.toml') -> ModelRoutingConfig: + """Get the model routing settings from the config file. We only support the default model routing config [model_routing]. + + Args: + toml_file: Path to the configuration file to read from. Defaults to 'config.toml'. + + Returns: + ModelRoutingConfig: The ModelRoutingConfig object with the settings from the config file, or the object with default values if not found/error. + """ + logger.openhands_logger.debug( + f"Loading model routing config ['model_routing'] from {toml_file}" + ) + default_cfg = ModelRoutingConfig() + + # load the toml file + try: + with open(toml_file, 'r', encoding='utf-8') as toml_contents: + toml_config = toml.load(toml_contents) + except FileNotFoundError as e: + logger.openhands_logger.error(f'Config file not found: {toml_file}. Error: {e}') + return default_cfg + except toml.TomlDecodeError as e: + logger.openhands_logger.error( + f'Cannot parse model routing group [model_routing] from {toml_file}. Exception: {e}' + ) + return default_cfg + + # Update the model routing config with the specified section + if 'model_routing' in toml_config: + try: + model_routing_data = toml_config['model_routing'] + return ModelRoutingConfig(**model_routing_data) + except ValidationError as e: + logger.openhands_logger.error( + f'Invalid model routing configuration for [model_routing]: {e}' + ) + return default_cfg + + logger.openhands_logger.warning( + f'Model routing config section [model_routing] not found in {toml_file}' + ) + return default_cfg + + def parse_arguments() -> argparse.Namespace: """Parse command line arguments.""" parser = get_headless_parser() diff --git a/openhands/llm/llm_registry.py b/openhands/llm/llm_registry.py index f329ae7f1a..941f80b9b3 100644 --- a/openhands/llm/llm_registry.py +++ b/openhands/llm/llm_registry.py @@ -108,10 +108,24 @@ class LLMRegistry: def get_active_llm(self) -> LLM: return self.active_agent_llm - def _set_active_llm(self, service_id) -> None: - if service_id not in self.service_to_llm: - raise ValueError(f'Unrecognized service ID: {service_id}') - self.active_agent_llm = self.service_to_llm[service_id] + def get_router(self, agent_config: AgentConfig) -> 'LLM': + """ + Get a router instance that inherits from LLM. + """ + # Import here to avoid circular imports + from openhands.llm.router import RouterLLM + + router_name = agent_config.model_routing.router_name + + if router_name == 'noop_router': + # Return the main LLM directly (no routing) + return self.get_llm_from_agent_config('agent', agent_config) + + return RouterLLM.from_config( + agent_config=agent_config, + llm_registry=self, + retry_listener=self.retry_listner, + ) def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None: self.subscriber = callback diff --git a/openhands/llm/router/README.md b/openhands/llm/router/README.md new file mode 100644 index 0000000000..1ed6223612 --- /dev/null +++ b/openhands/llm/router/README.md @@ -0,0 +1,39 @@ +# Model Routing Module + +**⚠️ Experimental Feature**: This module is experimental and under active development. + +## Overview + +Model routing enables OpenHands to switch between different LLM models during a conversation. An example use case is routing between a primary (expensive, multimodal) model and a secondary (cheaper, text-only) model. + +## Available Routers + +- **`noop_router`** (default): No routing, always uses primary LLM +- **`multimodal_router`**: A router that switches based on: + - Routes to primary model for images or when secondary model's context limit is exceeded + - Uses secondary model for text-only requests within its context limit + +## Configuration + +Add to your `config.toml`: + +```toml +# Main LLM (primary model) +[llm] +model = "claude-sonnet-4" +api_key = "your-api-key" + +# Secondary model for routing +[llm.secondary_model] +model = "kimi-k2" +api_key = "your-api-key" +for_routing = true + +# Enable routing +[model_routing] +router_name = "multimodal_router" +``` + +## Extending + +Create custom routers by inheriting from `BaseRouter` and implementing `set_active_llm()`. Register in `ROUTER_REGISTRY`. diff --git a/openhands/llm/router/__init__.py b/openhands/llm/router/__init__.py new file mode 100644 index 0000000000..877bbbb477 --- /dev/null +++ b/openhands/llm/router/__init__.py @@ -0,0 +1,8 @@ +from openhands.llm.router.base import ROUTER_LLM_REGISTRY, RouterLLM +from openhands.llm.router.rule_based.impl import MultimodalRouter + +__all__ = [ + 'RouterLLM', + 'ROUTER_LLM_REGISTRY', + 'MultimodalRouter', +] diff --git a/openhands/llm/router/base.py b/openhands/llm/router/base.py new file mode 100644 index 0000000000..84143b3785 --- /dev/null +++ b/openhands/llm/router/base.py @@ -0,0 +1,164 @@ +import copy +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Callable + +from openhands.core.config import AgentConfig +from openhands.core.logger import openhands_logger as logger +from openhands.core.message import Message +from openhands.llm.llm import LLM +from openhands.llm.metrics import Metrics + +if TYPE_CHECKING: + from openhands.llm.llm_registry import LLMRegistry + +ROUTER_LLM_REGISTRY: dict[str, type['RouterLLM']] = {} + + +class RouterLLM(LLM): + """ + Base class for multiple LLM acting as a unified LLM. + + This class provides a foundation for implementing model routing by inheriting from LLM, + allowing routers to work with multiple underlying LLM models while presenting a unified + LLM interface to consumers. + + Key features: + - Works with multiple LLMs configured via llms_for_routing + - Delegates all other operations/properties to the selected LLM + - Provides routing interface through _select_llm() method + """ + + def __init__( + self, + agent_config: AgentConfig, + llm_registry: 'LLMRegistry', + service_id: str = 'router_llm', + metrics: Metrics | None = None, + retry_listener: Callable[[int, int], None] | None = None, + ): + """ + Initialize RouterLLM with multiple LLM support. + """ + self.llm_registry = llm_registry + self.model_routing_config = agent_config.model_routing + + # Get the primary agent LLM + self.primary_llm = llm_registry.get_llm_from_agent_config('agent', agent_config) + + # Instantiate all the LLM instances for routing + llms_for_routing_config = self.model_routing_config.llms_for_routing + self.llms_for_routing = { + config_name: self.llm_registry.get_llm( + f'llm_for_routing.{config_name}', config=llm_config + ) + for config_name, llm_config in llms_for_routing_config.items() + } + + # All available LLMs for routing (set this BEFORE calling super().__init__) + self.available_llms = {'primary': self.primary_llm, **self.llms_for_routing} + + # Create router config based on primary LLM + router_config = copy.deepcopy(self.primary_llm.config) + + # Update model name to indicate this is a router + llm_names = [self.primary_llm.config.model] + if self.model_routing_config.llms_for_routing: + llm_names.extend( + config.model + for config in self.model_routing_config.llms_for_routing.values() + ) + router_config.model = f'router({",".join(llm_names)})' + + # Initialize parent LLM class + super().__init__( + config=router_config, + service_id=service_id, + metrics=metrics, + retry_listener=retry_listener, + ) + + # Current LLM state + self._current_llm = self.primary_llm # Default to primary LLM + self._last_routing_decision = 'primary' + + logger.info( + f'RouterLLM initialized with {len(self.available_llms)} LLMs: {list(self.available_llms.keys())}' + ) + + @abstractmethod + def _select_llm(self, messages: list[Message]) -> str: + """ + Select which LLM to use based on messages and events. + """ + pass + + def _get_llm_by_key(self, llm_key: str) -> LLM: + """ + Get LLM instance by key. + """ + if llm_key not in self.available_llms: + raise ValueError( + f'Unknown LLM key: {llm_key}. Available: {list(self.available_llms.keys())}' + ) + return self.available_llms[llm_key] + + @property + def completion(self) -> Callable: + """ + Override completion to route to appropriate LLM. + + This method intercepts completion calls and routes them to the appropriate + underlying LLM based on the routing logic implemented in _select_llm(). + """ + + def router_completion(*args: Any, **kwargs: Any) -> Any: + # Extract messages for routing decision + messages = kwargs.get('messages', []) + if args and not messages: + messages = args[0] if args else [] + + # Select appropriate LLM + selected_llm_key = self._select_llm(messages) + selected_llm = self._get_llm_by_key(selected_llm_key) + + # Update current state + self._current_llm = selected_llm + self._last_routing_decision = selected_llm_key + + logger.debug( + f'RouterLLM routing to {selected_llm_key} ({selected_llm.config.model})' + ) + + # Delegate to selected LLM + return selected_llm.completion(*args, **kwargs) + + return router_completion + + def __str__(self) -> str: + """String representation of the router.""" + return f'{self.__class__.__name__}(llms={list(self.available_llms.keys())})' + + def __repr__(self) -> str: + """Detailed string representation of the router.""" + return ( + f'{self.__class__.__name__}(' + f'primary={self.primary_llm.config.model}, ' + f'routing_llms={[llm.config.model for llm in self.llms_for_routing.values()]}, ' + f'current={self._last_routing_decision})' + ) + + def __getattr__(self, name): + """Delegate other attributes/methods to the active LLM.""" + return getattr(self._current_llm, name) + + @classmethod + def from_config( + cls, llm_registry: 'LLMRegistry', agent_config: AgentConfig, **kwargs + ) -> 'RouterLLM': + """Factory method to create a RouterLLM instance from configuration.""" + router_cls = ROUTER_LLM_REGISTRY.get(agent_config.model_routing.router_name) + if not router_cls: + raise ValueError( + f'Router LLM {agent_config.model_routing.router_name} not found.' + ) + return router_cls(agent_config, llm_registry, **kwargs) diff --git a/openhands/llm/router/rule_based/impl.py b/openhands/llm/router/rule_based/impl.py new file mode 100644 index 0000000000..81cb5bb3a4 --- /dev/null +++ b/openhands/llm/router/rule_based/impl.py @@ -0,0 +1,74 @@ +from openhands.core.config import AgentConfig +from openhands.core.logger import openhands_logger as logger +from openhands.core.message import Message +from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry +from openhands.llm.router.base import ROUTER_LLM_REGISTRY, RouterLLM + + +class MultimodalRouter(RouterLLM): + SECONDARY_MODEL_CONFIG_NAME = 'secondary_model' + ROUTER_NAME = 'multimodal_router' + + def __init__( + self, + agent_config: AgentConfig, + llm_registry: LLMRegistry, + **kwargs, + ): + super().__init__(agent_config, llm_registry, **kwargs) + + self._validate_model_routing_config(self.llms_for_routing) + + # States + self.max_token_exceeded = False + + def _select_llm(self, messages: list[Message]) -> str: + """Select LLM based on multimodal content and token limits.""" + route_to_primary = False + + # Check for multimodal content in messages + for message in messages: + if message.contains_image: + logger.info( + 'Multimodal content detected in messages. Routing to the primary model.' + ) + route_to_primary = True + + if not route_to_primary and self.max_token_exceeded: + route_to_primary = True + + # Check if `messages` exceeds context window of the secondary model + # Assuming the secondary model has a lower context window limit compared to the primary model + secondary_llm = self.available_llms.get(self.SECONDARY_MODEL_CONFIG_NAME) + if secondary_llm and ( + secondary_llm.config.max_input_tokens + and secondary_llm.get_token_count(messages) + > secondary_llm.config.max_input_tokens + ): + logger.warning( + f"Messages having {secondary_llm.get_token_count(messages)} tokens, exceed secondary model's max input tokens ({secondary_llm.config.max_input_tokens} tokens). " + 'Routing to the primary model.' + ) + self.max_token_exceeded = True + route_to_primary = True + + if route_to_primary: + logger.info('Routing to the primary model...') + return 'primary' + else: + logger.info('Routing to the secondary model...') + return self.SECONDARY_MODEL_CONFIG_NAME + + def vision_is_active(self): + return self.primary_llm.vision_is_active() + + def _validate_model_routing_config(self, llms_for_routing: dict[str, LLM]): + if self.SECONDARY_MODEL_CONFIG_NAME not in llms_for_routing: + raise ValueError( + f'Secondary LLM config {self.SECONDARY_MODEL_CONFIG_NAME} not found.' + ) + + +# Register the router +ROUTER_LLM_REGISTRY[MultimodalRouter.ROUTER_NAME] = MultimodalRouter