From 25d9cf28905c2f5ece4e2a5fb6256952b27ad620 Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Mon, 18 Aug 2025 02:11:20 -0400 Subject: [PATCH] [Refactor]: Add LLMRegistry for llm services (#9589) Co-authored-by: openhands Co-authored-by: Graham Neubig Co-authored-by: Engel Nyst --- .../agenthub/browsing_agent/browsing_agent.py | 6 +- .../agenthub/codeact_agent/codeact_agent.py | 14 +- openhands/agenthub/dummy_agent/agent.py | 10 +- openhands/agenthub/loc_agent/loc_agent.py | 7 +- .../agenthub/readonly_agent/readonly_agent.py | 8 +- .../visualbrowsing_agent.py | 10 +- openhands/cli/main.py | 12 +- openhands/controller/agent.py | 8 +- openhands/controller/agent_controller.py | 66 +-- openhands/controller/state/state.py | 12 +- openhands/controller/state/state_tracker.py | 39 +- openhands/core/config/openhands_config.py | 7 +- openhands/core/main.py | 18 +- openhands/core/setup.py | 18 +- openhands/llm/llm.py | 11 +- openhands/llm/llm_registry.py | 132 +++++ openhands/memory/condenser/condenser.py | 7 +- .../impl/amortized_forgetting_condenser.py | 5 +- .../impl/browser_output_condenser.py | 3 +- .../impl/conversation_window_condenser.py | 5 +- .../condenser/impl/llm_attention_condenser.py | 16 +- .../impl/llm_summarizing_condenser.py | 8 +- .../memory/condenser/impl/no_op_condenser.py | 5 +- .../impl/observation_masking_condenser.py | 5 +- openhands/memory/condenser/impl/pipeline.py | 7 +- .../condenser/impl/recent_events_condenser.py | 5 +- .../impl/structured_summary_condenser.py | 17 +- .../resolver/interfaces/issue_definitions.py | 2 +- openhands/resolver/issue_resolver.py | 4 +- openhands/resolver/send_pull_request.py | 2 +- openhands/runtime/base.py | 6 +- .../action_execution_client.py | 3 + openhands/runtime/impl/cli/cli_runtime.py | 3 + .../runtime/impl/docker/docker_runtime.py | 3 + .../impl/kubernetes/kubernetes_runtime.py | 3 + openhands/runtime/impl/local/local_runtime.py | 11 +- .../runtime/impl/remote/remote_runtime.py | 3 + openhands/runtime/utils/edit.py | 15 +- .../conversation_manager.py | 11 + .../docker_nested_conversation_manager.py | 30 +- .../standalone_conversation_manager.py | 34 +- .../server/conversation_manager/utils.py | 0 .../server/routes/manage_conversations.py | 14 +- .../server/services/conversation_service.py | 144 +++-- .../server/services/conversation_stats.py | 77 +++ openhands/server/session/agent_session.py | 28 +- openhands/server/session/conversation.py | 2 + openhands/server/session/session.py | 45 +- openhands/storage/locations.py | 8 + openhands/utils/conversation_summary.py | 22 +- openhands/utils/utils.py | 37 ++ pytest.ini | 1 + tests/__init__.py | 0 tests/runtime/conftest.py | 5 + tests/unit/__init__.py | 0 tests/unit/llm/test_acompletion.py | 2 +- tests/unit/llm/test_llm.py | 92 ++-- .../test_issue_handler_error_handling.py | 4 +- .../resolver/github/test_resolve_issues.py | 2 +- ...est_gitlab_issue_handler_error_handling.py | 4 +- .../gitlab/test_gitlab_resolve_issues.py | 2 +- tests/unit/test_agent_controller.py | 432 ++++++++++----- tests/unit/test_agent_delegation.py | 118 ++++- tests/unit/test_agent_session.py | 186 ++++--- tests/unit/test_agents.py | 62 ++- tests/unit/test_api_connection_error_retry.py | 4 +- tests/unit/test_auto_generate_title.py | 106 ++-- tests/unit/test_cli.py | 9 +- .../test_cli_openhands_provider_auth_error.py | 9 +- tests/unit/test_cli_runtime_mcp.py | 7 +- tests/unit/test_cli_workspace.py | 11 +- tests/unit/test_condenser.py | 216 +++++--- tests/unit/test_conversation_stats.py | 490 ++++++++++++++++++ tests/unit/test_conversation_summary.py | 74 ++- tests/unit/test_docker_runtime.py | 14 +- tests/unit/test_llm_registry.py | 178 +++++++ tests/unit/test_mcp_config.py | 4 + tests/unit/test_mcp_tool_timeout_stall.py | 28 +- tests/unit/test_memory.py | 40 +- tests/unit/test_prompt_caching.py | 26 +- tests/unit/test_runtime_git_tokens.py | 38 +- tests/unit/test_runtime_gitlab_microagents.py | 15 +- tests/unit/test_security.py | 4 +- tests/unit/test_session.py | 42 +- 84 files changed, 2376 insertions(+), 817 deletions(-) create mode 100644 openhands/llm/llm_registry.py create mode 100644 openhands/server/conversation_manager/utils.py create mode 100644 openhands/server/services/conversation_stats.py create mode 100644 openhands/utils/utils.py create mode 100644 tests/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_conversation_stats.py create mode 100644 tests/unit/test_llm_registry.py diff --git a/openhands/agenthub/browsing_agent/browsing_agent.py b/openhands/agenthub/browsing_agent/browsing_agent.py index 9fb4441e19..bf2c0960a0 100644 --- a/openhands/agenthub/browsing_agent/browsing_agent.py +++ b/openhands/agenthub/browsing_agent/browsing_agent.py @@ -18,7 +18,7 @@ from openhands.events.action import ( from openhands.events.event import EventSource from openhands.events.observation import BrowserOutputObservation from openhands.events.observation.observation import Observation -from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.plugins import ( PluginRequirement, ) @@ -102,15 +102,15 @@ class BrowsingAgent(Agent): def __init__( self, - llm: LLM, config: AgentConfig, + llm_registry: LLMRegistry, ) -> None: """Initializes a new instance of the BrowsingAgent class. Parameters: - llm (LLM): The llm to be used by this agent """ - super().__init__(llm, config) + super().__init__(config, llm_registry) # define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML. # see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details action_subsets = ['chat', 'bid'] diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index a32cd680da..402d4ba1d1 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -3,6 +3,8 @@ import sys from collections import deque from typing import TYPE_CHECKING +from openhands.llm.llm_registry import LLMRegistry + if TYPE_CHECKING: from litellm import ChatCompletionToolParam @@ -32,7 +34,6 @@ from openhands.core.logger import openhands_logger as logger from openhands.core.message import Message from openhands.events.action import AgentFinishAction, MessageAction from openhands.events.event import Event -from openhands.llm.llm import LLM from openhands.llm.llm_utils import check_tools from openhands.memory.condenser import Condenser from openhands.memory.condenser.condenser import Condensation, View @@ -74,18 +75,13 @@ class CodeActAgent(Agent): JupyterRequirement(), ] - def __init__( - self, - llm: LLM, - config: AgentConfig, - ) -> None: + def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None: """Initializes a new instance of the CodeActAgent class. Parameters: - - llm (LLM): The llm to be used by this agent - config (AgentConfig): The configuration for this agent """ - super().__init__(llm, config) + super().__init__(config, llm_registry) self.pending_actions: deque['Action'] = deque() self.reset() self.tools = self._get_tools() @@ -93,7 +89,7 @@ class CodeActAgent(Agent): # Create a ConversationMemory instance self.conversation_memory = ConversationMemory(self.config, self.prompt_manager) - self.condenser = Condenser.from_config(self.config.condenser) + self.condenser = Condenser.from_config(self.config.condenser, llm_registry) logger.debug(f'Using condenser: {type(self.condenser)}') @property diff --git a/openhands/agenthub/dummy_agent/agent.py b/openhands/agenthub/dummy_agent/agent.py index 0d644a60cd..4d2531e03a 100644 --- a/openhands/agenthub/dummy_agent/agent.py +++ b/openhands/agenthub/dummy_agent/agent.py @@ -22,7 +22,7 @@ from openhands.events.observation import ( Observation, ) from openhands.events.serialization.event import event_to_dict -from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry """ FIXME: There are a few problems this surfaced @@ -42,8 +42,12 @@ class DummyAgent(Agent): without making any LLM calls. """ - def __init__(self, llm: LLM, config: AgentConfig): - super().__init__(llm, config) + def __init__( + self, + config: AgentConfig, + llm_registry: LLMRegistry, + ): + super().__init__(config, llm_registry) self.steps: list[ActionObs] = [ { 'action': MessageAction('Time to get started!'), diff --git a/openhands/agenthub/loc_agent/loc_agent.py b/openhands/agenthub/loc_agent/loc_agent.py index 9fbc4c6150..516323d106 100644 --- a/openhands/agenthub/loc_agent/loc_agent.py +++ b/openhands/agenthub/loc_agent/loc_agent.py @@ -4,7 +4,7 @@ import openhands.agenthub.loc_agent.function_calling as locagent_function_callin from openhands.agenthub.codeact_agent import CodeActAgent from openhands.core.config import AgentConfig from openhands.core.logger import openhands_logger as logger -from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry if TYPE_CHECKING: from openhands.events.action import Action @@ -16,8 +16,8 @@ class LocAgent(CodeActAgent): def __init__( self, - llm: LLM, config: AgentConfig, + llm_registry: LLMRegistry, ) -> None: """Initializes a new instance of the LocAgent class. @@ -25,7 +25,8 @@ class LocAgent(CodeActAgent): - llm (LLM): The llm to be used by this agent - config (AgentConfig): The configuration for the agent """ - super().__init__(llm, config) + + super().__init__(config, llm_registry) self.tools = locagent_function_calling.get_tools() logger.debug( diff --git a/openhands/agenthub/readonly_agent/readonly_agent.py b/openhands/agenthub/readonly_agent/readonly_agent.py index efccdbf43c..2e9a54b416 100644 --- a/openhands/agenthub/readonly_agent/readonly_agent.py +++ b/openhands/agenthub/readonly_agent/readonly_agent.py @@ -3,6 +3,8 @@ import os from typing import TYPE_CHECKING +from openhands.llm.llm_registry import LLMRegistry + if TYPE_CHECKING: from litellm import ChatCompletionToolParam @@ -15,7 +17,6 @@ from openhands.agenthub.readonly_agent import ( ) from openhands.core.config import AgentConfig from openhands.core.logger import openhands_logger as logger -from openhands.llm.llm import LLM from openhands.utils.prompt import PromptManager @@ -37,17 +38,16 @@ class ReadOnlyAgent(CodeActAgent): def __init__( self, - llm: LLM, config: AgentConfig, + llm_registry: LLMRegistry, ) -> None: """Initializes a new instance of the ReadOnlyAgent class. Parameters: - - llm (LLM): The llm to be used by this agent - config (AgentConfig): The configuration for this agent """ # Initialize the CodeActAgent class; some of it is overridden with class methods - super().__init__(llm, config) + super().__init__(config, llm_registry) logger.debug( f'TOOLS loaded for ReadOnlyAgent: {", ".join([tool.get("function").get("name") for tool in self.tools])}' diff --git a/openhands/agenthub/visualbrowsing_agent/visualbrowsing_agent.py b/openhands/agenthub/visualbrowsing_agent/visualbrowsing_agent.py index 7ae484f018..322629d3a3 100644 --- a/openhands/agenthub/visualbrowsing_agent/visualbrowsing_agent.py +++ b/openhands/agenthub/visualbrowsing_agent/visualbrowsing_agent.py @@ -16,7 +16,7 @@ from openhands.events.action import ( from openhands.events.event import EventSource from openhands.events.observation import BrowserOutputObservation from openhands.events.observation.observation import Observation -from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.plugins import ( PluginRequirement, ) @@ -127,17 +127,13 @@ class VisualBrowsingAgent(Agent): sandbox_plugins: list[PluginRequirement] = [] response_parser = BrowsingResponseParser() - def __init__( - self, - llm: LLM, - config: AgentConfig, - ) -> None: + def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None: """Initializes a new instance of the VisualBrowsingAgent class. Parameters: - llm (LLM): The llm to be used by this agent """ - super().__init__(llm, config) + super().__init__(config, llm_registry) # define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML. # see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details action_subsets = [ diff --git a/openhands/cli/main.py b/openhands/cli/main.py index 642dd6e05f..0c7cf49621 100644 --- a/openhands/cli/main.py +++ b/openhands/cli/main.py @@ -83,6 +83,7 @@ from openhands.microagent.microagent import BaseMicroagent from openhands.runtime import get_runtime_cls from openhands.runtime.base import Runtime from openhands.storage.settings.file_settings_store import FileSettingsStore +from openhands.utils.utils import create_registry_and_convo_stats async def cleanup_session( @@ -147,9 +148,16 @@ async def run_session( None, display_initialization_animation, 'Initializing...', is_loaded ) - agent = create_agent(config) + llm_registry, convo_stats, config = create_registry_and_convo_stats( + config, + sid, + None, + ) + + agent = create_agent(config, llm_registry) runtime = create_runtime( config, + llm_registry, sid=sid, headless_mode=True, agent=agent, @@ -161,7 +169,7 @@ async def run_session( runtime.subscribe_to_shell_stream(stream_to_console) - controller, initial_state = create_controller(agent, runtime, config) + controller, initial_state = create_controller(agent, runtime, config, convo_stats) event_stream = runtime.event_stream diff --git a/openhands/controller/agent.py b/openhands/controller/agent.py index 6c49b58b32..da8007b4a0 100644 --- a/openhands/controller/agent.py +++ b/openhands/controller/agent.py @@ -3,6 +3,8 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING +from openhands.llm.llm_registry import LLMRegistry + if TYPE_CHECKING: from openhands.controller.state.state import State from openhands.events.action import Action @@ -17,7 +19,6 @@ from openhands.core.exceptions import ( ) from openhands.core.logger import openhands_logger as logger from openhands.events.event import EventSource -from openhands.llm.llm import LLM from openhands.runtime.plugins import PluginRequirement @@ -38,10 +39,11 @@ class Agent(ABC): def __init__( self, - llm: LLM, config: AgentConfig, + llm_registry: LLMRegistry, ): - self.llm = llm + self.llm = llm_registry.get_llm_from_agent_config('agent', config) + self.llm_registry = llm_registry self.config = config self._complete = False self._prompt_manager: 'PromptManager' | None = None diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 2952cfcb1c..b6c52708c9 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -73,9 +73,9 @@ from openhands.events.observation import ( Observation, ) from openhands.events.serialization.event import truncate_content -from openhands.llm.llm import LLM from openhands.llm.metrics import Metrics from openhands.runtime.runtime_status import RuntimeStatus +from openhands.server.services.conversation_stats import ConversationStats from openhands.storage.files import FileStore # note: RESUME is only available on web GUI @@ -109,6 +109,7 @@ class AgentController: self, agent: Agent, event_stream: EventStream, + convo_stats: ConversationStats, iteration_delta: int, budget_per_task_delta: float | None = None, agent_to_llm_config: dict[str, LLMConfig] | None = None, @@ -148,6 +149,7 @@ class AgentController: self.agent = agent self.headless_mode = headless_mode self.is_delegate = is_delegate + self.convo_stats = convo_stats # the event stream must be set before maybe subscribing to it self.event_stream = event_stream @@ -163,6 +165,7 @@ class AgentController: # state from the previous session, state from a parent agent, or a fresh state self.set_initial_state( state=initial_state, + convo_stats=convo_stats, max_iterations=iteration_delta, max_budget_per_task=budget_per_task_delta, confirmation_mode=confirmation_mode, @@ -477,11 +480,6 @@ class AgentController: log_level, str(observation_to_print), extra={'msg_type': 'OBSERVATION'} ) - # TODO: these metrics come from the draft editor, and they get accumulated into controller's state metrics and the agent's llm metrics - # In the future, we should have a more principled way to sharing metrics across all LLM instances for a given conversation - if observation.llm_metrics is not None: - self.state_tracker.merge_metrics(observation.llm_metrics) - # this happens for runnable actions and microagent actions if self._pending_action and self._pending_action.id == observation.cause: if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION: @@ -657,14 +655,10 @@ class AgentController: """ agent_cls: type[Agent] = Agent.get_cls(action.agent) agent_config = self.agent_configs.get(action.agent, self.agent.config) - llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config) # Make sure metrics are shared between parent and child for global accumulation - llm = LLM( - config=llm_config, - retry_listener=self.agent.llm.retry_listener, - metrics=self.state.metrics, + delegate_agent = agent_cls( + config=agent_config, llm_registry=self.agent.llm_registry ) - delegate_agent = agent_cls(llm=llm, config=agent_config) # Take a snapshot of the current metrics before starting the delegate state = State( @@ -683,7 +677,7 @@ class AgentController: ) self.log( 'debug', - f'start delegate, creating agent {delegate_agent.name} using LLM {llm}', + f'start delegate, creating agent {delegate_agent.name}', ) # Create the delegate with is_delegate=True so it does NOT subscribe directly @@ -693,6 +687,7 @@ class AgentController: user_id=self.user_id, agent=delegate_agent, event_stream=self.event_stream, + convo_stats=self.convo_stats, iteration_delta=self._initial_max_iterations, budget_per_task_delta=self._initial_max_budget_per_task, agent_to_llm_config=self.agent_to_llm_config, @@ -795,13 +790,8 @@ class AgentController: extra={'msg_type': 'STEP'}, ) - # Ensure budget control flag is synchronized with the latest metrics. - # In the future, we should centralized the use of one LLM object per conversation. - # This will help us unify the cost for auto generating titles, running the condensor, etc. - # Before many microservices will touh the same llm cost field, we should sync with the budget flag for the controller - # and check that we haven't exceeded budget BEFORE executing an agent step. + # Synchronize spend across all llm services with the budget flag self.state_tracker.sync_budget_flag_with_metrics() - if self._is_stuck(): await self._react_to_exception( AgentStuckInLoopError('Agent got stuck in a loop') @@ -961,14 +951,15 @@ class AgentController: def set_initial_state( self, state: State | None, + convo_stats: ConversationStats, max_iterations: int, max_budget_per_task: float | None, confirmation_mode: bool = False, ): self.state_tracker.set_initial_state( self.id, - self.agent, state, + convo_stats, max_iterations, max_budget_per_task, confirmation_mode, @@ -1009,37 +1000,20 @@ class AgentController: action: The action to attach metrics to """ # Get metrics from agent LLM - agent_metrics = self.state.metrics + metrics = self.convo_stats.get_combined_metrics() - # Get metrics from condenser LLM if it exists - condenser_metrics: Metrics | None = None - if hasattr(self.agent, 'condenser') and hasattr(self.agent.condenser, 'llm'): - condenser_metrics = self.agent.condenser.llm.metrics - - # Create a new minimal metrics object with just what the frontend needs - metrics = Metrics(model_name=agent_metrics.model_name) - - # Set accumulated cost (sum of agent and condenser costs) - metrics.accumulated_cost = agent_metrics.accumulated_cost - if condenser_metrics: - metrics.accumulated_cost += condenser_metrics.accumulated_cost + # Create a clean copy with only the fields we want to keep + clean_metrics = Metrics() + clean_metrics.accumulated_cost = metrics.accumulated_cost + clean_metrics._accumulated_token_usage = copy.deepcopy( + metrics.accumulated_token_usage + ) # Add max_budget_per_task to metrics if self.state.budget_flag: - metrics.max_budget_per_task = self.state.budget_flag.max_value + clean_metrics.max_budget_per_task = self.state.budget_flag.max_value - # Set accumulated token usage (sum of agent and condenser token usage) - # Use a deep copy to ensure we don't modify the original object - metrics._accumulated_token_usage = ( - agent_metrics.accumulated_token_usage.model_copy(deep=True) - ) - if condenser_metrics: - metrics._accumulated_token_usage = ( - metrics._accumulated_token_usage - + condenser_metrics.accumulated_token_usage - ) - - action.llm_metrics = metrics + action.llm_metrics = clean_metrics # Log the metrics information for debugging # Get the latest usage directly from the agent's metrics diff --git a/openhands/controller/state/state.py b/openhands/controller/state/state.py index 3af407d896..4f9ca24a1f 100644 --- a/openhands/controller/state/state.py +++ b/openhands/controller/state/state.py @@ -21,6 +21,7 @@ from openhands.events.action.agent import AgentFinishAction from openhands.events.event import Event, EventSource from openhands.llm.metrics import Metrics from openhands.memory.view import View +from openhands.server.services.conversation_stats import ConversationStats from openhands.storage.files import FileStore from openhands.storage.locations import get_conversation_agent_state_filename @@ -84,6 +85,7 @@ class State: limit_increase_amount=100, current_value=0, max_value=100 ) ) + convo_stats: ConversationStats | None = None budget_flag: BudgetControlFlag | None = None confirmation_mode: bool = False history: list[Event] = field(default_factory=list) @@ -91,8 +93,7 @@ class State: outputs: dict = field(default_factory=dict) agent_state: AgentState = AgentState.LOADING resume_state: AgentState | None = None - # global metrics for the current task - metrics: Metrics = field(default_factory=Metrics) + # root agent has level 0, and every delegate increases the level by one delegate_level: int = 0 # start_id and end_id track the range of events in history @@ -116,9 +117,14 @@ class State: local_metrics: Metrics | None = None delegates: dict[tuple[int, int], tuple[str, str]] | None = None + metrics: Metrics = field(default_factory=Metrics) + def save_to_session( self, sid: str, file_store: FileStore, user_id: str | None ) -> None: + convo_stats = self.convo_stats + self.convo_stats = None # Don't save convo stats, handles itself + pickled = pickle.dumps(self) logger.debug(f'Saving state to session {sid}:{self.agent_state}') encoded = base64.b64encode(pickled).decode('utf-8') @@ -138,6 +144,8 @@ class State: logger.error(f'Failed to save state to session: {e}') raise e + self.convo_stats = convo_stats # restore reference + @staticmethod def restore_from_session( sid: str, file_store: FileStore, user_id: str | None = None diff --git a/openhands/controller/state/state_tracker.py b/openhands/controller/state/state_tracker.py index aab0a2b07f..c3aad9effc 100644 --- a/openhands/controller/state/state_tracker.py +++ b/openhands/controller/state/state_tracker.py @@ -1,4 +1,3 @@ -from openhands.controller.agent import Agent from openhands.controller.state.control_flags import ( BudgetControlFlag, IterationControlFlag, @@ -14,7 +13,7 @@ from openhands.events.observation.delegate import AgentDelegateObservation from openhands.events.observation.empty import NullObservation from openhands.events.serialization.event import event_to_trajectory from openhands.events.stream import EventStream -from openhands.llm.metrics import Metrics +from openhands.server.services.conversation_stats import ConversationStats from openhands.storage.files import FileStore @@ -51,8 +50,8 @@ class StateTracker: def set_initial_state( self, id: str, - agent: Agent, state: State | None, + convo_stats: ConversationStats, max_iterations: int, max_budget_per_task: float | None, confirmation_mode: bool = False, @@ -75,6 +74,7 @@ class StateTracker: session_id=id.removesuffix('-delegate'), user_id=self.user_id, inputs={}, + convo_stats=convo_stats, iteration_flag=IterationControlFlag( limit_increase_amount=max_iterations, current_value=0, @@ -99,13 +99,7 @@ class StateTracker: if self.state.start_id <= -1: self.state.start_id = 0 - logger.info( - f'AgentController {id} initializing history from event {self.state.start_id}', - ) - - # Share the state metrics with the agent's LLM metrics - # This ensures that all accumulated metrics are always in sync between controller and llm - agent.llm.metrics = self.state.metrics + state.convo_stats = convo_stats def _init_history(self, event_stream: EventStream) -> None: """Initializes the agent's history from the event stream. @@ -254,6 +248,9 @@ class StateTracker: if self.sid and self.file_store: self.state.save_to_session(self.sid, self.file_store, self.user_id) + if self.state.convo_stats: + self.state.convo_stats.save_metrics() + def run_control_flags(self): """Performs one step of the control flags""" self.state.iteration_flag.step() @@ -264,20 +261,8 @@ class StateTracker: """Ensures that budget flag is up to date with accumulated costs from llm completions Budget flag will monitor for when budget is exceeded """ - if self.state.budget_flag: - self.state.budget_flag.current_value = self.state.metrics.accumulated_cost - - def merge_metrics(self, metrics: Metrics): - """Merges metrics with the state metrics - - NOTE: this should be refactored in the future. We should have services (draft llm, title autocomplete, condenser, etc) - use their own LLMs, but the metrics object should be shared. This way we have one source of truth for accumulated costs from - all services - - This would prevent having fragmented stores for metrics, and we don't have the burden of deciding where and how to store them - if we decide introduce more specialized services that require llm completions - - """ - self.state.metrics.merge(metrics) - if self.state.budget_flag: - self.state.budget_flag.current_value = self.state.metrics.accumulated_cost + # Sync cost across all llm services from llm registry + if self.state.budget_flag and self.state.convo_stats: + self.state.budget_flag.current_value = ( + self.state.convo_stats.get_combined_metrics().accumulated_cost + ) diff --git a/openhands/core/config/openhands_config.py b/openhands/core/config/openhands_config.py index 792aabb7c0..a57053f36a 100644 --- a/openhands/core/config/openhands_config.py +++ b/openhands/core/config/openhands_config.py @@ -157,13 +157,16 @@ class OpenHandsConfig(BaseModel): """Get a map of agent names to llm configs.""" return {name: self.get_llm_config_from_agent(name) for name in self.agents} - def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig: - agent_config: AgentConfig = self.get_agent_config(name) + def get_llm_config_from_agent_config(self, agent_config: AgentConfig): llm_config_name = ( agent_config.llm_config if agent_config.llm_config is not None else 'llm' ) return self.get_llm_config(llm_config_name) + def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig: + agent_config: AgentConfig = self.get_agent_config(name) + return self.get_llm_config_from_agent_config(agent_config) + def get_agent_configs(self) -> dict[str, AgentConfig]: return self.agents diff --git a/openhands/core/main.py b/openhands/core/main.py index 3adb8b6ce1..25f047c771 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -6,7 +6,6 @@ from typing import Callable, Protocol import openhands.agenthub # noqa F401 (we import this to get the agents registered) import openhands.cli.suppress_warnings # noqa: F401 -from openhands.controller.agent import Agent from openhands.controller.replay import ReplayManager from openhands.controller.state.state import State from openhands.core.config import ( @@ -33,10 +32,12 @@ from openhands.events.action.action import Action from openhands.events.event import Event from openhands.events.observation import AgentStateChangedObservation from openhands.io import read_input, read_task +from openhands.llm.llm_registry import LLMRegistry from openhands.mcp import add_mcp_tools_to_agent from openhands.memory.memory import Memory from openhands.runtime.base import Runtime from openhands.utils.async_utils import call_async_from_sync +from openhands.utils.utils import create_registry_and_convo_stats class FakeUserResponseFunc(Protocol): @@ -53,12 +54,12 @@ async def run_controller( initial_user_action: Action, sid: str | None = None, runtime: Runtime | None = None, - agent: Agent | None = None, exit_on_message: bool = False, fake_user_response_fn: FakeUserResponseFunc | None = None, headless_mode: bool = True, memory: Memory | None = None, conversation_instructions: str | None = None, + llm_registry: LLMRegistry | None = None, ) -> State | None: """Main coroutine to run the agent controller with task input flexibility. @@ -70,7 +71,6 @@ async def run_controller( sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing. Set it to incompatible value will cause unexpected behavior on RemoteRuntime. runtime: (optional) A runtime for the agent to run on. - agent: (optional) A agent to run. exit_on_message: quit if agent asks for a message from user (optional) fake_user_response_fn: An optional function that receives the current state (could be None) and returns a fake user response. @@ -98,8 +98,13 @@ async def run_controller( """ sid = sid or generate_sid(config) - if agent is None: - agent = create_agent(config) + llm_registry, convo_stats, config = create_registry_and_convo_stats( + config, + sid, + None, + ) + + agent = create_agent(config, llm_registry) # when the runtime is created, it will be connected and clone the selected repository repo_directory = None @@ -108,6 +113,7 @@ async def run_controller( repo_tokens = get_provider_tokens() runtime = create_runtime( config, + llm_registry, sid=sid, headless_mode=headless_mode, agent=agent, @@ -159,7 +165,7 @@ async def run_controller( ) controller, initial_state = create_controller( - agent, runtime, config, replay_events=replay_events + agent, runtime, config, convo_stats, replay_events=replay_events ) assert isinstance(initial_user_action, Action), ( diff --git a/openhands/core/setup.py b/openhands/core/setup.py index 34c066fb95..648101b6ad 100644 --- a/openhands/core/setup.py +++ b/openhands/core/setup.py @@ -21,12 +21,13 @@ from openhands.integrations.provider import ( ProviderToken, ProviderType, ) -from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.memory import Memory from openhands.microagent.microagent import BaseMicroagent from openhands.runtime import get_runtime_cls from openhands.runtime.base import Runtime from openhands.security import SecurityAnalyzer, options +from openhands.server.services.conversation_stats import ConversationStats from openhands.storage import get_file_store from openhands.storage.data_models.user_secrets import UserSecrets from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync @@ -34,6 +35,7 @@ from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync def create_runtime( config: OpenHandsConfig, + llm_registry: LLMRegistry, sid: str | None = None, headless_mode: bool = True, agent: Agent | None = None, @@ -82,6 +84,7 @@ def create_runtime( sid=session_id, plugins=agent_cls.sandbox_plugins, headless_mode=headless_mode, + llm_registry=llm_registry, git_provider_tokens=git_provider_tokens, ) @@ -203,16 +206,11 @@ def create_memory( return memory -def create_agent(config: OpenHandsConfig) -> Agent: +def create_agent(config: OpenHandsConfig, llm_registry: LLMRegistry) -> Agent: agent_cls: type[Agent] = Agent.get_cls(config.default_agent) agent_config = config.get_agent_config(config.default_agent) - llm_config = config.get_llm_config_from_agent(config.default_agent) - - agent = agent_cls( - llm=LLM(config=llm_config), - config=agent_config, - ) - + config.get_llm_config_from_agent(config.default_agent) + agent = agent_cls(config=agent_config, llm_registry=llm_registry) return agent @@ -220,6 +218,7 @@ def create_controller( agent: Agent, runtime: Runtime, config: OpenHandsConfig, + convo_stats: ConversationStats, headless_mode: bool = True, replay_events: list[Event] | None = None, ) -> tuple[AgentController, State | None]: @@ -237,6 +236,7 @@ def create_controller( controller = AgentController( agent=agent, + convo_stats=convo_stats, iteration_delta=config.max_iterations, budget_per_task_delta=config.max_budget_per_task, agent_to_llm_config=config.get_agent_to_llm_config_map(), diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 62b0f94acc..997f82fecb 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -8,6 +8,7 @@ from typing import Any, Callable import httpx from openhands.core.config import LLMConfig +from openhands.llm.metrics import Metrics with warnings.catch_warnings(): warnings.simplefilter('ignore') @@ -34,7 +35,6 @@ from openhands.llm.fn_call_converter import ( convert_fncall_messages_to_non_fncall_messages, convert_non_fncall_messages_to_fncall_messages, ) -from openhands.llm.metrics import Metrics from openhands.llm.retry_mixin import RetryMixin __all__ = ['LLM'] @@ -133,6 +133,7 @@ class LLM(RetryMixin, DebugMixin): def __init__( self, config: LLMConfig, + service_id: str, metrics: Metrics | None = None, retry_listener: Callable[[int, int], None] | None = None, ) -> None: @@ -145,11 +146,12 @@ class LLM(RetryMixin, DebugMixin): metrics: The metrics to use. """ self._tried_model_info = False + self.cost_metric_supported: bool = True + self.config: LLMConfig = copy.deepcopy(config) + self.service_id = service_id self.metrics: Metrics = ( metrics if metrics is not None else Metrics(model_name=config.model) ) - self.cost_metric_supported: bool = True - self.config: LLMConfig = copy.deepcopy(config) self.model_info: ModelInfo | None = None self.retry_listener = retry_listener @@ -408,8 +410,7 @@ class LLM(RetryMixin, DebugMixin): assert self.config.log_completions_folder is not None log_file = os.path.join( self.config.log_completions_folder, - # use the metric model name (for draft editor) - f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json', + f'{self.config.model.replace("/", "__")}-{time.time()}.json', ) # set up the dict to be logged diff --git a/openhands/llm/llm_registry.py b/openhands/llm/llm_registry.py new file mode 100644 index 0000000000..f329ae7f1a --- /dev/null +++ b/openhands/llm/llm_registry.py @@ -0,0 +1,132 @@ +import copy +from typing import Any, Callable +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict + +from openhands.core.config.agent_config import AgentConfig +from openhands.core.config.llm_config import LLMConfig +from openhands.core.config.openhands_config import OpenHandsConfig +from openhands.core.logger import openhands_logger as logger +from openhands.llm.llm import LLM + + +class RegistryEvent(BaseModel): + llm: LLM + service_id: str + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + +class LLMRegistry: + def __init__( + self, + config: OpenHandsConfig, + agent_cls: str | None = None, + retry_listener: Callable[[int, int], None] | None = None, + ): + self.registry_id = str(uuid4()) + self.config = copy.deepcopy(config) + self.retry_listner = retry_listener + self.agent_to_llm_config = self.config.get_agent_to_llm_config_map() + self.service_to_llm: dict[str, LLM] = {} + self.subscriber: Callable[[Any], None] | None = None + + selected_agent_cls = self.config.default_agent + if agent_cls: + selected_agent_cls = agent_cls + + agent_name = selected_agent_cls if selected_agent_cls is not None else 'agent' + llm_config = self.config.get_llm_config_from_agent(agent_name) + self.active_agent_llm: LLM = self.get_llm('agent', llm_config) + + def _create_new_llm( + self, service_id: str, config: LLMConfig, with_listener: bool = True + ) -> LLM: + if with_listener: + llm = LLM( + service_id=service_id, config=config, retry_listener=self.retry_listner + ) + else: + llm = LLM(service_id=service_id, config=config) + self.service_to_llm[service_id] = llm + self.notify(RegistryEvent(llm=llm, service_id=service_id)) + return llm + + def request_extraneous_completion( + self, service_id: str, llm_config: LLMConfig, messages: list[dict[str, str]] + ) -> str: + logger.info(f'extraneous completion: {service_id}') + if service_id not in self.service_to_llm: + self._create_new_llm( + config=llm_config, service_id=service_id, with_listener=False + ) + + llm = self.service_to_llm[service_id] + response = llm.completion(messages=messages) + return response.choices[0].message.content.strip() + + def get_llm_from_agent_config(self, service_id: str, agent_config: AgentConfig): + llm_config = self.config.get_llm_config_from_agent_config(agent_config) + if service_id in self.service_to_llm: + if self.service_to_llm[service_id].config != llm_config: + # TODO: update llm config internally + # Done when agent delegates has different config, we should reuse the existing LLM + pass + return self.service_to_llm[service_id] + + return self._create_new_llm(config=llm_config, service_id=service_id) + + def get_llm( + self, + service_id: str, + config: LLMConfig | None = None, + ): + logger.info( + f'[LLM registry {self.registry_id}]: Registering service for {service_id}' + ) + + # Attempting to switch configs for existing LLM + if ( + service_id in self.service_to_llm + and self.service_to_llm[service_id].config != config + ): + raise ValueError( + f'Requesting same service ID {service_id} with different config, use a new service ID' + ) + + if service_id in self.service_to_llm: + return self.service_to_llm[service_id] + + if not config: + raise ValueError('Requesting new LLM without specifying LLM config') + + return self._create_new_llm(config=config, service_id=service_id) + + def get_active_llm(self) -> LLM: + return self.active_agent_llm + + def _set_active_llm(self, service_id) -> None: + if service_id not in self.service_to_llm: + raise ValueError(f'Unrecognized service ID: {service_id}') + self.active_agent_llm = self.service_to_llm[service_id] + + def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None: + self.subscriber = callback + + # Subscriptions happen after default llm is initialized + # Notify service of this llm + self.notify( + RegistryEvent( + llm=self.active_agent_llm, service_id=self.active_agent_llm.service_id + ) + ) + + def notify(self, event: RegistryEvent): + if self.subscriber: + try: + self.subscriber(event) + except Exception as e: + logger.warning(f'Failed to emit event: {e}') diff --git a/openhands/memory/condenser/condenser.py b/openhands/memory/condenser/condenser.py index 6e52780b0b..c4e42277a9 100644 --- a/openhands/memory/condenser/condenser.py +++ b/openhands/memory/condenser/condenser.py @@ -10,6 +10,7 @@ from openhands.controller.state.state import State from openhands.core.config.condenser_config import CondenserConfig from openhands.core.logger import openhands_logger as logger from openhands.events.action.agent import CondensationAction +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.view import View CONDENSER_METADATA_KEY = 'condenser_meta' @@ -144,7 +145,9 @@ class Condenser(ABC): CONDENSER_REGISTRY[configuration_type] = cls @classmethod - def from_config(cls, config: CondenserConfig) -> Condenser: + def from_config( + cls, config: CondenserConfig, llm_registry: LLMRegistry + ) -> Condenser: """Create a condenser from a configuration object. Args: @@ -158,7 +161,7 @@ class Condenser(ABC): """ try: condenser_class = CONDENSER_REGISTRY[type(config)] - return condenser_class.from_config(config) + return condenser_class.from_config(config, llm_registry) except KeyError: raise ValueError(f'Unknown condenser config: {config}') diff --git a/openhands/memory/condenser/impl/amortized_forgetting_condenser.py b/openhands/memory/condenser/impl/amortized_forgetting_condenser.py index 15eb7b76a0..a33455c341 100644 --- a/openhands/memory/condenser/impl/amortized_forgetting_condenser.py +++ b/openhands/memory/condenser/impl/amortized_forgetting_condenser.py @@ -2,6 +2,7 @@ from __future__ import annotations from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig from openhands.events.action.agent import CondensationAction +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import ( Condensation, RollingCondenser, @@ -58,7 +59,9 @@ class AmortizedForgettingCondenser(RollingCondenser): @classmethod def from_config( - cls, config: AmortizedForgettingCondenserConfig + cls, + config: AmortizedForgettingCondenserConfig, + llm_registry: LLMRegistry, ) -> AmortizedForgettingCondenser: return AmortizedForgettingCondenser(**config.model_dump(exclude={'type'})) diff --git a/openhands/memory/condenser/impl/browser_output_condenser.py b/openhands/memory/condenser/impl/browser_output_condenser.py index e7fa456fae..b3e4683cad 100644 --- a/openhands/memory/condenser/impl/browser_output_condenser.py +++ b/openhands/memory/condenser/impl/browser_output_condenser.py @@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import BrowserOutputCondenserConfig from openhands.events.event import Event from openhands.events.observation import BrowserOutputObservation from openhands.events.observation.agent import AgentCondensationObservation +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import Condensation, Condenser, View @@ -40,7 +41,7 @@ class BrowserOutputCondenser(Condenser): @classmethod def from_config( - cls, config: BrowserOutputCondenserConfig + cls, config: BrowserOutputCondenserConfig, llm_registry: LLMRegistry ) -> BrowserOutputCondenser: return BrowserOutputCondenser(**config.model_dump(exclude={'type'})) diff --git a/openhands/memory/condenser/impl/conversation_window_condenser.py b/openhands/memory/condenser/impl/conversation_window_condenser.py index 58b9ed8187..9ca9e11255 100644 --- a/openhands/memory/condenser/impl/conversation_window_condenser.py +++ b/openhands/memory/condenser/impl/conversation_window_condenser.py @@ -9,6 +9,7 @@ from openhands.events.action.agent import ( from openhands.events.action.message import MessageAction, SystemMessageAction from openhands.events.event import EventSource from openhands.events.observation import Observation +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View @@ -177,7 +178,9 @@ class ConversationWindowCondenser(RollingCondenser): @classmethod def from_config( - cls, _config: ConversationWindowCondenserConfig + cls, + _config: ConversationWindowCondenserConfig, + llm_registry: LLMRegistry, ) -> ConversationWindowCondenser: return ConversationWindowCondenser() diff --git a/openhands/memory/condenser/impl/llm_attention_condenser.py b/openhands/memory/condenser/impl/llm_attention_condenser.py index 106e5a46b1..81b7fde8dc 100644 --- a/openhands/memory/condenser/impl/llm_attention_condenser.py +++ b/openhands/memory/condenser/impl/llm_attention_condenser.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from openhands.core.config.condenser_config import LLMAttentionCondenserConfig from openhands.events.action.agent import CondensationAction from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import ( Condensation, RollingCondenser, @@ -22,7 +23,12 @@ class ImportantEventSelection(BaseModel): class LLMAttentionCondenser(RollingCondenser): """Rolling condenser strategy that uses an LLM to select the most important events when condensing the history.""" - def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 1): + def __init__( + self, + llm: LLM, + max_size: int = 100, + keep_first: int = 1, + ): if keep_first >= max_size // 2: raise ValueError( f'keep_first ({keep_first}) must be less than half of max_size ({max_size})' @@ -113,15 +119,19 @@ class LLMAttentionCondenser(RollingCondenser): return len(view) > self.max_size @classmethod - def from_config(cls, config: LLMAttentionCondenserConfig) -> LLMAttentionCondenser: + def from_config( + cls, config: LLMAttentionCondenserConfig, llm_registry: LLMRegistry + ) -> LLMAttentionCondenser: # This condenser cannot take advantage of prompt caching. If it happens # to be set, we'll pay for the cache writes but never get a chance to # save on a read. llm_config = config.llm_config.model_copy() llm_config.caching_prompt = False + llm = llm_registry.get_llm('condenser', llm_config) + return LLMAttentionCondenser( - llm=LLM(config=llm_config), + llm=llm, max_size=config.max_size, keep_first=config.keep_first, ) diff --git a/openhands/memory/condenser/impl/llm_summarizing_condenser.py b/openhands/memory/condenser/impl/llm_summarizing_condenser.py index 8ea73a25b2..b78699bb57 100644 --- a/openhands/memory/condenser/impl/llm_summarizing_condenser.py +++ b/openhands/memory/condenser/impl/llm_summarizing_condenser.py @@ -5,7 +5,8 @@ from openhands.core.message import Message, TextContent from openhands.events.action.agent import CondensationAction from openhands.events.observation.agent import AgentCondensationObservation from openhands.events.serialization.event import truncate_content -from openhands.llm import LLM +from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import ( Condensation, RollingCondenser, @@ -154,16 +155,17 @@ CURRENT_STATE: Last flip: Heads, Haiku count: 15/20""" @classmethod def from_config( - cls, config: LLMSummarizingCondenserConfig + cls, config: LLMSummarizingCondenserConfig, llm_registry: LLMRegistry ) -> LLMSummarizingCondenser: # This condenser cannot take advantage of prompt caching. If it happens # to be set, we'll pay for the cache writes but never get a chance to # save on a read. llm_config = config.llm_config.model_copy() llm_config.caching_prompt = False + llm = llm_registry.get_llm('condenser', llm_config) return LLMSummarizingCondenser( - llm=LLM(config=llm_config), + llm=llm, max_size=config.max_size, keep_first=config.keep_first, max_event_length=config.max_event_length, diff --git a/openhands/memory/condenser/impl/no_op_condenser.py b/openhands/memory/condenser/impl/no_op_condenser.py index 436cf05813..9f480a129f 100644 --- a/openhands/memory/condenser/impl/no_op_condenser.py +++ b/openhands/memory/condenser/impl/no_op_condenser.py @@ -1,6 +1,7 @@ from __future__ import annotations from openhands.core.config.condenser_config import NoOpCondenserConfig +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import Condensation, Condenser, View @@ -12,7 +13,9 @@ class NoOpCondenser(Condenser): return view @classmethod - def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser: + def from_config( + cls, config: NoOpCondenserConfig, llm_registry: LLMRegistry + ) -> NoOpCondenser: return NoOpCondenser() diff --git a/openhands/memory/condenser/impl/observation_masking_condenser.py b/openhands/memory/condenser/impl/observation_masking_condenser.py index b5e8de740c..71d691b00f 100644 --- a/openhands/memory/condenser/impl/observation_masking_condenser.py +++ b/openhands/memory/condenser/impl/observation_masking_condenser.py @@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import ObservationMaskingCondenserCo from openhands.events.event import Event from openhands.events.observation import Observation from openhands.events.observation.agent import AgentCondensationObservation +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import Condensation, Condenser, View @@ -28,7 +29,9 @@ class ObservationMaskingCondenser(Condenser): @classmethod def from_config( - cls, config: ObservationMaskingCondenserConfig + cls, + config: ObservationMaskingCondenserConfig, + llm_registry: LLMRegistry, ) -> ObservationMaskingCondenser: return ObservationMaskingCondenser(**config.model_dump(exclude={'type'})) diff --git a/openhands/memory/condenser/impl/pipeline.py b/openhands/memory/condenser/impl/pipeline.py index cd9f458201..c32aa4b255 100644 --- a/openhands/memory/condenser/impl/pipeline.py +++ b/openhands/memory/condenser/impl/pipeline.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from openhands.controller.state.state import State from openhands.core.config.condenser_config import CondenserPipelineConfig +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import Condensation, Condenser from openhands.memory.view import View @@ -39,8 +40,10 @@ class CondenserPipeline(Condenser): return result @classmethod - def from_config(cls, config: CondenserPipelineConfig) -> CondenserPipeline: - condensers = [Condenser.from_config(c) for c in config.condensers] + def from_config( + cls, config: CondenserPipelineConfig, llm_registry: LLMRegistry + ) -> CondenserPipeline: + condensers = [Condenser.from_config(c, llm_registry) for c in config.condensers] return CondenserPipeline(*condensers) diff --git a/openhands/memory/condenser/impl/recent_events_condenser.py b/openhands/memory/condenser/impl/recent_events_condenser.py index 099b4846c8..0492532781 100644 --- a/openhands/memory/condenser/impl/recent_events_condenser.py +++ b/openhands/memory/condenser/impl/recent_events_condenser.py @@ -1,6 +1,7 @@ from __future__ import annotations from openhands.core.config.condenser_config import RecentEventsCondenserConfig +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import Condensation, Condenser, View @@ -21,7 +22,9 @@ class RecentEventsCondenser(Condenser): return View(events=head + tail) @classmethod - def from_config(cls, config: RecentEventsCondenserConfig) -> RecentEventsCondenser: + def from_config( + cls, config: RecentEventsCondenserConfig, llm_registry: LLMRegistry + ) -> RecentEventsCondenser: return RecentEventsCondenser(**config.model_dump(exclude={'type'})) diff --git a/openhands/memory/condenser/impl/structured_summary_condenser.py b/openhands/memory/condenser/impl/structured_summary_condenser.py index 6bfe875d3e..a698e898d8 100644 --- a/openhands/memory/condenser/impl/structured_summary_condenser.py +++ b/openhands/memory/condenser/impl/structured_summary_condenser.py @@ -13,7 +13,8 @@ from openhands.core.message import Message, TextContent from openhands.events.action.agent import CondensationAction from openhands.events.observation.agent import AgentCondensationObservation from openhands.events.serialization.event import truncate_content -from openhands.llm import LLM +from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser.condenser import ( Condensation, RollingCondenser, @@ -180,15 +181,14 @@ class StructuredSummaryCondenser(RollingCondenser): if max_size < 1: raise ValueError(f'max_size ({max_size}) cannot be non-positive') - if not llm.is_function_calling_active(): - raise ValueError( - 'LLM must support function calling to use StructuredSummaryCondenser' - ) - self.max_size = max_size self.keep_first = keep_first self.max_event_length = max_event_length self.llm = llm + if not self.llm.is_function_calling_active(): + raise ValueError( + 'LLM must support function calling to use StructuredSummaryCondenser' + ) super().__init__() @@ -309,16 +309,17 @@ Capture all relevant information, especially: @classmethod def from_config( - cls, config: StructuredSummaryCondenserConfig + cls, config: StructuredSummaryCondenserConfig, llm_registry: LLMRegistry ) -> StructuredSummaryCondenser: # This condenser cannot take advantage of prompt caching. If it happens # to be set, we'll pay for the cache writes but never get a chance to # save on a read. llm_config = config.llm_config.model_copy() llm_config.caching_prompt = False + llm = llm_registry.get_llm('condenser', llm_config) return StructuredSummaryCondenser( - llm=LLM(config=llm_config), + llm=llm, max_size=config.max_size, keep_first=config.keep_first, max_event_length=config.max_event_length, diff --git a/openhands/resolver/interfaces/issue_definitions.py b/openhands/resolver/interfaces/issue_definitions.py index f94fe30dc1..a3b5b24c42 100644 --- a/openhands/resolver/interfaces/issue_definitions.py +++ b/openhands/resolver/interfaces/issue_definitions.py @@ -23,7 +23,7 @@ class ServiceContext: def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None): self._strategy = strategy if llm_config is not None: - self.llm = LLM(llm_config) + self.llm = LLM(llm_config, service_id='resolver') def set_strategy(self, strategy: IssueHandlerInterface) -> None: self._strategy = strategy diff --git a/openhands/resolver/issue_resolver.py b/openhands/resolver/issue_resolver.py index 474096baa6..d8953f47ec 100644 --- a/openhands/resolver/issue_resolver.py +++ b/openhands/resolver/issue_resolver.py @@ -28,6 +28,7 @@ from openhands.events.observation import ( ) from openhands.events.stream import EventStreamSubscriber from openhands.integrations.service_types import ProviderType +from openhands.llm.llm_registry import LLMRegistry from openhands.resolver.interfaces.issue import Issue from openhands.resolver.interfaces.issue_definitions import ( ServiceContextIssue, @@ -412,7 +413,8 @@ class IssueResolver: shutil.rmtree(self.workspace_base) shutil.copytree(os.path.join(self.output_dir, 'repo'), self.workspace_base) - runtime = create_runtime(self.app_config) + llm_registry = LLMRegistry(self.app_config) + runtime = create_runtime(self.app_config, llm_registry) await runtime.connect() def on_event(evt: Event) -> None: diff --git a/openhands/resolver/send_pull_request.py b/openhands/resolver/send_pull_request.py index 71a3b9fca1..bf396546a9 100644 --- a/openhands/resolver/send_pull_request.py +++ b/openhands/resolver/send_pull_request.py @@ -463,7 +463,7 @@ def update_existing_pull_request( # Summarize with LLM if provided if llm_config is not None: - llm = LLM(llm_config) + llm = LLM(llm_config, service_id='resolver') with open( os.path.join( os.path.dirname(__file__), diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index f355bd90ce..e1cf4c9560 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -54,6 +54,7 @@ from openhands.integrations.provider import ( ProviderType, ) from openhands.integrations.service_types import AuthenticationError +from openhands.llm.llm_registry import LLMRegistry from openhands.microagent import ( BaseMicroagent, load_microagents_from_dir, @@ -125,6 +126,7 @@ class Runtime(FileEditRuntimeMixin): self, config: OpenHandsConfig, event_stream: EventStream, + llm_registry: LLMRegistry, sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, @@ -178,7 +180,9 @@ class Runtime(FileEditRuntimeMixin): # Load mixins FileEditRuntimeMixin.__init__( - self, enable_llm_editor=config.get_agent_config().enable_llm_editor + self, + enable_llm_editor=config.get_agent_config().enable_llm_editor, + llm_registry=llm_registry, ) self.user_id = user_id diff --git a/openhands/runtime/impl/action_execution/action_execution_client.py b/openhands/runtime/impl/action_execution/action_execution_client.py index bf3515b355..93fee37e8b 100644 --- a/openhands/runtime/impl/action_execution/action_execution_client.py +++ b/openhands/runtime/impl/action_execution/action_execution_client.py @@ -43,6 +43,7 @@ from openhands.events.observation import ( from openhands.events.serialization import event_to_dict, observation_from_dict from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS from openhands.integrations.provider import PROVIDER_TOKEN_TYPE +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.base import Runtime from openhands.runtime.plugins import PluginRequirement from openhands.runtime.utils.request import send_request @@ -68,6 +69,7 @@ class ActionExecutionClient(Runtime): self, config: OpenHandsConfig, event_stream: EventStream, + llm_registry: LLMRegistry, sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, @@ -85,6 +87,7 @@ class ActionExecutionClient(Runtime): super().__init__( config, event_stream, + llm_registry, sid, plugins, env_vars, diff --git a/openhands/runtime/impl/cli/cli_runtime.py b/openhands/runtime/impl/cli/cli_runtime.py index bd5b12e284..acac7441f3 100644 --- a/openhands/runtime/impl/cli/cli_runtime.py +++ b/openhands/runtime/impl/cli/cli_runtime.py @@ -46,6 +46,7 @@ from openhands.events.observation import ( Observation, ) from openhands.integrations.provider import PROVIDER_TOKEN_TYPE +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.base import Runtime from openhands.runtime.plugins import PluginRequirement from openhands.runtime.runtime_status import RuntimeStatus @@ -107,6 +108,7 @@ class CLIRuntime(Runtime): self, config: OpenHandsConfig, event_stream: EventStream, + llm_registry: LLMRegistry, sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, @@ -119,6 +121,7 @@ class CLIRuntime(Runtime): super().__init__( config, event_stream, + llm_registry, sid, plugins, env_vars, diff --git a/openhands/runtime/impl/docker/docker_runtime.py b/openhands/runtime/impl/docker/docker_runtime.py index 4ce7ef40fa..d3da289b17 100644 --- a/openhands/runtime/impl/docker/docker_runtime.py +++ b/openhands/runtime/impl/docker/docker_runtime.py @@ -20,6 +20,7 @@ from openhands.core.logger import DEBUG, DEBUG_RUNTIME from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream from openhands.integrations.provider import PROVIDER_TOKEN_TYPE +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.builder import DockerRuntimeBuilder from openhands.runtime.impl.action_execution.action_execution_client import ( ActionExecutionClient, @@ -90,6 +91,7 @@ class DockerRuntime(ActionExecutionClient): self, config: OpenHandsConfig, event_stream: EventStream, + llm_registry: LLMRegistry, sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, @@ -143,6 +145,7 @@ class DockerRuntime(ActionExecutionClient): super().__init__( config, event_stream, + llm_registry, sid, plugins, env_vars, diff --git a/openhands/runtime/impl/kubernetes/kubernetes_runtime.py b/openhands/runtime/impl/kubernetes/kubernetes_runtime.py index 00e1b0de46..0ad2d4efe6 100644 --- a/openhands/runtime/impl/kubernetes/kubernetes_runtime.py +++ b/openhands/runtime/impl/kubernetes/kubernetes_runtime.py @@ -43,6 +43,7 @@ from openhands.core.logger import DEBUG from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream from openhands.integrations.provider import PROVIDER_TOKEN_TYPE +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.impl.action_execution.action_execution_client import ( ActionExecutionClient, ) @@ -81,6 +82,7 @@ class KubernetesRuntime(ActionExecutionClient): self, config: OpenHandsConfig, event_stream: EventStream, + llm_registry: LLMRegistry, sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, @@ -137,6 +139,7 @@ class KubernetesRuntime(ActionExecutionClient): super().__init__( config, event_stream, + llm_registry, sid, plugins, env_vars, diff --git a/openhands/runtime/impl/local/local_runtime.py b/openhands/runtime/impl/local/local_runtime.py index 54c3289c4a..5b5e354e98 100644 --- a/openhands/runtime/impl/local/local_runtime.py +++ b/openhands/runtime/impl/local/local_runtime.py @@ -26,6 +26,7 @@ from openhands.events.observation import ( ) from openhands.events.serialization import event_to_dict, observation_from_dict from openhands.integrations.provider import PROVIDER_TOKEN_TYPE +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.impl.action_execution.action_execution_client import ( ActionExecutionClient, ) @@ -135,6 +136,7 @@ class LocalRuntime(ActionExecutionClient): self, config: OpenHandsConfig, event_stream: EventStream, + llm_registry: LLMRegistry, sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, @@ -186,6 +188,7 @@ class LocalRuntime(ActionExecutionClient): super().__init__( config, event_stream, + llm_registry, sid, plugins, env_vars, @@ -801,12 +804,6 @@ def _create_warm_server_in_background( def _get_plugins(config: OpenHandsConfig) -> list[PluginRequirement]: from openhands.controller.agent import Agent - from openhands.llm.llm import LLM - agent_config = config.get_agent_config(config.default_agent) - llm = LLM( - config=config.get_llm_config_from_agent(config.default_agent), - ) - agent = Agent.get_cls(config.default_agent)(llm, agent_config) - plugins = agent.sandbox_plugins + plugins = Agent.get_cls(config.default_agent).sandbox_plugins return plugins diff --git a/openhands/runtime/impl/remote/remote_runtime.py b/openhands/runtime/impl/remote/remote_runtime.py index 14fc54ff29..ce52a43165 100644 --- a/openhands/runtime/impl/remote/remote_runtime.py +++ b/openhands/runtime/impl/remote/remote_runtime.py @@ -19,6 +19,7 @@ from openhands.core.exceptions import ( from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream from openhands.integrations.provider import PROVIDER_TOKEN_TYPE +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.builder.remote import RemoteRuntimeBuilder from openhands.runtime.impl.action_execution.action_execution_client import ( ActionExecutionClient, @@ -51,6 +52,7 @@ class RemoteRuntime(ActionExecutionClient): self, config: OpenHandsConfig, event_stream: EventStream, + llm_registry: LLMRegistry, sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, @@ -64,6 +66,7 @@ class RemoteRuntime(ActionExecutionClient): super().__init__( config, event_stream, + llm_registry, sid, plugins, env_vars, diff --git a/openhands/runtime/utils/edit.py b/openhands/runtime/utils/edit.py index 520243a4c2..692f0c3ace 100644 --- a/openhands/runtime/utils/edit.py +++ b/openhands/runtime/utils/edit.py @@ -23,7 +23,7 @@ from openhands.events.observation import ( ) from openhands.linter import DefaultLinter from openhands.llm.llm import LLM -from openhands.llm.metrics import Metrics +from openhands.llm.llm_registry import LLMRegistry from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches USER_MSG = """ @@ -128,7 +128,13 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface): # This restricts the number of lines we can edit to avoid exceeding the token limit. MAX_LINES_TO_EDIT = 300 - def __init__(self, enable_llm_editor: bool, *args: Any, **kwargs: Any) -> None: + def __init__( + self, + enable_llm_editor: bool, + llm_registry: LLMRegistry, + *args: Any, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.enable_llm_editor = enable_llm_editor @@ -138,7 +144,6 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface): draft_editor_config = self.config.get_llm_config('draft_editor') # manually set the model name for the draft editor LLM to distinguish token costs - llm_metrics = Metrics(model_name='draft_editor:' + draft_editor_config.model) if draft_editor_config.caching_prompt: logger.debug( 'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. ' @@ -146,7 +151,9 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface): ) draft_editor_config.caching_prompt = False - self.draft_editor_llm = LLM(draft_editor_config, metrics=llm_metrics) + self.draft_editor_llm = llm_registry.get_llm( + 'draft_editor_llm', draft_editor_config + ) logger.debug( f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}' ) diff --git a/openhands/server/conversation_manager/conversation_manager.py b/openhands/server/conversation_manager/conversation_manager.py index 330bcc14c8..8a120d8d22 100644 --- a/openhands/server/conversation_manager/conversation_manager.py +++ b/openhands/server/conversation_manager/conversation_manager.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod import socketio from openhands.core.config import OpenHandsConfig +from openhands.core.config.llm_config import LLMConfig from openhands.events.action import MessageAction from openhands.server.config.server_config import ServerConfig from openhands.server.data_models.agent_loop_info import AgentLoopInfo @@ -136,6 +137,16 @@ class ConversationManager(ABC): ) -> list[AgentLoopInfo]: """Get the AgentLoopInfo for conversations.""" + @abstractmethod + async def request_llm_completion( + self, + sid: str, + service_id: str, + llm_config: LLMConfig, + messages: list[dict[str, str]], + ) -> str: + """Request extraneous llm completions for a conversation""" + @classmethod @abstractmethod def get_instance( diff --git a/openhands/server/conversation_manager/docker_nested_conversation_manager.py b/openhands/server/conversation_manager/docker_nested_conversation_manager.py index 82db976a4f..107fd11253 100644 --- a/openhands/server/conversation_manager/docker_nested_conversation_manager.py +++ b/openhands/server/conversation_manager/docker_nested_conversation_manager.py @@ -15,13 +15,13 @@ from docker.models.containers import Container from openhands.controller.agent import Agent from openhands.core.config import OpenHandsConfig +from openhands.core.config.llm_config import LLMConfig from openhands.core.logger import openhands_logger as logger from openhands.events.action import MessageAction from openhands.events.nested_event_store import NestedEventStore from openhands.events.stream import EventStream from openhands.experiments.experiment_manager import ExperimentManagerImpl from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler -from openhands.llm.llm import LLM from openhands.runtime import get_runtime_cls from openhands.runtime.impl.docker.docker_runtime import DockerRuntime from openhands.server.config.server_config import ServerConfig @@ -42,6 +42,7 @@ from openhands.storage.files import FileStore from openhands.storage.locations import get_conversation_dir from openhands.utils.async_utils import call_sync_from_async from openhands.utils.import_utils import get_impl +from openhands.utils.utils import create_registry_and_convo_stats @dataclass @@ -275,6 +276,16 @@ class DockerNestedConversationManager(ConversationManager): # Not supported - clients should connect directly to the nested server! raise ValueError('unsupported_operation') + async def request_llm_completion( + self, + sid: str, + service_id: str, + llm_config: LLMConfig, + messages: list[dict[str, str]], + ) -> str: + # Not supported - clients should connect directly to the nested server! + raise ValueError('unsupported_operation') + async def send_event_to_conversation(self, sid, data): async with httpx.AsyncClient( headers={ @@ -471,27 +482,27 @@ class DockerNestedConversationManager(ConversationManager): # This session is created here only because it is the easiest way to get a runtime, which # is the easiest way to create the needed docker container - # Run experiment manager variant test before creating session config: OpenHandsConfig = ExperimentManagerImpl.run_config_variant_test( user_id, sid, self.config ) + llm_registry, convo_stats, config = create_registry_and_convo_stats( + config, sid, user_id, settings + ) + session = Session( sid=sid, + llm_registry=llm_registry, + convo_stats=convo_stats, file_store=self.file_store, config=config, sio=self.sio, user_id=user_id, ) + llm_registry.retry_listner = session._notify_on_llm_retry agent_cls = settings.agent or config.default_agent - agent_name = agent_cls if agent_cls is not None else 'agent' - llm = LLM( - config=config.get_llm_config_from_agent(agent_name), - retry_listener=session._notify_on_llm_retry, - ) - llm = session._create_llm(agent_cls) agent_config = config.get_agent_config(agent_cls) - agent = Agent.get_cls(agent_cls)(llm, agent_config) + agent = Agent.get_cls(agent_cls)(agent_config, llm_registry) config = config.model_copy(deep=True) env_vars = config.sandbox.runtime_startup_env_vars @@ -543,6 +554,7 @@ class DockerNestedConversationManager(ConversationManager): headless_mode=False, attach_to_existing=False, main_module='openhands.server', + llm_registry=llm_registry, ) # Hack - disable setting initial env. diff --git a/openhands/server/conversation_manager/standalone_conversation_manager.py b/openhands/server/conversation_manager/standalone_conversation_manager.py index b646d40f70..2e07b187ee 100644 --- a/openhands/server/conversation_manager/standalone_conversation_manager.py +++ b/openhands/server/conversation_manager/standalone_conversation_manager.py @@ -6,12 +6,14 @@ from typing import Callable, Iterable import socketio +from openhands.core.config.llm_config import LLMConfig from openhands.core.config.openhands_config import OpenHandsConfig from openhands.core.exceptions import AgentRuntimeUnavailableError from openhands.core.logger import openhands_logger as logger from openhands.core.schema.agent import AgentState from openhands.events.action import MessageAction from openhands.events.stream import EventStreamSubscriber, session_exists +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime import get_runtime_cls from openhands.server.config.server_config import ServerConfig from openhands.server.constants import ROOM_KEY @@ -37,6 +39,7 @@ from openhands.utils.conversation_summary import ( ) from openhands.utils.import_utils import get_impl from openhands.utils.shutdown_listener import should_continue +from openhands.utils.utils import create_registry_and_convo_stats from .conversation_manager import ConversationManager @@ -332,12 +335,15 @@ class StandaloneConversationManager(ConversationManager): ) await self.close_session(oldest_conversation_id) - config = self.config.model_copy(deep=True) - + llm_registry, convo_stats, config = create_registry_and_convo_stats( + self.config, sid, user_id, settings + ) session = Session( sid=sid, file_store=self.file_store, config=config, + llm_registry=llm_registry, + convo_stats=convo_stats, sio=self.sio, user_id=user_id, ) @@ -349,7 +355,9 @@ class StandaloneConversationManager(ConversationManager): try: session.agent_session.event_stream.subscribe( EventStreamSubscriber.SERVER, - self._create_conversation_update_callback(user_id, sid, settings), + self._create_conversation_update_callback( + user_id, sid, settings, session.llm_registry + ), UPDATED_AT_CALLBACK_ID, ) except ValueError: @@ -369,6 +377,21 @@ class StandaloneConversationManager(ConversationManager): raise RuntimeError(f'no_conversation:{sid}') await session.dispatch(data) + async def request_llm_completion( + self, + sid: str, + service_id: str, + llm_config: LLMConfig, + messages: list[dict[str, str]], + ): + session = self._local_agent_loops_by_sid.get(sid) + if not session: + raise RuntimeError(f'no_conversation:{sid}') + llm_registry = session.llm_registry + return llm_registry.request_extraneous_completion( + service_id, llm_config, messages + ) + async def disconnect_from_session(self, connection_id: str): sid = self._local_connection_id_to_session_id.pop(connection_id, None) logger.info( @@ -450,6 +473,7 @@ class StandaloneConversationManager(ConversationManager): user_id: str | None, conversation_id: str, settings: Settings, + llm_registry: LLMRegistry, ) -> Callable: def callback(event, *args, **kwargs): call_async_from_sync( @@ -458,6 +482,7 @@ class StandaloneConversationManager(ConversationManager): user_id, conversation_id, settings, + llm_registry, event, ) @@ -468,6 +493,7 @@ class StandaloneConversationManager(ConversationManager): user_id: str, conversation_id: str, settings: Settings, + llm_registry: LLMRegistry, event=None, ): conversation_store = await self._get_conversation_store(user_id) @@ -495,7 +521,7 @@ class StandaloneConversationManager(ConversationManager): conversation.title == default_title ): # attempt to autogenerate if default title is in use title = await auto_generate_title( - conversation_id, user_id, self.file_store, settings + conversation_id, user_id, self.file_store, settings, llm_registry ) if title and not title.isspace(): conversation.title = title diff --git a/openhands/server/conversation_manager/utils.py b/openhands/server/conversation_manager/utils.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 07dead28c4..1db2c42e87 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -33,7 +33,6 @@ from openhands.integrations.service_types import ( ProviderType, SuggestedTask, ) -from openhands.llm.llm import LLM from openhands.runtime import get_runtime_cls from openhands.runtime.runtime_status import RuntimeStatus from openhands.server.data_models.agent_loop_info import AgentLoopInfo @@ -47,6 +46,7 @@ from openhands.server.services.conversation_service import ( setup_init_conversation_settings, ) from openhands.server.shared import ( + ConversationManagerImpl, ConversationStoreImpl, config, conversation_manager, @@ -364,7 +364,7 @@ async def get_prompt( ) prompt_template = generate_prompt_template(stringified_events) - prompt = generate_prompt(llm_config, prompt_template) + prompt = generate_prompt(llm_config, prompt_template, conversation_id) return JSONResponse( { @@ -380,8 +380,9 @@ def generate_prompt_template(events: str) -> str: return template.render(events=events) -def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str: - llm = LLM(llm_config) +def generate_prompt( + llm_config: LLMConfig, prompt_template: str, conversation_id: str +) -> str: messages = [ { 'role': 'system', @@ -393,8 +394,9 @@ def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str: }, ] - response = llm.completion(messages=messages) - raw_prompt = response['choices'][0]['message']['content'].strip() + raw_prompt = ConversationManagerImpl.request_llm_completion( + 'remember_prompt', conversation_id, llm_config, messages + ) prompt = re.search(r'(.*?)', raw_prompt, re.DOTALL) if prompt: diff --git a/openhands/server/services/conversation_service.py b/openhands/server/services/conversation_service.py index 0385168b73..1ade6fad76 100644 --- a/openhands/server/services/conversation_service.py +++ b/openhands/server/services/conversation_service.py @@ -31,20 +31,60 @@ from openhands.storage.data_models.user_secrets import UserSecrets from openhands.utils.conversation_summary import get_default_conversation_title -async def create_new_conversation( +async def initialize_conversation( + user_id: str | None, + conversation_id: str | None, + selected_repository: str | None, + selected_branch: str | None, + conversation_trigger: ConversationTrigger = ConversationTrigger.GUI, + git_provider: ProviderType | None = None, +) -> ConversationMetadata | None: + if conversation_id is None: + conversation_id = uuid.uuid4().hex + + conversation_store = await ConversationStoreImpl.get_instance(config, user_id) + + if not await conversation_store.exists(conversation_id): + logger.info( + f'New conversation ID: {conversation_id}', + extra={'user_id': user_id, 'session_id': conversation_id}, + ) + + conversation_title = get_default_conversation_title(conversation_id) + + logger.info(f'Saving metadata for conversation {conversation_id}') + convo_metadata = ConversationMetadata( + trigger=conversation_trigger, + conversation_id=conversation_id, + title=conversation_title, + user_id=user_id, + selected_repository=selected_repository, + selected_branch=selected_branch, + git_provider=git_provider, + ) + + await conversation_store.save_metadata(convo_metadata) + return convo_metadata + + try: + convo_metadata = await conversation_store.get_metadata(conversation_id) + return convo_metadata + except Exception: + pass + + return None + + +async def start_conversation( user_id: str | None, git_provider_tokens: PROVIDER_TOKEN_TYPE | None, custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None, - selected_repository: str | None, - selected_branch: str | None, initial_user_msg: str | None, image_urls: list[str] | None, replay_json: str | None, - conversation_instructions: str | None = None, - conversation_trigger: ConversationTrigger = ConversationTrigger.GUI, - attach_conversation_id: bool = False, - git_provider: ProviderType | None = None, - conversation_id: str | None = None, + conversation_id: str, + convo_metadata: ConversationMetadata, + conversation_instructions: str | None, mcp_config: MCPConfig | None = None, ) -> AgentLoopInfo: logger.info( @@ -52,7 +92,7 @@ async def create_new_conversation( extra={ 'signal': 'create_conversation', 'user_id': user_id, - 'trigger': conversation_trigger.value, + 'trigger': convo_metadata.trigger, }, ) logger.info('Loading settings') @@ -79,53 +119,25 @@ async def create_new_conversation( raise MissingSettingsError('Settings not found') session_init_args['git_provider_tokens'] = git_provider_tokens - session_init_args['selected_repository'] = selected_repository + session_init_args['selected_repository'] = convo_metadata.selected_repository session_init_args['custom_secrets'] = custom_secrets - session_init_args['selected_branch'] = selected_branch - session_init_args['git_provider'] = git_provider + session_init_args['selected_branch'] = convo_metadata.selected_branch + session_init_args['git_provider'] = convo_metadata.git_provider session_init_args['conversation_instructions'] = conversation_instructions if mcp_config: session_init_args['mcp_config'] = mcp_config conversation_init_data = ConversationInitData(**session_init_args) - logger.info('Loading conversation store') - conversation_store = await ConversationStoreImpl.get_instance(config, user_id) - logger.info('ServerConversation store loaded') - - # For nested runtimes, we allow a single conversation id, passed in on container creation - if conversation_id is None: - conversation_id = uuid.uuid4().hex - - if not await conversation_store.exists(conversation_id): - logger.info( - f'New conversation ID: {conversation_id}', - extra={'user_id': user_id, 'session_id': conversation_id}, - ) - - conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test( - user_id, conversation_id, conversation_init_data - ) - conversation_title = get_default_conversation_title(conversation_id) - - logger.info(f'Saving metadata for conversation {conversation_id}') - await conversation_store.save_metadata( - ConversationMetadata( - trigger=conversation_trigger, - conversation_id=conversation_id, - title=conversation_title, - user_id=user_id, - selected_repository=selected_repository, - selected_branch=selected_branch, - git_provider=git_provider, - llm_model=conversation_init_data.llm_model, - ) - ) + conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test( + user_id, conversation_id, conversation_init_data + ) logger.info( f'Starting agent loop for conversation {conversation_id}', extra={'user_id': user_id, 'session_id': conversation_id}, ) + initial_message_action = None if initial_user_msg or image_urls: initial_message_action = MessageAction( @@ -133,9 +145,6 @@ async def create_new_conversation( image_urls=image_urls or [], ) - if attach_conversation_id: - logger.warning('Attaching conversation ID is deprecated, skipping process') - agent_loop_info = await conversation_manager.maybe_start_agent_loop( conversation_id, conversation_init_data, @@ -147,6 +156,47 @@ async def create_new_conversation( return agent_loop_info +async def create_new_conversation( + user_id: str | None, + git_provider_tokens: PROVIDER_TOKEN_TYPE | None, + custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None, + selected_repository: str | None, + selected_branch: str | None, + initial_user_msg: str | None, + image_urls: list[str] | None, + replay_json: str | None, + conversation_instructions: str | None = None, + conversation_trigger: ConversationTrigger = ConversationTrigger.GUI, + git_provider: ProviderType | None = None, + conversation_id: str | None = None, + mcp_config: MCPConfig | None = None, +) -> AgentLoopInfo: + conversation_metadata = await initialize_conversation( + user_id, + conversation_id, + selected_repository, + selected_branch, + conversation_trigger, + git_provider, + ) + + if not conversation_metadata: + raise Exception('Failed to initialize conversation') + + return await start_conversation( + user_id, + git_provider_tokens, + custom_secrets, + initial_user_msg, + image_urls, + replay_json, + conversation_metadata.conversation_id, + conversation_metadata, + conversation_instructions, + mcp_config, + ) + + def create_provider_tokens_object( providers_set: list[ProviderType], ) -> PROVIDER_TOKEN_TYPE: diff --git a/openhands/server/services/conversation_stats.py b/openhands/server/services/conversation_stats.py new file mode 100644 index 0000000000..0f0354e65e --- /dev/null +++ b/openhands/server/services/conversation_stats.py @@ -0,0 +1,77 @@ +import base64 +import pickle +from threading import Lock + +from openhands.core.logger import openhands_logger as logger +from openhands.llm.llm_registry import RegistryEvent +from openhands.llm.metrics import Metrics +from openhands.storage.files import FileStore +from openhands.storage.locations import get_conversation_stats_filename + + +class ConversationStats: + def __init__( + self, + file_store: FileStore | None, + conversation_id: str, + user_id: str | None, + ): + self.metrics_path = get_conversation_stats_filename(conversation_id, user_id) + self.file_store = file_store + self.conversation_id = conversation_id + self.user_id = user_id + + self._save_lock = Lock() + + self.service_to_metrics: dict[str, Metrics] = {} + self.restored_metrics: dict[str, Metrics] = {} + + # Always attempt to restore registry if it exists + self.maybe_restore_metrics() + + def save_metrics(self): + if not self.file_store: + return + + with self._save_lock: + pickled = pickle.dumps(self.service_to_metrics) + serialized_metrics = base64.b64encode(pickled).decode('utf-8') + self.file_store.write(self.metrics_path, serialized_metrics) + + def maybe_restore_metrics(self): + if not self.file_store or not self.conversation_id: + return + + try: + encoded = self.file_store.read(self.metrics_path) + pickled = base64.b64decode(encoded) + self.restored_metrics = pickle.loads(pickled) + logger.info(f'restored metrics: {self.conversation_id}') + except FileNotFoundError: + pass + + def get_combined_metrics(self) -> Metrics: + total_metrics = Metrics() + for metrics in self.service_to_metrics.values(): + total_metrics.merge(metrics) + + logger.info(f'metrics by all services: {self.service_to_metrics}') + logger.info(f'combined metrics\n\n{total_metrics}') + return total_metrics + + def get_metrics_for_service(self, service_id: str) -> Metrics: + if service_id not in self.service_to_metrics: + raise Exception(f'LLM service does not exist {service_id}') + + return self.service_to_metrics[service_id] + + def register_llm(self, event: RegistryEvent): + # Listen for llm creations and track their metrics + llm = event.llm + service_id = event.service_id + + if service_id in self.restored_metrics: + llm.metrics = self.restored_metrics[service_id].copy() + del self.restored_metrics[service_id] + + self.service_to_metrics[service_id] = llm.metrics diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 8581656b19..0a9acf6033 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -21,6 +21,7 @@ from openhands.integrations.provider import ( PROVIDER_TOKEN_TYPE, ProviderHandler, ) +from openhands.llm.llm_registry import LLMRegistry from openhands.mcp import add_mcp_tools_to_agent from openhands.memory.memory import Memory from openhands.microagent.microagent import BaseMicroagent @@ -29,6 +30,7 @@ from openhands.runtime.base import Runtime from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime from openhands.runtime.runtime_status import RuntimeStatus from openhands.security import SecurityAnalyzer, options +from openhands.server.services.conversation_stats import ConversationStats from openhands.storage.data_models.user_secrets import UserSecrets from openhands.storage.files import FileStore from openhands.utils.async_utils import EXECUTOR, call_sync_from_async @@ -48,6 +50,7 @@ class AgentSession: sid: str user_id: str | None event_stream: EventStream + llm_registry: LLMRegistry file_store: FileStore controller: AgentController | None = None runtime: Runtime | None = None @@ -63,6 +66,8 @@ class AgentSession: self, sid: str, file_store: FileStore, + llm_registry: LLMRegistry, + convo_stats: ConversationStats, status_callback: Callable | None = None, user_id: str | None = None, ) -> None: @@ -80,6 +85,8 @@ class AgentSession: self.logger = OpenHandsLoggerAdapter( extra={'session_id': sid, 'user_id': user_id} ) + self.llm_registry = llm_registry + self.convo_stats = convo_stats async def start( self, @@ -340,6 +347,7 @@ class AgentSession: self.runtime = runtime_cls( config=config, event_stream=self.event_stream, + llm_registry=self.llm_registry, sid=self.sid, plugins=agent.sandbox_plugins, status_callback=self._status_callback, @@ -360,6 +368,7 @@ class AgentSession: self.runtime = runtime_cls( config=config, event_stream=self.event_stream, + llm_registry=self.llm_registry, sid=self.sid, plugins=agent.sandbox_plugins, status_callback=self._status_callback, @@ -441,6 +450,7 @@ class AgentSession: user_id=self.user_id, file_store=self.file_store, event_stream=self.event_stream, + convo_stats=self.convo_stats, agent=agent, iteration_delta=int(max_iterations), budget_per_task_delta=max_budget_per_task, @@ -490,6 +500,15 @@ class AgentSession: ) return memory + def get_state(self) -> AgentState | None: + controller = self.controller + if controller: + return controller.state.agent_state + if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE: + # If 5 minutes have elapsed and we still don't have a controller, something has gone wrong + return AgentState.ERROR + return None + def _maybe_restore_state(self) -> State | None: """Helper method to handle state restore logic.""" restored_state = None @@ -510,14 +529,5 @@ class AgentSession: self.logger.debug('No events found, no state to restore') return restored_state - def get_state(self) -> AgentState | None: - controller = self.controller - if controller: - return controller.state.agent_state - if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE: - # If 5 minutes have elapsed and we still don't have a controller, something has gone wrong - return AgentState.ERROR - return None - def is_closed(self) -> bool: return self._closed diff --git a/openhands/server/session/conversation.py b/openhands/server/session/conversation.py index 5972a367e3..66ccbccc51 100644 --- a/openhands/server/session/conversation.py +++ b/openhands/server/session/conversation.py @@ -2,6 +2,7 @@ import asyncio from openhands.core.config import OpenHandsConfig from openhands.events.stream import EventStream +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime import get_runtime_cls from openhands.runtime.base import Runtime from openhands.security import SecurityAnalyzer, options @@ -45,6 +46,7 @@ class ServerConversation: else: runtime_cls = get_runtime_cls(self.config.runtime) runtime = runtime_cls( + llm_registry=LLMRegistry(self.config), config=config, event_stream=self.event_stream, sid=self.sid, diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 458f45abe2..fe20df3558 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -1,6 +1,5 @@ import asyncio import time -from copy import deepcopy from logging import LoggerAdapter import socketio @@ -28,9 +27,10 @@ from openhands.events.observation.agent import RecallObservation from openhands.events.observation.error import ErrorObservation from openhands.events.serialization import event_from_dict, event_to_dict from openhands.events.stream import EventStreamSubscriber -from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.runtime_status import RuntimeStatus from openhands.server.constants import ROOM_KEY +from openhands.server.services.conversation_stats import ConversationStats from openhands.server.session.agent_session import AgentSession from openhands.server.session.conversation_init_data import ConversationInitData from openhands.storage.data_models.settings import Settings @@ -45,6 +45,7 @@ class Session: agent_session: AgentSession loop: asyncio.AbstractEventLoop config: OpenHandsConfig + llm_registry: LLMRegistry file_store: FileStore user_id: str | None logger: LoggerAdapter @@ -53,6 +54,8 @@ class Session: self, sid: str, config: OpenHandsConfig, + llm_registry: LLMRegistry, + convo_stats: ConversationStats, file_store: FileStore, sio: socketio.AsyncServer | None, user_id: str | None = None, @@ -62,17 +65,21 @@ class Session: self.last_active_ts = int(time.time()) self.file_store = file_store self.logger = OpenHandsLoggerAdapter(extra={'session_id': sid}) + self.llm_registry = llm_registry + self.convo_stats = convo_stats self.agent_session = AgentSession( sid, file_store, + llm_registry=self.llm_registry, + convo_stats=convo_stats, status_callback=self.queue_status_message, user_id=user_id, ) self.agent_session.event_stream.subscribe( EventStreamSubscriber.SERVER, self.on_event, self.sid ) - # Copying this means that when we update variables they are not applied to the shared global configuration! - self.config = deepcopy(config) + self.config = config + # Lazy import to avoid circular dependency from openhands.experiments.experiment_manager import ExperimentManagerImpl @@ -140,13 +147,6 @@ class Session: else self.config.max_budget_per_task ) - # This is a shallow copy of the default LLM config, so changes here will - # persist if we retrieve the default LLM config again when constructing - # the agent - default_llm_config = self.config.get_llm_config() - default_llm_config.model = settings.llm_model or '' - default_llm_config.api_key = settings.llm_api_key - default_llm_config.base_url = settings.llm_base_url self.config.search_api_key = settings.search_api_key if settings.sandbox_api_key: self.config.sandbox.api_key = settings.sandbox_api_key.get_secret_value() @@ -181,10 +181,9 @@ class Session: ) # TODO: override other LLM config & agent config groups (#2075) - - llm = self._create_llm(agent_cls) agent_config = self.config.get_agent_config(agent_cls) - + agent_name = agent_cls if agent_cls is not None else 'agent' + llm_config = self.config.get_llm_config_from_agent(agent_name) if settings.enable_default_condenser: # Default condenser chains three condensers together: # 1. a conversation window condenser that handles explicit @@ -200,7 +199,7 @@ class Session: ConversationWindowCondenserConfig(), BrowserOutputCondenserConfig(attention_window=2), LLMSummarizingCondenserConfig( - llm_config=llm.config, keep_first=4, max_size=120 + llm_config=llm_config, keep_first=4, max_size=120 ), ] ) @@ -208,12 +207,14 @@ class Session: self.logger.info( f'Enabling pipeline condenser with:' f' browser_output_masking(attention_window=2), ' - f' llm(model="{llm.config.model}", ' - f' base_url="{llm.config.base_url}", ' + f' llm(model="{llm_config.model}", ' + f' base_url="{llm_config.base_url}", ' f' keep_first=4, max_size=80)' ) agent_config.condenser = default_condenser_config - agent = Agent.get_cls(agent_cls)(llm, agent_config) + agent = Agent.get_cls(agent_cls)(agent_config, self.llm_registry) + + self.llm_registry.retry_listner = self._notify_on_llm_retry git_provider_tokens = None selected_repository = None @@ -269,14 +270,6 @@ class Session: ) return - def _create_llm(self, agent_cls: str | None) -> LLM: - """Initialize LLM, extracted for testing.""" - agent_name = agent_cls if agent_cls is not None else 'agent' - return LLM( - config=self.config.get_llm_config_from_agent(agent_name), - retry_listener=self._notify_on_llm_retry, - ) - def _notify_on_llm_retry(self, retries: int, max: int) -> None: self.queue_status_message( 'info', RuntimeStatus.LLM_RETRY, f'Retrying LLM request, {retries} / {max}' diff --git a/openhands/storage/locations.py b/openhands/storage/locations.py index 192721da86..6ae8879dc5 100644 --- a/openhands/storage/locations.py +++ b/openhands/storage/locations.py @@ -30,5 +30,13 @@ def get_conversation_agent_state_filename(sid: str, user_id: str | None = None) return f'{get_conversation_dir(sid, user_id)}agent_state.pkl' +def get_conversation_llm_registry_filename(sid: str, user_id: str | None = None) -> str: + return f'{get_conversation_dir(sid, user_id)}llm_registry.json' + + +def get_conversation_stats_filename(sid: str, user_id: str | None = None) -> str: + return f'{get_conversation_dir(sid, user_id)}convo_stats.json' + + def get_experiment_config_filename(sid: str, user_id: str | None = None) -> str: return f'{get_conversation_dir(sid, user_id)}exp_config.json' diff --git a/openhands/utils/conversation_summary.py b/openhands/utils/conversation_summary.py index 11fa030f6a..b20fdafe66 100644 --- a/openhands/utils/conversation_summary.py +++ b/openhands/utils/conversation_summary.py @@ -7,13 +7,16 @@ from openhands.core.logger import openhands_logger as logger from openhands.events.action.message import MessageAction from openhands.events.event import EventSource from openhands.events.event_store import EventStore -from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.storage.data_models.settings import Settings from openhands.storage.files import FileStore async def generate_conversation_title( - message: str, llm_config: LLMConfig, max_length: int = 50 + message: str, + llm_config: LLMConfig, + llm_registry: LLMRegistry, + max_length: int = 50, ) -> Optional[str]: """Generate a concise title for a conversation based on the first user message. @@ -35,8 +38,6 @@ async def generate_conversation_title( truncated_message = message try: - llm = LLM(llm_config) - # Create a simple prompt for the LLM to generate a title messages = [ { @@ -49,8 +50,9 @@ async def generate_conversation_title( }, ] - response = llm.completion(messages=messages) - title = response.choices[0].message.content.strip() + title = llm_registry.request_extraneous_completion( + 'convo_title_creator', llm_config, messages + ) # Ensure the title isn't too long if len(title) > max_length: @@ -75,7 +77,11 @@ def get_default_conversation_title(conversation_id: str) -> str: async def auto_generate_title( - conversation_id: str, user_id: str | None, file_store: FileStore, settings: Settings + conversation_id: str, + user_id: str | None, + file_store: FileStore, + settings: Settings, + llm_registry: LLMRegistry, ) -> str: """Auto-generate a title for a conversation based on the first user message. Uses LLM-based title generation if available, otherwise falls back to a simple truncation. @@ -116,7 +122,7 @@ async def auto_generate_title( # Try to generate title using LLM llm_title = await generate_conversation_title( - first_user_message, llm_config + first_user_message, llm_config, llm_registry ) if llm_title: logger.info(f'Generated title using LLM: {llm_title}') diff --git a/openhands/utils/utils.py b/openhands/utils/utils.py new file mode 100644 index 0000000000..7abd386f31 --- /dev/null +++ b/openhands/utils/utils.py @@ -0,0 +1,37 @@ +from copy import deepcopy + +from openhands.core.config.openhands_config import OpenHandsConfig +from openhands.llm.llm_registry import LLMRegistry +from openhands.server.services.conversation_stats import ConversationStats +from openhands.storage import get_file_store +from openhands.storage.data_models.settings import Settings + + +def setup_llm_config(config: OpenHandsConfig, settings: Settings) -> OpenHandsConfig: + # Copying this means that when we update variables they are not applied to the shared global configuration! + config = deepcopy(config) + + llm_config = config.get_llm_config() + llm_config.model = settings.llm_model or '' + llm_config.api_key = settings.llm_api_key + llm_config.base_url = settings.llm_base_url + config.set_llm_config(llm_config) + return config + + +def create_registry_and_convo_stats( + config: OpenHandsConfig, + sid: str, + user_id: str | None, + user_settings: Settings | None = None, +) -> tuple[LLMRegistry, ConversationStats, OpenHandsConfig]: + user_config = config + if user_settings: + user_config = setup_llm_config(config, user_settings) + + agent_cls = user_settings.agent if user_settings else None + llm_registry = LLMRegistry(user_config, agent_cls) + file_store = get_file_store(user_config.file_store, user_config.file_store_path) + convo_stats = ConversationStats(file_store, sid, user_id) + llm_registry.subscribe(convo_stats.register_llm) + return llm_registry, convo_stats, user_config diff --git a/pytest.ini b/pytest.ini index 7a222b2d34..c6be6e987a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,4 @@ [pytest] addopts = -p no:warnings +asyncio_mode = auto asyncio_default_fixture_loop_scope = function diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/runtime/conftest.py b/tests/runtime/conftest.py index ecd95fac09..d657fc5827 100644 --- a/tests/runtime/conftest.py +++ b/tests/runtime/conftest.py @@ -10,6 +10,7 @@ from pytest import TempPathFactory from openhands.core.config import MCPConfig, OpenHandsConfig, load_openhands_config from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.base import Runtime from openhands.runtime.impl.cli.cli_runtime import CLIRuntime from openhands.runtime.impl.docker.docker_runtime import DockerRuntime @@ -268,9 +269,13 @@ def _load_runtime( ) event_stream = EventStream(sid, file_store) + # Create a LLMRegistry instance for the runtime + llm_registry = LLMRegistry(config=OpenHandsConfig()) + runtime = runtime_cls( config=config, event_stream=event_stream, + llm_registry=llm_registry, sid=sid, plugins=plugins, ) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/llm/test_acompletion.py b/tests/unit/llm/test_acompletion.py index 5d519562a3..434bd676f7 100644 --- a/tests/unit/llm/test_acompletion.py +++ b/tests/unit/llm/test_acompletion.py @@ -20,7 +20,7 @@ def test_llm(): def _get_llm(type_: type[LLM]): with _patch_http(): - return type_(config=config.get_llm_config()) + return type_(config=config.get_llm_config(), service_id='test_service') @pytest.fixture diff --git a/tests/unit/llm/test_llm.py b/tests/unit/llm/test_llm.py index 9b6ccea9d2..83c2c50020 100644 --- a/tests/unit/llm/test_llm.py +++ b/tests/unit/llm/test_llm.py @@ -38,7 +38,7 @@ def default_config(): def test_llm_init_with_default_config(default_config): - llm = LLM(default_config) + llm = LLM(default_config, service_id='test-service') assert llm.config.model == 'gpt-4o' assert llm.config.api_key.get_secret_value() == 'test_key' assert isinstance(llm.metrics, Metrics) @@ -129,7 +129,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config): 'max_input_tokens': 8000, 'max_output_tokens': 2000, } - llm = LLM(default_config) + llm = LLM(default_config, service_id='test-service') llm.init_model_info() assert llm.config.max_input_tokens == 8000 assert llm.config.max_output_tokens == 2000 @@ -138,7 +138,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config): @patch('openhands.llm.llm.litellm.get_model_info') def test_llm_init_without_model_info(mock_get_model_info, default_config): mock_get_model_info.side_effect = Exception('Model info not available') - llm = LLM(default_config) + llm = LLM(default_config, service_id='test-service') llm.init_model_info() assert llm.config.max_input_tokens is None assert llm.config.max_output_tokens is None @@ -154,7 +154,7 @@ def test_llm_init_with_custom_config(): top_p=0.9, top_k=None, ) - llm = LLM(custom_config) + llm = LLM(custom_config, service_id='test-service') assert llm.config.model == 'custom-model' assert llm.config.api_key.get_secret_value() == 'custom_key' assert llm.config.max_input_tokens == 5000 @@ -168,7 +168,7 @@ def test_llm_init_with_custom_config(): def test_llm_top_k_in_completion_when_set(mock_litellm_completion): # Create a config with top_k set config_with_top_k = LLMConfig(top_k=50) - llm = LLM(config_with_top_k) + llm = LLM(config_with_top_k, service_id='test-service') # Define a side effect function to check top_k def side_effect(*args, **kwargs): @@ -186,7 +186,7 @@ def test_llm_top_k_in_completion_when_set(mock_litellm_completion): def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion): # Create a config with top_k set to None config_without_top_k = LLMConfig(top_k=None) - llm = LLM(config_without_top_k) + llm = LLM(config_without_top_k, service_id='test-service') # Define a side effect function to check top_k def side_effect(*args, **kwargs): @@ -202,7 +202,7 @@ def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion): def test_llm_init_with_metrics(): config = LLMConfig(model='gpt-4o', api_key='test_key') metrics = Metrics() - llm = LLM(config, metrics=metrics) + llm = LLM(config, metrics=metrics, service_id='test-service') assert llm.metrics is metrics assert ( llm.metrics.model_name == 'default' @@ -224,7 +224,7 @@ def test_response_latency_tracking(mock_time, mock_litellm_completion): # Create LLM instance and make a completion call config = LLMConfig(model='gpt-4o', api_key='test_key') - llm = LLM(config) + llm = LLM(config, service_id='test-service') response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}]) # Verify the response latency was tracked correctly @@ -257,7 +257,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config): 'max_input_tokens': 7000, 'max_output_tokens': 1500, } - llm = LLM(default_config) + llm = LLM(default_config, service_id='test-service') llm.init_model_info() assert llm.config.max_input_tokens == 7000 assert llm.config.max_output_tokens == 1500 @@ -280,7 +280,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config): default_config.model = ( 'custom-model' # Use a model not in FUNCTION_CALLING_SUPPORTED_MODELS ) - llm = LLM(default_config) + llm = LLM(default_config, service_id='test-service') llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], tools=[ @@ -292,7 +292,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config): # Test with Grok-4 model that doesn't support stop parameter default_config.model = 'xai/grok-4-0709' - llm = LLM(default_config) + llm = LLM(default_config, service_id='test-service') llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], tools=[ @@ -314,7 +314,7 @@ def test_completion_with_mocked_logger( 'choices': [{'message': {'content': 'Test response'}}] } - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') response = llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -345,7 +345,7 @@ def test_completion_retries( {'choices': [{'message': {'content': 'Retry successful'}}]}, ] - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') response = llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -365,7 +365,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config {'choices': [{'message': {'content': 'Retry successful'}}]}, ] - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') response = llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -387,7 +387,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config def test_completion_operation_cancelled(mock_litellm_completion, default_config): mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled') - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') with pytest.raises(OperationCancelled): llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], @@ -404,7 +404,7 @@ def test_completion_keyboard_interrupt(mock_litellm_completion, default_config): mock_litellm_completion.side_effect = side_effect - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') with pytest.raises(OperationCancelled): try: llm.completion( @@ -428,7 +428,7 @@ def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_ mock_litellm_completion.side_effect = side_effect - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') result = llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -469,7 +469,7 @@ def test_completion_retry_with_llm_no_response_error_zero_temp( mock_litellm_completion.side_effect = side_effect # Create LLM instance and make a completion call - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') response = llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -509,7 +509,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp( 'LLM did not return a response' ) - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') with pytest.raises(LLMNoResponseError): llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], @@ -575,7 +575,7 @@ def test_gemini_25_pro_function_calling(mock_httpx_get, mock_get_model_info): for model_name, expected_support in test_cases: config = LLMConfig(model=model_name, api_key='test_key') - llm = LLM(config) + llm = LLM(config, service_id='test-service') assert llm.is_function_calling_active() == expected_support, ( f'Expected function calling support to be {expected_support} for model {model_name}' @@ -617,7 +617,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp_successful_ret mock_litellm_completion.side_effect = side_effect # Create LLM instance and make a completion call with non-zero temperature - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') response = llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -677,7 +677,7 @@ def test_completion_retry_with_llm_no_response_error_successful_retry( mock_litellm_completion.side_effect = side_effect # Create LLM instance and make a completion call with explicit temperature=0 - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') response = llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -709,7 +709,7 @@ def test_completion_with_litellm_mock(mock_litellm_completion, default_config): } mock_litellm_completion.return_value = mock_response - test_llm = LLM(config=default_config) + test_llm = LLM(config=default_config, service_id='test-service') response = test_llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -743,7 +743,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config): } # Initialize LLM and call completion - llm = LLM(config=gemini_config) + llm = LLM(config=gemini_config, service_id='test-service') llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}]) # Verify that litellm_completion was called with the 'thinking' parameter @@ -762,7 +762,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config): @patch('openhands.llm.llm.litellm.token_counter') def test_get_token_count_with_dict_messages(mock_token_counter, default_config): mock_token_counter.return_value = 42 - llm = LLM(default_config) + llm = LLM(default_config, service_id='test-service') messages = [{'role': 'user', 'content': 'Hello!'}] token_count = llm.get_token_count(messages) @@ -777,7 +777,7 @@ def test_get_token_count_with_dict_messages(mock_token_counter, default_config): def test_get_token_count_with_message_objects( mock_token_counter, default_config, mock_logger ): - llm = LLM(default_config) + llm = LLM(default_config, service_id='test-service') # Create a Message object and its equivalent dict message_obj = Message(role='user', content=[TextContent(text='Hello!')]) @@ -806,7 +806,7 @@ def test_get_token_count_with_custom_tokenizer( config = copy.deepcopy(default_config) config.custom_tokenizer = 'custom/tokenizer' - llm = LLM(config) + llm = LLM(config, service_id='test-service') messages = [{'role': 'user', 'content': 'Hello!'}] token_count = llm.get_token_count(messages) @@ -823,7 +823,7 @@ def test_get_token_count_error_handling( mock_token_counter, default_config, mock_logger ): mock_token_counter.side_effect = Exception('Token counting failed') - llm = LLM(default_config) + llm = LLM(default_config, service_id='test-service') messages = [{'role': 'user', 'content': 'Hello!'}] token_count = llm.get_token_count(messages) @@ -865,7 +865,7 @@ def test_llm_token_usage(mock_litellm_completion, default_config): # We'll make mock_litellm_completion return these responses in sequence mock_litellm_completion.side_effect = [mock_response_1, mock_response_2] - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') # First call llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}]) @@ -924,7 +924,7 @@ def test_accumulated_token_usage(mock_litellm_completion, default_config): mock_litellm_completion.side_effect = [mock_response_1, mock_response_2] # Create LLM instance - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') # First call llm.completion(messages=[{'role': 'user', 'content': 'First message'}]) @@ -980,7 +980,7 @@ def test_completion_with_log_completions(mock_litellm_completion, default_config } mock_litellm_completion.return_value = mock_response - test_llm = LLM(config=default_config) + test_llm = LLM(config=default_config, service_id='test-service') response = test_llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -1006,7 +1006,7 @@ def test_llm_base_url_auto_protocol_patch(mock_get): mock_get.return_value.status_code = 200 mock_get.return_value.json.return_value = {'model': 'fake'} - llm = LLM(config=config) + llm = LLM(config=config, service_id='test-service') llm.init_model_info() called_url = mock_get.call_args[0][0] @@ -1020,7 +1020,7 @@ def test_unknown_model_token_limits(): """Test that models without known token limits get None for both max_output_tokens and max_input_tokens.""" # Create LLM instance with a non-existent model to avoid litellm having model info for it config = LLMConfig(model='non-existent-model', api_key='test_key') - llm = LLM(config) + llm = LLM(config, service_id='test-service') # Verify max_output_tokens and max_input_tokens are initialized to None (default value) assert llm.config.max_output_tokens is None @@ -1031,7 +1031,7 @@ def test_max_tokens_from_model_info(): """Test that max_output_tokens and max_input_tokens are correctly initialized from model info.""" # Create LLM instance with GPT-4 model which has known token limits config = LLMConfig(model='gpt-4', api_key='test_key') - llm = LLM(config) + llm = LLM(config, service_id='test-service') # GPT-4 has specific token limits # These are the expected values from litellm @@ -1043,7 +1043,7 @@ def test_claude_3_7_sonnet_max_output_tokens(): """Test that Claude 3.7 Sonnet models get the special 64000 max_output_tokens value and default max_input_tokens.""" # Create LLM instance with Claude 3.7 Sonnet model config = LLMConfig(model='claude-3-7-sonnet', api_key='test_key') - llm = LLM(config) + llm = LLM(config, service_id='test-service') # Verify max_output_tokens is set to 64000 for Claude 3.7 Sonnet assert llm.config.max_output_tokens == 64000 @@ -1055,7 +1055,7 @@ def test_claude_sonnet_4_max_output_tokens(): """Test that Claude Sonnet 4 models get the correct max_output_tokens and max_input_tokens values.""" # Create LLM instance with a Claude Sonnet 4 model config = LLMConfig(model='claude-sonnet-4-20250514', api_key='test_key') - llm = LLM(config) + llm = LLM(config, service_id='test-service') # Verify max_output_tokens is set to the expected value assert llm.config.max_output_tokens == 64000 @@ -1068,7 +1068,7 @@ def test_sambanova_deepseek_model_max_output_tokens(): """Test that SambaNova DeepSeek-V3-0324 model gets the correct max_output_tokens value.""" # Create LLM instance with SambaNova DeepSeek model config = LLMConfig(model='sambanova/DeepSeek-V3-0324', api_key='test_key') - llm = LLM(config) + llm = LLM(config, service_id='test-service') # SambaNova DeepSeek model has specific token limits # This is the expected value from litellm @@ -1081,7 +1081,7 @@ def test_max_output_tokens_override_in_config(): config = LLMConfig( model='claude-sonnet-4-20250514', api_key='test_key', max_output_tokens=2048 ) - llm = LLM(config) + llm = LLM(config, service_id='test-service') # Verify the config has the overridden max_output_tokens value assert llm.config.max_output_tokens == 2048 @@ -1098,7 +1098,7 @@ def test_azure_model_default_max_tokens(): ) # Create LLM instance with Azure model - llm = LLM(azure_config) + llm = LLM(azure_config, service_id='test-service') # Verify the config has the default max_output_tokens value assert llm.config.max_output_tokens is None # Default value @@ -1143,7 +1143,7 @@ def test_gemini_none_reasoning_effort_uses_thinking_budget(mock_completion): 'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, } - llm = LLM(config) + llm = LLM(config, service_id='test-service') sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] llm.completion(messages=sample_messages) @@ -1167,7 +1167,7 @@ def test_gemini_low_reasoning_effort_uses_thinking_budget(mock_completion): 'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, } - llm = LLM(config) + llm = LLM(config, service_id='test-service') sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] llm.completion(messages=sample_messages) @@ -1191,7 +1191,7 @@ def test_gemini_medium_reasoning_effort_passes_through(mock_completion): 'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, } - llm = LLM(config) + llm = LLM(config, service_id='test-service') sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] llm.completion(messages=sample_messages) @@ -1214,7 +1214,7 @@ def test_gemini_high_reasoning_effort_passes_through(mock_completion): 'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, } - llm = LLM(config) + llm = LLM(config, service_id='test-service') sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] llm.completion(messages=sample_messages) @@ -1235,7 +1235,7 @@ def test_non_gemini_uses_reasoning_effort(mock_completion): 'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, } - llm = LLM(config) + llm = LLM(config, service_id='test-service') sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] llm.completion(messages=sample_messages) @@ -1259,7 +1259,7 @@ def test_non_reasoning_model_no_optimization(mock_completion): 'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, } - llm = LLM(config) + llm = LLM(config, service_id='test-service') sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] llm.completion(messages=sample_messages) @@ -1285,7 +1285,7 @@ def test_gemini_performance_optimization_end_to_end(mock_completion): assert config.reasoning_effort is None # Create LLM and make completion - llm = LLM(config) + llm = LLM(config, service_id='test-service') messages = [{'role': 'user', 'content': 'Solve this complex problem'}] response = llm.completion(messages=messages) diff --git a/tests/unit/resolver/github/test_issue_handler_error_handling.py b/tests/unit/resolver/github/test_issue_handler_error_handling.py index 63a79b3558..ec58f18f68 100644 --- a/tests/unit/resolver/github/test_issue_handler_error_handling.py +++ b/tests/unit/resolver/github/test_issue_handler_error_handling.py @@ -207,7 +207,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con ), ] - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') handler = ServiceContextIssue( GithubIssueHandler('test-owner', 'test-repo', 'test-token'), default_config ) @@ -251,7 +251,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config): ) # Initialize LLM and handler - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') handler = ServiceContextPR( GithubPRHandler('test-owner', 'test-repo', 'test-token'), default_config ) diff --git a/tests/unit/resolver/github/test_resolve_issues.py b/tests/unit/resolver/github/test_resolve_issues.py index 6bda19ef05..002735e9b7 100644 --- a/tests/unit/resolver/github/test_resolve_issues.py +++ b/tests/unit/resolver/github/test_resolve_issues.py @@ -463,7 +463,7 @@ async def test_process_issue( [], ) handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue' - handler_instance.llm = LLM(llm_config) + handler_instance.llm = LLM(llm_config, service_id='test-service') # Mock the runtime and its methods mock_runtime = MagicMock() diff --git a/tests/unit/resolver/gitlab/test_gitlab_issue_handler_error_handling.py b/tests/unit/resolver/gitlab/test_gitlab_issue_handler_error_handling.py index c3be6277ce..1dbd826376 100644 --- a/tests/unit/resolver/gitlab/test_gitlab_issue_handler_error_handling.py +++ b/tests/unit/resolver/gitlab/test_gitlab_issue_handler_error_handling.py @@ -209,7 +209,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con ), ] - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') handler = ServiceContextIssue( GitlabIssueHandler('test-owner', 'test-repo', 'test-token'), default_config ) @@ -253,7 +253,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config): ) # Initialize LLM and handler - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') handler = ServiceContextPR( GitlabPRHandler('test-owner', 'test-repo', 'test-token'), default_config ) diff --git a/tests/unit/resolver/gitlab/test_gitlab_resolve_issues.py b/tests/unit/resolver/gitlab/test_gitlab_resolve_issues.py index 20c4fa7dd7..2cedb21b49 100644 --- a/tests/unit/resolver/gitlab/test_gitlab_resolve_issues.py +++ b/tests/unit/resolver/gitlab/test_gitlab_resolve_issues.py @@ -500,7 +500,7 @@ async def test_process_issue( [], ) handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue' - handler_instance.llm = LLM(llm_config) + handler_instance.llm = LLM(llm_config, service_id='test-service') # Create mock runtime and mock run_controller mock_runtime = MagicMock() diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index c0bfe21573..75ad435dbd 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -18,6 +18,7 @@ from openhands.controller.state.control_flags import ( from openhands.controller.state.state import State from openhands.core.config import OpenHandsConfig from openhands.core.config.agent_config import AgentConfig +from openhands.core.config.llm_config import LLMConfig from openhands.core.main import run_controller from openhands.core.schema import AgentState from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber @@ -33,6 +34,7 @@ from openhands.events.observation.agent import RecallObservation from openhands.events.observation.empty import NullObservation from openhands.events.serialization import event_to_dict from openhands.llm import LLM +from openhands.llm.llm_registry import LLMRegistry, RegistryEvent from openhands.llm.metrics import Metrics, TokenUsage from openhands.memory.condenser.condenser import Condensation from openhands.memory.condenser.impl.conversation_window_condenser import ( @@ -45,6 +47,7 @@ from openhands.runtime.impl.action_execution.action_execution_client import ( ActionExecutionClient, ) from openhands.runtime.runtime_status import RuntimeStatus +from openhands.server.services.conversation_stats import ConversationStats from openhands.storage.memory import InMemoryFileStore @@ -61,15 +64,43 @@ def event_loop(): @pytest.fixture -def mock_agent(): - agent = MagicMock(spec=Agent) - agent.llm = MagicMock(spec=LLM) - agent.llm.metrics = Metrics() - agent.llm.config = OpenHandsConfig().get_llm_config() +def mock_agent_with_stats(): + """Create a mock agent with properly connected LLM registry and conversation stats.""" + import uuid - # Add config with enable_mcp attribute - agent.config = MagicMock(spec=AgentConfig) - agent.config.enable_mcp = True + # Create LLM registry + config = OpenHandsConfig() + llm_registry = LLMRegistry(config=config) + + # Create conversation stats + file_store = InMemoryFileStore({}) + conversation_id = f'test-conversation-{uuid.uuid4()}' + conversation_stats = ConversationStats( + file_store=file_store, conversation_id=conversation_id, user_id='test-user' + ) + + # Connect registry to stats (this is the key requirement) + llm_registry.subscribe(conversation_stats.register_llm) + + # Create mock agent + agent = MagicMock(spec=Agent) + agent_config = MagicMock(spec=AgentConfig) + llm_config = LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + agent_config.disabled_microagents = [] + agent_config.enable_mcp = True + llm_registry.service_to_llm.clear() + mock_llm = llm_registry.get_llm('agent_llm', llm_config) + agent.llm = mock_llm + agent.name = 'test-agent' + agent.sandbox_plugins = [] + agent.config = agent_config + agent.prompt_manager = MagicMock() # Add a proper system message mock system_message = SystemMessageAction( @@ -79,7 +110,7 @@ def mock_agent(): system_message._id = -1 # Set invalid ID to avoid the ID check agent.get_system_message.return_value = system_message - return agent + return agent, conversation_stats, llm_registry @pytest.fixture @@ -134,10 +165,13 @@ async def send_event_to_controller(controller, event): @pytest.mark.asyncio -async def test_set_agent_state(mock_agent, mock_event_stream): +async def test_set_agent_state(mock_agent_with_stats, mock_event_stream): + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -152,10 +186,13 @@ async def test_set_agent_state(mock_agent, mock_event_stream): @pytest.mark.asyncio -async def test_on_event_message_action(mock_agent, mock_event_stream): +async def test_on_event_message_action(mock_agent_with_stats, mock_event_stream): + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -169,10 +206,15 @@ async def test_on_event_message_action(mock_agent, mock_event_stream): @pytest.mark.asyncio -async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream): +async def test_on_event_change_agent_state_action( + mock_agent_with_stats, mock_event_stream +): + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -186,10 +228,17 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream) @pytest.mark.asyncio -async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback): +async def test_react_to_exception( + mock_agent_with_stats, + mock_event_stream, + mock_status_callback, +): + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, status_callback=mock_status_callback, iteration_delta=10, sid='test', @@ -204,12 +253,17 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal @pytest.mark.asyncio async def test_react_to_content_policy_violation( - mock_agent, mock_event_stream, mock_status_callback + mock_agent_with_stats, + mock_event_stream, + mock_status_callback, ): """Test that the controller properly handles content policy violations from the LLM.""" + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, status_callback=mock_status_callback, iteration_delta=10, sid='test', @@ -246,18 +300,16 @@ async def test_react_to_content_policy_violation( @pytest.mark.asyncio async def test_run_controller_with_fatal_error( - test_event_stream, mock_memory, mock_agent + test_event_stream, mock_memory, mock_agent_with_stats ): config = OpenHandsConfig() + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats def agent_step_fn(state): print(f'agent_step_fn received state: {state}') return CmdRunAction(command='ls') mock_agent.step = agent_step_fn - mock_agent.llm = MagicMock(spec=LLM) - mock_agent.llm.metrics = Metrics() - mock_agent.llm.config = config.get_llm_config() runtime = MagicMock(spec=ActionExecutionClient) @@ -284,15 +336,17 @@ async def test_run_controller_with_fatal_error( EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()) ) - state = await run_controller( - config=config, - initial_user_action=MessageAction(content='Test message'), - runtime=runtime, - sid='test', - agent=mock_agent, - fake_user_response_fn=lambda _: 'repeat', - memory=mock_memory, - ) + # Mock the create_agent function to return our mock agent + with patch('openhands.core.main.create_agent', return_value=mock_agent): + state = await run_controller( + config=config, + initial_user_action=MessageAction(content='Test message'), + runtime=runtime, + sid='test', + fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, + llm_registry=llm_registry, + ) print(f'state: {state}') events = list(test_event_stream.get_events()) print(f'event_stream: {events}') @@ -312,18 +366,16 @@ async def test_run_controller_with_fatal_error( @pytest.mark.asyncio async def test_run_controller_stop_with_stuck( - test_event_stream, mock_memory, mock_agent + test_event_stream, mock_memory, mock_agent_with_stats ): config = OpenHandsConfig() + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats def agent_step_fn(state): print(f'agent_step_fn received state: {state}') return CmdRunAction(command='ls') mock_agent.step = agent_step_fn - mock_agent.llm = MagicMock(spec=LLM) - mock_agent.llm.metrics = Metrics() - mock_agent.llm.config = config.get_llm_config() runtime = MagicMock(spec=ActionExecutionClient) @@ -352,15 +404,17 @@ async def test_run_controller_stop_with_stuck( EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()) ) - state = await run_controller( - config=config, - initial_user_action=MessageAction(content='Test message'), - runtime=runtime, - sid='test', - agent=mock_agent, - fake_user_response_fn=lambda _: 'repeat', - memory=mock_memory, - ) + # Mock the create_agent function to return our mock agent + with patch('openhands.core.main.create_agent', return_value=mock_agent): + state = await run_controller( + config=config, + initial_user_action=MessageAction(content='Test message'), + runtime=runtime, + sid='test', + fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, + llm_registry=llm_registry, + ) events = list(test_event_stream.get_events()) print(f'state: {state}') for i, event in enumerate(events): @@ -391,11 +445,14 @@ async def test_run_controller_stop_with_stuck( @pytest.mark.asyncio -async def test_max_iterations_extension(mock_agent, mock_event_stream): +async def test_max_iterations_extension(mock_agent_with_stats, mock_event_stream): # Test with headless_mode=False - should extend max_iterations + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -426,6 +483,7 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream): controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -450,7 +508,9 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream): @pytest.mark.asyncio -async def test_step_max_budget(mock_agent, mock_event_stream): +async def test_step_max_budget(mock_agent_with_stats, mock_event_stream): + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + # Metrics are always synced with budget flag before metrics = Metrics() metrics.accumulated_cost = 10.1 @@ -458,9 +518,13 @@ async def test_step_max_budget(mock_agent, mock_event_stream): limit_increase_amount=10, current_value=10.1, max_value=10 ) + # Update agent's LLM metrics in place + mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, budget_per_task_delta=10, sid='test', @@ -475,7 +539,9 @@ async def test_step_max_budget(mock_agent, mock_event_stream): @pytest.mark.asyncio -async def test_step_max_budget_headless(mock_agent, mock_event_stream): +async def test_step_max_budget_headless(mock_agent_with_stats, mock_event_stream): + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + # Metrics are always synced with budget flag before metrics = Metrics() metrics.accumulated_cost = 10.1 @@ -483,9 +549,13 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream): limit_increase_amount=10, current_value=10.1, max_value=10 ) + # Update agent's LLM metrics in place + mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, budget_per_task_delta=10, sid='test', @@ -500,12 +570,14 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream): @pytest.mark.asyncio -async def test_budget_reset_on_continue(mock_agent, mock_event_stream): +async def test_budget_reset_on_continue(mock_agent_with_stats, mock_event_stream): """Test that when a user continues after hitting the budget limit: 1. Error is thrown when budget cap is exceeded 2. LLM budget does not reset when user continues 3. Budget is extended by adding the initial budget cap to the current accumulated cost """ + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + # Create a real Metrics instance shared between controller state and llm metrics = Metrics() metrics.accumulated_cost = 6.0 @@ -521,10 +593,14 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream): ), ) + # Update agent's LLM metrics in place + mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost + # Create controller with budget cap controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, budget_per_task_delta=initial_budget, sid='test', @@ -570,11 +646,17 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream): @pytest.mark.asyncio -async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream): +async def test_reset_with_pending_action_no_observation( + mock_agent_with_stats, mock_event_stream +): """Test reset() when there's a pending action with tool call metadata but no observation.""" + # Connect LLM registry to conversation stats + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -617,11 +699,17 @@ async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_s @pytest.mark.asyncio -async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_stream): +async def test_reset_with_pending_action_stopped_state( + mock_agent_with_stats, mock_event_stream +): """Test reset() when there's a pending action and agent state is STOPPED.""" + # Connect LLM registry to conversation stats + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -665,12 +753,16 @@ async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_st @pytest.mark.asyncio async def test_reset_with_pending_action_existing_observation( - mock_agent, mock_event_stream + mock_agent_with_stats, mock_event_stream ): """Test reset() when there's a pending action with tool call metadata and an existing observation.""" + # Connect LLM registry to conversation stats + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -708,11 +800,15 @@ async def test_reset_with_pending_action_existing_observation( @pytest.mark.asyncio -async def test_reset_without_pending_action(mock_agent, mock_event_stream): +async def test_reset_without_pending_action(mock_agent_with_stats, mock_event_stream): """Test reset() when there's no pending action.""" + # Connect LLM registry to conversation stats + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -738,12 +834,15 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream): @pytest.mark.asyncio async def test_reset_with_pending_action_no_metadata( - mock_agent, mock_event_stream, monkeypatch + mock_agent_with_stats, mock_event_stream, monkeypatch ): """Test reset() when there's a pending action without tool call metadata.""" + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -782,16 +881,13 @@ async def test_reset_with_pending_action_no_metadata( @pytest.mark.asyncio async def test_run_controller_max_iterations_has_metrics( - test_event_stream, mock_memory, mock_agent + test_event_stream, mock_memory, mock_agent_with_stats ): config = OpenHandsConfig( max_iterations=3, ) event_stream = test_event_stream - - mock_agent.llm = MagicMock(spec=LLM) - mock_agent.llm.metrics = Metrics() - mock_agent.llm.config = config.get_llm_config() + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats step_count = 0 @@ -833,15 +929,17 @@ async def test_run_controller_max_iterations_has_metrics( event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())) - state = await run_controller( - config=config, - initial_user_action=MessageAction(content='Test message'), - runtime=runtime, - sid='test', - agent=mock_agent, - fake_user_response_fn=lambda _: 'repeat', - memory=mock_memory, - ) + # Mock the create_agent function to return our mock agent + with patch('openhands.core.main.create_agent', return_value=mock_agent): + state = await run_controller( + config=config, + initial_user_action=MessageAction(content='Test message'), + runtime=runtime, + sid='test', + fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, + llm_registry=llm_registry, + ) state.metrics = mock_agent.llm.metrics assert state.iteration_flag.current_value == 3 @@ -867,10 +965,17 @@ async def test_run_controller_max_iterations_has_metrics( @pytest.mark.asyncio -async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback): +async def test_notify_on_llm_retry( + mock_agent_with_stats, + mock_event_stream, + mock_status_callback, +): + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, status_callback=mock_status_callback, iteration_delta=10, sid='test', @@ -908,9 +1013,15 @@ async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_ca ], ) async def test_context_window_exceeded_error_handling( - context_window_error, mock_agent, mock_runtime, test_event_stream, mock_memory + context_window_error, + mock_agent_with_stats, + mock_runtime, + test_event_stream, + mock_memory, ): """Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact.""" + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + max_iterations = 5 error_after = 2 @@ -973,18 +1084,20 @@ async def test_context_window_exceeded_error_handling( # state is set to error out before then, if this terminates and we have a # record of the error being thrown we can be confident that the controller # handles the truncation correctly. - final_state = await asyncio.wait_for( - run_controller( - config=config, - initial_user_action=MessageAction(content='INITIAL'), - runtime=mock_runtime, - sid='test', - agent=mock_agent, - fake_user_response_fn=lambda _: 'repeat', - memory=mock_memory, - ), - timeout=10, - ) + # Mock the create_agent function to return our mock agent + with patch('openhands.core.main.create_agent', return_value=mock_agent): + final_state = await asyncio.wait_for( + run_controller( + config=config, + initial_user_action=MessageAction(content='INITIAL'), + runtime=mock_runtime, + sid='test', + fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, + llm_registry=llm_registry, + ), + timeout=10, + ) # Check that the context window exception was thrown and the controller # called the agent's `step` function the right number of times. @@ -1072,9 +1185,13 @@ async def test_context_window_exceeded_error_handling( @pytest.mark.asyncio async def test_run_controller_with_context_window_exceeded_with_truncation( - mock_agent, mock_runtime, mock_memory, test_event_stream + mock_agent_with_stats, + mock_runtime, + mock_memory, + test_event_stream, ): """Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON.""" + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats class StepState: def __init__(self): @@ -1121,18 +1238,20 @@ async def test_run_controller_with_context_window_exceeded_with_truncation( mock_runtime.config = copy.deepcopy(config) try: - state = await asyncio.wait_for( - run_controller( - config=config, - initial_user_action=MessageAction(content='INITIAL'), - runtime=mock_runtime, - sid='test', - agent=mock_agent, - fake_user_response_fn=lambda _: 'repeat', - memory=mock_memory, - ), - timeout=10, - ) + # Mock the create_agent function to return our mock agent + with patch('openhands.core.main.create_agent', return_value=mock_agent): + state = await asyncio.wait_for( + run_controller( + config=config, + initial_user_action=MessageAction(content='INITIAL'), + runtime=mock_runtime, + sid='test', + fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, + llm_registry=llm_registry, + ), + timeout=10, + ) # A timeout error indicates the run_controller entrypoint is not making # progress @@ -1156,9 +1275,13 @@ async def test_run_controller_with_context_window_exceeded_with_truncation( @pytest.mark.asyncio async def test_run_controller_with_context_window_exceeded_without_truncation( - mock_agent, mock_runtime, mock_memory, test_event_stream + mock_agent_with_stats, + mock_runtime, + mock_memory, + test_event_stream, ): """Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON.""" + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats class StepState: def __init__(self): @@ -1199,18 +1322,20 @@ async def test_run_controller_with_context_window_exceeded_without_truncation( config = OpenHandsConfig(max_iterations=3) mock_runtime.config = copy.deepcopy(config) try: - state = await asyncio.wait_for( - run_controller( - config=config, - initial_user_action=MessageAction(content='INITIAL'), - runtime=mock_runtime, - sid='test', - agent=mock_agent, - fake_user_response_fn=lambda _: 'repeat', - memory=mock_memory, - ), - timeout=10, - ) + # Mock the create_agent function to return our mock agent + with patch('openhands.core.main.create_agent', return_value=mock_agent): + state = await asyncio.wait_for( + run_controller( + config=config, + initial_user_action=MessageAction(content='INITIAL'), + runtime=mock_runtime, + sid='test', + fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, + llm_registry=llm_registry, + ), + timeout=10, + ) # A timeout error indicates the run_controller entrypoint is not making # progress @@ -1244,7 +1369,11 @@ async def test_run_controller_with_context_window_exceeded_without_truncation( @pytest.mark.asyncio -async def test_run_controller_with_memory_error(test_event_stream, mock_agent): +async def test_run_controller_with_memory_error( + test_event_stream, mock_agent_with_stats +): + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + config = OpenHandsConfig() event_stream = test_event_stream @@ -1273,15 +1402,17 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent): with patch.object( memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge ): - state = await run_controller( - config=config, - initial_user_action=MessageAction(content='Test message'), - runtime=runtime, - sid='test', - agent=mock_agent, - fake_user_response_fn=lambda _: 'repeat', - memory=memory, - ) + # Mock the create_agent function to return our mock agent + with patch('openhands.core.main.create_agent', return_value=mock_agent): + state = await run_controller( + config=config, + initial_user_action=MessageAction(content='Test message'), + runtime=runtime, + sid='test', + fake_user_response_fn=lambda _: 'repeat', + memory=memory, + llm_registry=llm_registry, + ) assert state.iteration_flag.current_value == 0 assert state.agent_state == AgentState.ERROR @@ -1289,7 +1420,9 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent): @pytest.mark.asyncio -async def test_action_metrics_copy(mock_agent): +async def test_action_metrics_copy(mock_agent_with_stats): + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + # Setup file_store = InMemoryFileStore({}) event_stream = EventStream(sid='test', file_store=file_store) @@ -1299,8 +1432,7 @@ async def test_action_metrics_copy(mock_agent): initial_state = State(metrics=metrics, budget_flag=None) - # Create agent with metrics - mock_agent.llm = MagicMock(spec=LLM) + # Update agent's LLM metrics # Add multiple token usages - we should get the last one in the action usage1 = TokenUsage( @@ -1342,6 +1474,11 @@ async def test_action_metrics_copy(mock_agent): mock_agent.llm.metrics = metrics + # Register the metrics with the LLM registry + llm_registry.service_to_llm['agent'] = mock_agent.llm + # Manually notify the conversation stats about the LLM registration + llm_registry.notify(RegistryEvent(llm=mock_agent.llm, service_id='agent')) + # Mock agent step to return an action action = MessageAction(content='Test message') @@ -1354,6 +1491,7 @@ async def test_action_metrics_copy(mock_agent): controller = AgentController( agent=mock_agent, event_stream=event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -1411,12 +1549,13 @@ async def test_action_metrics_copy(mock_agent): @pytest.mark.asyncio -async def test_condenser_metrics_included(mock_agent, test_event_stream): +async def test_condenser_metrics_included(mock_agent_with_stats, test_event_stream): """Test that metrics from the condenser's LLM are included in the action metrics.""" - # Set up agent metrics - agent_metrics = Metrics(model_name='agent-model') - agent_metrics.accumulated_cost = 0.05 - agent_metrics._accumulated_token_usage = TokenUsage( + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + + # Set up agent metrics in place + mock_agent.llm.metrics.accumulated_cost = 0.05 + mock_agent.llm.metrics._accumulated_token_usage = TokenUsage( model='agent-model', prompt_tokens=100, completion_tokens=50, @@ -1424,7 +1563,6 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream): cache_write_tokens=10, response_id='agent-accumulated', ) - # mock_agent.llm.metrics = agent_metrics mock_agent.name = 'TestAgent' # Create condenser with its own metrics @@ -1442,6 +1580,11 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream): ) condenser.llm.metrics = condenser_metrics + # Register the condenser metrics with the LLM registry + llm_registry.service_to_llm['condenser'] = condenser.llm + # Manually notify the conversation stats about the condenser LLM registration + llm_registry.notify(RegistryEvent(llm=condenser.llm, service_id='condenser')) + # Attach the condenser to the mock_agent mock_agent.condenser = condenser @@ -1463,11 +1606,12 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream): controller = AgentController( agent=mock_agent, event_stream=test_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, headless_mode=True, - initial_state=State(metrics=agent_metrics, budget_flag=None), + initial_state=State(metrics=mock_agent.llm.metrics, budget_flag=None), ) # Execute one step @@ -1505,7 +1649,9 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream): @pytest.mark.asyncio -async def test_first_user_message_with_identical_content(test_event_stream, mock_agent): +async def test_first_user_message_with_identical_content( + test_event_stream, mock_agent_with_stats +): """Test that _first_user_message correctly identifies the first user message. This test verifies that messages with identical content but different IDs are properly @@ -1514,14 +1660,12 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock The issue we're checking is that the comparison (action == self._first_user_message()) should correctly differentiate between messages with the same content but different IDs. """ - # Create an agent controller - mock_agent.llm = MagicMock(spec=LLM) - mock_agent.llm.metrics = Metrics() - mock_agent.llm.config = OpenHandsConfig().get_llm_config() + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats controller = AgentController( agent=mock_agent, event_stream=test_event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test', confirmation_mode=False, @@ -1569,11 +1713,15 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock @pytest.mark.asyncio -async def test_agent_controller_processes_null_observation_with_cause(): +async def test_agent_controller_processes_null_observation_with_cause( + mock_agent_with_stats, +): """Test that AgentController processes NullObservation events with a cause value. And that the agent's step method is called as a result. """ + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + # Create an in-memory file store and real event stream file_store = InMemoryFileStore() event_stream = EventStream(sid='test-session', file_store=file_store) @@ -1581,19 +1729,11 @@ async def test_agent_controller_processes_null_observation_with_cause(): # Create a Memory instance - not used directly in this test but needed for setup Memory(event_stream=event_stream, sid='test-session') - # Create a mock agent with necessary attributes - mock_agent = MagicMock(spec=Agent) - mock_agent.get_system_message = MagicMock( - return_value=None, - ) - mock_agent.llm = MagicMock(spec=LLM) - mock_agent.llm.metrics = Metrics() - mock_agent.llm.config = OpenHandsConfig().get_llm_config() - # Create a controller with the mock agent controller = AgentController( agent=mock_agent, event_stream=event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test-session', ) @@ -1655,8 +1795,12 @@ async def test_agent_controller_processes_null_observation_with_cause(): ) -def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agent): +def test_agent_controller_should_step_with_null_observation_cause_zero( + mock_agent_with_stats, +): """Test that AgentController's should_step method returns False for NullObservation with cause = 0.""" + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + # Create a mock event stream file_store = InMemoryFileStore() event_stream = EventStream(sid='test-session', file_store=file_store) @@ -1665,6 +1809,7 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen controller = AgentController( agent=mock_agent, event_stream=event_stream, + convo_stats=conversation_stats, iteration_delta=10, sid='test-session', ) @@ -1683,10 +1828,15 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen ) -def test_system_message_in_event_stream(mock_agent, test_event_stream): +def test_system_message_in_event_stream(mock_agent_with_stats, test_event_stream): """Test that SystemMessageAction is added to event stream in AgentController.""" + mock_agent, conversation_stats, llm_registry = mock_agent_with_stats + _ = AgentController( - agent=mock_agent, event_stream=test_event_stream, iteration_delta=10 + agent=mock_agent, + event_stream=test_event_stream, + convo_stats=conversation_stats, + iteration_delta=10, ) # Get events from the event stream diff --git a/tests/unit/test_agent_delegation.py b/tests/unit/test_agent_delegation.py index 1c5bc6f545..fdb9c9d4d8 100644 --- a/tests/unit/test_agent_delegation.py +++ b/tests/unit/test_agent_delegation.py @@ -12,8 +12,9 @@ from openhands.controller.state.control_flags import ( IterationControlFlag, ) from openhands.controller.state.state import State -from openhands.core.config import LLMConfig +from openhands.core.config import OpenHandsConfig from openhands.core.config.agent_config import AgentConfig +from openhands.core.config.llm_config import LLMConfig from openhands.core.schema import AgentState from openhands.events import EventSource, EventStream from openhands.events.action import ( @@ -28,11 +29,39 @@ from openhands.events.event import Event, RecallType from openhands.events.observation.agent import RecallObservation from openhands.events.stream import EventStreamSubscriber from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.llm.metrics import Metrics from openhands.memory.memory import Memory +from openhands.server.services.conversation_stats import ConversationStats from openhands.storage.memory import InMemoryFileStore +@pytest.fixture +def llm_registry(): + config = OpenHandsConfig() + return LLMRegistry(config=config) + + +@pytest.fixture +def conversation_stats(): + import uuid + + file_store = InMemoryFileStore({}) + # Use a unique conversation ID for each test to avoid conflicts + conversation_id = f'test-conversation-{uuid.uuid4()}' + return ConversationStats( + file_store=file_store, conversation_id=conversation_id, user_id='test-user' + ) + + +@pytest.fixture +def connected_registry_and_stats(llm_registry, conversation_stats): + """Connect the LLMRegistry and ConversationStats properly""" + # Subscribe to LLM registry events to track metrics + llm_registry.subscribe(conversation_stats.register_llm) + return llm_registry, conversation_stats + + @pytest.fixture def mock_event_stream(): """Creates an event stream in memory.""" @@ -42,15 +71,17 @@ def mock_event_stream(): @pytest.fixture -def mock_parent_agent(): +def mock_parent_agent(llm_registry): """Creates a mock parent agent for testing delegation.""" agent = MagicMock(spec=Agent) agent.name = 'ParentAgent' agent.llm = MagicMock(spec=LLM) + agent.llm.service_id = 'main_agent' agent.llm.metrics = Metrics() agent.llm.config = LLMConfig() agent.llm.retry_listener = None # Add retry_listener attribute agent.config = AgentConfig() + agent.llm_registry = llm_registry # Add the missing llm_registry attribute # Add a proper system message mock system_message = SystemMessageAction(content='Test system message') @@ -61,15 +92,17 @@ def mock_parent_agent(): @pytest.fixture -def mock_child_agent(): +def mock_child_agent(llm_registry): """Creates a mock child agent for testing delegation.""" agent = MagicMock(spec=Agent) agent.name = 'ChildAgent' agent.llm = MagicMock(spec=LLM) + agent.llm.service_id = 'main_agent' agent.llm.metrics = Metrics() agent.llm.config = LLMConfig() agent.llm.retry_listener = None # Add retry_listener attribute agent.config = AgentConfig() + agent.llm_registry = llm_registry # Add the missing llm_registry attribute system_message = SystemMessageAction(content='Test system message') system_message._source = EventSource.AGENT @@ -78,15 +111,37 @@ def mock_child_agent(): return agent +def create_mock_agent_factory(mock_child_agent, llm_registry): + """Helper function to create a mock agent factory with proper LLM registration.""" + + def create_mock_agent(config, llm_registry=None): + # Register the mock agent's LLM in the registry so get_combined_metrics() can find it + if llm_registry: + mock_child_agent.llm = llm_registry.get_llm('agent_llm', LLMConfig()) + mock_child_agent.llm_registry = ( + llm_registry # Set the llm_registry attribute + ) + return mock_child_agent + + return create_mock_agent + + @pytest.mark.asyncio -async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream): - """Test that when the parent agent delegates to a child - 1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly. - 2. metrics are accumulated globally (delegate is adding to the parents metrics) - 3. local metrics for the delegate are still accessible +async def test_delegation_flow( + mock_parent_agent, mock_child_agent, mock_event_stream, connected_registry_and_stats +): """ + Test that when the parent agent delegates to a child + 1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly. + 2. metrics are accumulated globally via LLM registry (delegate adds to the global metrics) + 3. global metrics tracking works correctly through the LLM registry + """ + llm_registry, conversation_stats = connected_registry_and_stats + # Mock the agent class resolution so that AgentController can instantiate mock_child_agent - Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent) + Agent.get_cls = Mock( + return_value=create_mock_agent_factory(mock_child_agent, llm_registry) + ) step_count = 0 @@ -97,6 +152,12 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s mock_child_agent.step = agent_step_fn + # Set up the parent agent's LLM with initial cost and register it in the registry + # The parent agent's LLM should use the existing registered LLM to ensure proper tracking + parent_llm = llm_registry.service_to_llm['agent'] + parent_llm.metrics.accumulated_cost = 2 + mock_parent_agent.llm = parent_llm + parent_metrics = Metrics() parent_metrics.accumulated_cost = 2 # Create parent controller @@ -114,6 +175,7 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s parent_controller = AgentController( agent=mock_parent_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=1, # Add the required iteration_delta parameter sid='parent', confirmation_mode=False, @@ -180,21 +242,23 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s for i in range(4): delegate_controller.state.iteration_flag.step() delegate_controller.agent.step(delegate_controller.state) + # Update the agent's LLM metrics (not the deprecated state metrics) delegate_controller.agent.llm.metrics.add_cost(1.0) assert ( delegate_controller.state.get_local_step() == 4 ) # verify local metrics are accessible via snapshot + # Check that the conversation stats has the combined metrics (parent + delegate) + combined_metrics = delegate_controller.state.convo_stats.get_combined_metrics() assert ( - delegate_controller.state.metrics.accumulated_cost - == 6 # Make sure delegate tracks global cost + combined_metrics.accumulated_cost + == 6 # Make sure delegate tracks global cost (2 from parent + 4 from delegate) ) - assert ( - delegate_controller.state.get_local_metrics().accumulated_cost - == 4 # Delegate spent one dollar per step - ) + # Since metrics are now global via LLM registry, local metrics tracking + # is handled differently. The delegate's LLM shares the same metrics object + # as the parent for global tracking, so we verify the global total is correct. delegate_controller.state.outputs = {'delegate_result': 'done'} @@ -228,15 +292,18 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s ], ) async def test_delegate_step_different_states( - mock_parent_agent, mock_event_stream, delegate_state + mock_parent_agent, mock_event_stream, delegate_state, connected_registry_and_stats ): """Ensure that delegate is closed or remains open based on the delegate's state.""" + llm_registry, conversation_stats = connected_registry_and_stats + # Create a state with iteration_flag.max_value set to 10 state = State(inputs={}) state.iteration_flag.max_value = 10 controller = AgentController( agent=mock_parent_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=1, # Add the required iteration_delta parameter sid='test', confirmation_mode=False, @@ -292,11 +359,23 @@ async def test_delegate_step_different_states( @pytest.mark.asyncio async def test_delegate_hits_global_limits( - mock_child_agent, mock_event_stream, mock_parent_agent + mock_child_agent, mock_event_stream, mock_parent_agent, connected_registry_and_stats ): - """Global limits from control flags should apply to delegates""" + """ + Global limits from control flags should apply to delegates + """ + llm_registry, conversation_stats = connected_registry_and_stats + # Mock the agent class resolution so that AgentController can instantiate mock_child_agent - Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent) + Agent.get_cls = Mock( + return_value=create_mock_agent_factory(mock_child_agent, llm_registry) + ) + + # Set up the parent agent's LLM with initial cost and register it in the registry + mock_parent_agent.llm.metrics.accumulated_cost = 2 + mock_parent_agent.llm.service_id = 'main_agent' + # Register the parent agent's LLM in the registry + llm_registry.service_to_llm['main_agent'] = mock_parent_agent.llm parent_metrics = Metrics() parent_metrics.accumulated_cost = 2 @@ -315,6 +394,7 @@ async def test_delegate_hits_global_limits( parent_controller = AgentController( agent=mock_parent_agent, event_stream=mock_event_stream, + convo_stats=conversation_stats, iteration_delta=1, # Add the required iteration_delta parameter sid='parent', confirmation_mode=False, diff --git a/tests/unit/test_agent_session.py b/tests/unit/test_agent_session.py index 2697380f4a..e05e69815c 100644 --- a/tests/unit/test_agent_session.py +++ b/tests/unit/test_agent_session.py @@ -9,12 +9,13 @@ from openhands.core.config import LLMConfig, OpenHandsConfig from openhands.core.config.agent_config import AgentConfig from openhands.events import EventStream, EventStreamSubscriber from openhands.integrations.service_types import ProviderType -from openhands.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.llm.metrics import Metrics from openhands.memory.memory import Memory from openhands.runtime.impl.action_execution.action_execution_client import ( ActionExecutionClient, ) +from openhands.server.services.conversation_stats import ConversationStats from openhands.server.session.agent_session import AgentSession from openhands.storage.memory import InMemoryFileStore @@ -22,44 +23,70 @@ from openhands.storage.memory import InMemoryFileStore @pytest.fixture -def mock_agent(): - """Create a properly configured mock agent with all required nested attributes""" - # Create the base mocks - agent = MagicMock(spec=Agent) - llm = MagicMock(spec=LLM) - metrics = MagicMock(spec=Metrics) - llm_config = MagicMock(spec=LLMConfig) - agent_config = MagicMock(spec=AgentConfig) +def mock_llm_registry(): + """Create a mock LLM registry that properly simulates LLM registration""" + config = OpenHandsConfig() + registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None) + return registry - # Configure the LLM config - llm_config.model = 'test-model' - llm_config.base_url = 'http://test' - llm_config.max_message_chars = 1000 - # Configure the agent config - agent_config.disabled_microagents = [] - agent_config.enable_mcp = True +@pytest.fixture +def mock_conversation_stats(): + """Create a mock ConversationStats that properly simulates metrics tracking""" + file_store = InMemoryFileStore({}) + stats = ConversationStats( + file_store=file_store, conversation_id='test-conversation', user_id='test-user' + ) + return stats - # Set up the chain of mocks - llm.metrics = metrics - llm.config = llm_config - agent.llm = llm - agent.name = 'test-agent' - agent.sandbox_plugins = [] - agent.config = agent_config - agent.prompt_manager = MagicMock() - return agent +@pytest.fixture +def connected_registry_and_stats(mock_llm_registry, mock_conversation_stats): + """Connect the LLMRegistry and ConversationStats properly""" + # Subscribe to LLM registry events to track metrics + mock_llm_registry.subscribe(mock_conversation_stats.register_llm) + return mock_llm_registry, mock_conversation_stats + + +@pytest.fixture +def make_mock_agent(): + def _make_mock_agent(llm_registry): + agent = MagicMock(spec=Agent) + agent_config = MagicMock(spec=AgentConfig) + llm_config = LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + agent_config.disabled_microagents = [] + agent_config.enable_mcp = True + llm_registry.service_to_llm.clear() + mock_llm = llm_registry.get_llm('agent_llm', llm_config) + agent.llm = mock_llm + agent.name = 'test-agent' + agent.sandbox_plugins = [] + agent.config = agent_config + agent.prompt_manager = MagicMock() + return agent + + return _make_mock_agent @pytest.mark.asyncio -async def test_agent_session_start_with_no_state(mock_agent): +async def test_agent_session_start_with_no_state( + make_mock_agent, mock_llm_registry, mock_conversation_stats +): """Test that AgentSession.start() works correctly when there's no state to restore""" + mock_agent = make_mock_agent(mock_llm_registry) # Setup file_store = InMemoryFileStore({}) session = AgentSession( sid='test-session', file_store=file_store, + llm_registry=mock_llm_registry, + convo_stats=mock_conversation_stats, ) # Create a mock runtime and set it up @@ -140,13 +167,18 @@ async def test_agent_session_start_with_no_state(mock_agent): @pytest.mark.asyncio -async def test_agent_session_start_with_restored_state(mock_agent): +async def test_agent_session_start_with_restored_state( + make_mock_agent, mock_llm_registry, mock_conversation_stats +): """Test that AgentSession.start() works correctly when there's a state to restore""" + mock_agent = make_mock_agent(mock_llm_registry) # Setup file_store = InMemoryFileStore({}) session = AgentSession( sid='test-session', file_store=file_store, + llm_registry=mock_llm_registry, + convo_stats=mock_conversation_stats, ) # Create a mock runtime and set it up @@ -230,13 +262,21 @@ async def test_agent_session_start_with_restored_state(mock_agent): @pytest.mark.asyncio -async def test_metrics_centralization_and_sharing(mock_agent): - """Test that metrics are centralized and shared between controller and agent.""" +async def test_metrics_centralization_via_conversation_stats( + make_mock_agent, connected_registry_and_stats +): + """Test that metrics are centralized through the ConversationStats service.""" + + mock_llm_registry, mock_conversation_stats = connected_registry_and_stats + mock_agent = make_mock_agent(mock_llm_registry) + # Setup file_store = InMemoryFileStore({}) session = AgentSession( sid='test-session', file_store=file_store, + llm_registry=mock_llm_registry, + convo_stats=mock_conversation_stats, ) # Create a mock runtime and set it up @@ -262,6 +302,8 @@ async def test_metrics_centralization_and_sharing(mock_agent): memory = Memory(event_stream=mock_event_stream, sid='test-session') memory.microagents_dir = 'test-dir' + # The registry already has a real metrics object set up in the fixture + # Patch necessary components with ( patch( @@ -281,49 +323,50 @@ async def test_metrics_centralization_and_sharing(mock_agent): max_iterations=10, ) - # Verify that the agent's LLM metrics and controller's state metrics are the same object - assert session.controller.agent.llm.metrics is session.controller.state.metrics + # Verify that the ConversationStats is properly set up + assert session.controller.state.convo_stats is mock_conversation_stats - # Add some metrics to the agent's LLM + # Add some metrics to the agent's LLM (simulating LLM usage) test_cost = 0.05 session.controller.agent.llm.metrics.add_cost(test_cost) - # Verify that the cost is reflected in the controller's state metrics - assert session.controller.state.metrics.accumulated_cost == test_cost + # Verify that the cost is reflected in the combined metrics from the conversation stats + combined_metrics = session.controller.state.convo_stats.get_combined_metrics() + assert combined_metrics.accumulated_cost == test_cost - # Create a test metrics object to simulate an observation with metrics - test_observation_metrics = Metrics() - test_observation_metrics.add_cost(0.1) + # Add more cost to simulate additional LLM usage + additional_cost = 0.1 + session.controller.agent.llm.metrics.add_cost(additional_cost) - # Get the current accumulated cost before merging - current_cost = session.controller.state.metrics.accumulated_cost + # Verify the combined metrics reflect the total cost + combined_metrics = session.controller.state.convo_stats.get_combined_metrics() + assert combined_metrics.accumulated_cost == test_cost + additional_cost - # Simulate merging metrics from an observation - session.controller.state_tracker.merge_metrics(test_observation_metrics) - - # Verify that the merged metrics are reflected in both agent and controller - assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1 - assert ( - session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1 - ) - - # Reset the agent and verify that metrics are not reset + # Reset the agent and verify that combined metrics are preserved session.controller.agent.reset() - # Metrics should still be the same after reset - assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1 - assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1 - assert session.controller.agent.llm.metrics is session.controller.state.metrics + # Combined metrics should still be preserved after agent reset + assert ( + session.controller.state.convo_stats.get_combined_metrics().accumulated_cost + == test_cost + additional_cost + ) @pytest.mark.asyncio -async def test_budget_control_flag_syncs_with_metrics(mock_agent): +async def test_budget_control_flag_syncs_with_metrics( + make_mock_agent, connected_registry_and_stats +): """Test that BudgetControlFlag's current value matches the accumulated costs.""" + + mock_llm_registry, mock_conversation_stats = connected_registry_and_stats + mock_agent = make_mock_agent(mock_llm_registry) # Setup file_store = InMemoryFileStore({}) session = AgentSession( sid='test-session', file_store=file_store, + llm_registry=mock_llm_registry, + convo_stats=mock_conversation_stats, ) # Create a mock runtime and set it up @@ -349,6 +392,8 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent): memory = Memory(event_stream=mock_event_stream, sid='test-session') memory.microagents_dir = 'test-dir' + # The registry already has a real metrics object set up in the fixture + # Patch necessary components with ( patch( @@ -375,7 +420,7 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent): assert session.controller.state.budget_flag.max_value == 1.0 assert session.controller.state.budget_flag.current_value == 0.0 - # Add some metrics to the agent's LLM + # Add some metrics to the agent's LLM (simulating LLM usage) test_cost = 0.05 session.controller.agent.llm.metrics.add_cost(test_cost) @@ -384,24 +429,31 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent): session.controller.state_tracker.sync_budget_flag_with_metrics() assert session.controller.state.budget_flag.current_value == test_cost - # Create a test metrics object to simulate an observation with metrics - test_observation_metrics = Metrics() - test_observation_metrics.add_cost(0.1) + # Add more cost to simulate additional LLM usage + additional_cost = 0.1 + session.controller.agent.llm.metrics.add_cost(additional_cost) - # Simulate merging metrics from an observation - session.controller.state_tracker.merge_metrics(test_observation_metrics) + # Sync again and verify the budget flag is updated + session.controller.state_tracker.sync_budget_flag_with_metrics() + assert ( + session.controller.state.budget_flag.current_value + == test_cost + additional_cost + ) - # Verify that the budget control flag's current value is updated to match the new accumulated cost - assert session.controller.state.budget_flag.current_value == test_cost + 0.1 - - # Reset the agent and verify that metrics and budget flag are not reset + # Reset the agent and verify that budget flag still reflects the accumulated cost session.controller.agent.reset() # Budget control flag should still reflect the accumulated cost after reset - assert session.controller.state.budget_flag.current_value == test_cost + 0.1 + session.controller.state_tracker.sync_budget_flag_with_metrics() + assert ( + session.controller.state.budget_flag.current_value + == test_cost + additional_cost + ) -def test_override_provider_tokens_with_custom_secret(): +def test_override_provider_tokens_with_custom_secret( + mock_llm_registry, mock_conversation_stats +): """Test that override_provider_tokens_with_custom_secret works correctly. This test verifies that the method properly removes provider tokens when @@ -413,6 +465,8 @@ def test_override_provider_tokens_with_custom_secret(): session = AgentSession( sid='test-session', file_store=file_store, + llm_registry=mock_llm_registry, + convo_stats=mock_conversation_stats, ) # Create test data diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index 4e6b4913e0..051857aa2a 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -30,6 +30,7 @@ from openhands.agenthub.readonly_agent.tools import ( ) from openhands.controller.state.state import State from openhands.core.config import AgentConfig, LLMConfig +from openhands.core.config.openhands_config import OpenHandsConfig from openhands.core.exceptions import FunctionCallNotExistsError from openhands.core.message import ImageContent, Message, TextContent from openhands.events.action import ( @@ -42,10 +43,20 @@ from openhands.events.observation.commands import ( CmdOutputObservation, ) from openhands.events.tool import ToolCallMetadata -from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser import View +@pytest.fixture +def create_llm_registry(): + def _get_registry(llm_config): + config = OpenHandsConfig() + config.set_llm_config(llm_config) + return LLMRegistry(config=config) + + return _get_registry + + @pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent']) def agent_class(request): if request.param == 'CodeActAgent': @@ -57,18 +68,22 @@ def agent_class(request): @pytest.fixture -def agent(agent_class) -> Union[CodeActAgent, ReadOnlyAgent]: +def agent(agent_class, create_llm_registry) -> Union[CodeActAgent, ReadOnlyAgent]: + llm_config = LLMConfig(model='gpt-4o', api_key='test_key') config = AgentConfig() - agent = agent_class(llm=LLM(LLMConfig()), config=config) + agent = agent_class(config=config, llm_registry=create_llm_registry(llm_config)) agent.llm = Mock() agent.llm.config = Mock() agent.llm.config.max_message_chars = 1000 return agent -def test_agent_with_default_config_has_default_tools(): +def test_agent_with_default_config_has_default_tools(create_llm_registry): + llm_config = LLMConfig(model='gpt-4o', api_key='test_key') config = AgentConfig() - codeact_agent = CodeActAgent(llm=LLM(LLMConfig()), config=config) + codeact_agent = CodeActAgent( + config=config, llm_registry=create_llm_registry(llm_config) + ) assert len(codeact_agent.tools) > 0 default_tool_names = [tool['function']['name'] for tool in codeact_agent.tools] assert { @@ -231,7 +246,7 @@ def test_response_to_actions_invalid_tool(): readonly_response_to_actions(mock_response) -def test_step_with_no_pending_actions(mock_state: State): +def test_step_with_no_pending_actions(mock_state: State, create_llm_registry): # Mock the LLM response mock_response = Mock() mock_response.id = 'mock_id' @@ -252,9 +267,12 @@ def test_step_with_no_pending_actions(mock_state: State): llm.format_messages_for_llm = Mock(return_value=[]) # Mock message formatting # Create agent with mocked LLM + llm_config = LLMConfig(model='gpt-4o', api_key='test_key') config = AgentConfig() config.enable_prompt_extensions = False - agent = CodeActAgent(llm=llm, config=config) + agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config)) + # Replace the LLM with our mock after creation + agent.llm = llm # Test step with no pending actions mock_state.latest_user_message = None @@ -281,15 +299,10 @@ def test_step_with_no_pending_actions(mock_state: State): @pytest.mark.parametrize('agent_type', ['CodeActAgent', 'ReadOnlyAgent']) def test_correct_tool_description_loaded_based_on_model_name( - agent_type, mock_state: State + agent_type, create_llm_registry ): """Tests that the simplified tool descriptions are loaded for specific models.""" - o3_mock_config = Mock() - o3_mock_config.model = 'mock_o3_model' - - llm = Mock() - llm.config = o3_mock_config - + o3_mock_config = LLMConfig(model='mock_o3_model', api_key='test_key') if agent_type == 'CodeActAgent': from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent @@ -299,16 +312,19 @@ def test_correct_tool_description_loaded_based_on_model_name( agent_class = ReadOnlyAgent - agent = agent_class(llm=llm, config=AgentConfig()) + agent = agent_class( + config=AgentConfig(), + llm_registry=create_llm_registry(o3_mock_config), + ) for tool in agent.tools: # Assert all descriptions have less than 1024 characters assert len(tool['function']['description']) < 1024 - sonnet_mock_config = Mock() - sonnet_mock_config.model = 'mock_sonnet_model' - - llm.config = sonnet_mock_config - agent = agent_class(llm=llm, config=AgentConfig()) + sonnect_mock_config = LLMConfig(model='mock_sonnet_model', api_key='test_key') + agent = agent_class( + config=AgentConfig(), + llm_registry=create_llm_registry(sonnect_mock_config), + ) # Assert existence of the detailed tool descriptions that are longer than 1024 characters if agent_type == 'CodeActAgent': # This only holds for CodeActAgent @@ -481,10 +497,12 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages( assert isinstance(enhanced_messages[5].content[0], ImageContent) -def test_get_system_message(): +def test_get_system_message(create_llm_registry): """Test that the Agent.get_system_message method returns a SystemMessageAction.""" # Create a mock agent - agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig()) + config = AgentConfig() + llm_config = LLMConfig(model='gpt-4o', api_key='test_key') + agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config)) result = agent.get_system_message() diff --git a/tests/unit/test_api_connection_error_retry.py b/tests/unit/test_api_connection_error_retry.py index 3dc3ee73b8..8bcf15f986 100644 --- a/tests/unit/test_api_connection_error_retry.py +++ b/tests/unit/test_api_connection_error_retry.py @@ -34,7 +34,7 @@ def test_completion_retries_api_connection_error( ] # Create an LLM instance and call completion - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') response = llm.completion( messages=[{'role': 'user', 'content': 'Hello!'}], stream=False, @@ -70,7 +70,7 @@ def test_completion_max_retries_api_connection_error( ] # Create an LLM instance and call completion - llm = LLM(config=default_config) + llm = LLM(config=default_config, service_id='test-service') # The completion should raise an APIConnectionError after exhausting all retries with pytest.raises(APIConnectionError) as excinfo: diff --git a/tests/unit/test_auto_generate_title.py b/tests/unit/test_auto_generate_title.py index f052c62e12..5d684bd9c0 100644 --- a/tests/unit/test_auto_generate_title.py +++ b/tests/unit/test_auto_generate_title.py @@ -5,11 +5,11 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from openhands.core.config.llm_config import LLMConfig from openhands.core.config.openhands_config import OpenHandsConfig from openhands.events.action import MessageAction from openhands.events.event import EventSource from openhands.events.event_store import EventStore +from openhands.llm.llm_registry import LLMRegistry from openhands.server.conversation_manager.standalone_conversation_manager import ( StandaloneConversationManager, ) @@ -24,6 +24,7 @@ async def test_auto_generate_title_with_llm(): """Test auto-generating a title using LLM.""" # Mock dependencies file_store = InMemoryFileStore() + llm_registry = MagicMock(spec=LLMRegistry) # Create test conversation with a user message conversation_id = 'test-conversation' @@ -46,43 +47,33 @@ async def test_auto_generate_title_with_llm(): mock_event_store.search_events.return_value = [user_message] mock_event_store_cls.return_value = mock_event_store - # Mock the LLM response - with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls: - mock_llm = mock_llm_cls.return_value - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = 'Python Data Analysis Script' - mock_llm.completion.return_value = mock_response + # Mock the LLM registry response + llm_registry.request_extraneous_completion.return_value = ( + 'Python Data Analysis Script' + ) - # Create test settings with LLM config - settings = Settings( - llm_model='test-model', - llm_api_key='test-key', - llm_base_url='test-url', - ) + # Create test settings with LLM config + settings = Settings( + llm_model='test-model', + llm_api_key='test-key', + llm_base_url='test-url', + ) - # Call the auto_generate_title function directly - title = await auto_generate_title( - conversation_id, user_id, file_store, settings - ) + # Call the auto_generate_title function directly + title = await auto_generate_title( + conversation_id, user_id, file_store, settings, llm_registry + ) - # Verify the result - assert title == 'Python Data Analysis Script' + # Verify the result + assert title == 'Python Data Analysis Script' - # Verify EventStore was created with the correct parameters - mock_event_store_cls.assert_called_once_with( - conversation_id, file_store, user_id - ) + # Verify EventStore was created with the correct parameters + mock_event_store_cls.assert_called_once_with( + conversation_id, file_store, user_id + ) - # Verify LLM was called with appropriate parameters - mock_llm_cls.assert_called_once_with( - LLMConfig( - model='test-model', - api_key='test-key', - base_url='test-url', - ) - ) - mock_llm.completion.assert_called_once() + # Verify LLM registry was called with appropriate parameters + llm_registry.request_extraneous_completion.assert_called_once() @pytest.mark.asyncio @@ -90,6 +81,7 @@ async def test_auto_generate_title_fallback(): """Test auto-generating a title with fallback to truncation when LLM fails.""" # Mock dependencies file_store = InMemoryFileStore() + llm_registry = MagicMock(spec=LLMRegistry) # Create test conversation with a user message conversation_id = 'test-conversation' @@ -111,31 +103,29 @@ async def test_auto_generate_title_fallback(): mock_event_store.search_events.return_value = [user_message] mock_event_store_cls.return_value = mock_event_store - # Mock the LLM to raise an exception - with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls: - mock_llm = mock_llm_cls.return_value - mock_llm.completion.side_effect = Exception('Test error') + # Mock the LLM registry to raise an exception + llm_registry.request_extraneous_completion.side_effect = Exception('Test error') - # Create test settings with LLM config - settings = Settings( - llm_model='test-model', - llm_api_key='test-key', - llm_base_url='test-url', - ) + # Create test settings with LLM config + settings = Settings( + llm_model='test-model', + llm_api_key='test-key', + llm_base_url='test-url', + ) - # Call the auto_generate_title function directly - title = await auto_generate_title( - conversation_id, user_id, file_store, settings - ) + # Call the auto_generate_title function directly + title = await auto_generate_title( + conversation_id, user_id, file_store, settings, llm_registry + ) - # Verify the result is a truncated version of the message - assert title == 'This is a very long message th...' - assert len(title) <= 35 + # Verify the result is a truncated version of the message + assert title == 'This is a very long message th...' + assert len(title) <= 35 - # Verify EventStore was created with the correct parameters - mock_event_store_cls.assert_called_once_with( - conversation_id, file_store, user_id - ) + # Verify EventStore was created with the correct parameters + mock_event_store_cls.assert_called_once_with( + conversation_id, file_store, user_id + ) @pytest.mark.asyncio @@ -143,6 +133,7 @@ async def test_auto_generate_title_no_messages(): """Test auto-generating a title when there are no user messages.""" # Mock dependencies file_store = InMemoryFileStore() + llm_registry = MagicMock(spec=LLMRegistry) # Create test conversation with no messages conversation_id = 'test-conversation' @@ -166,7 +157,7 @@ async def test_auto_generate_title_no_messages(): # Call the auto_generate_title function directly title = await auto_generate_title( - conversation_id, user_id, file_store, settings + conversation_id, user_id, file_store, settings, llm_registry ) # Verify the result is empty @@ -186,6 +177,7 @@ async def test_update_conversation_with_title(): sio.emit = AsyncMock() file_store = InMemoryFileStore() server_config = MagicMock() + llm_registry = MagicMock(spec=LLMRegistry) # Create test conversation conversation_id = 'test-conversation' @@ -222,7 +214,9 @@ async def test_update_conversation_with_title(): AsyncMock(return_value='Generated Title'), ): # Call the method - await manager._update_conversation_for_event(user_id, conversation_id, settings) + await manager._update_conversation_for_event( + user_id, conversation_id, settings, llm_registry + ) # Verify the title was updated assert mock_metadata.title == 'Generated Title' diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 80b53e8869..1ae525b8ec 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -6,6 +6,7 @@ import pytest_asyncio from openhands.cli import main as cli from openhands.controller.state.state import State +from openhands.core.config.llm_config import LLMConfig from openhands.events import EventSource from openhands.events.action import MessageAction @@ -124,12 +125,14 @@ def mock_config(): '' # Empty string, not starting with 'tvly-' ) config.search_api_key = search_api_key_mock + config.get_llm_config_from_agent.return_value = LLMConfig(model='model') # Mock sandbox with volumes attribute to prevent finalize_config issues config.sandbox = MagicMock() config.sandbox.volumes = ( None # This prevents finalize_config from overriding workspace_base ) + config.model_name = 'model' return config @@ -213,7 +216,11 @@ async def test_run_session_without_initial_action( # Assertions for initialization flow mock_display_runtime_init.assert_called_once_with('local') mock_display_animation.assert_called_once() - mock_create_agent.assert_called_once_with(mock_config) + # Check that mock_config is the first parameter to create_agent + mock_create_agent.assert_called_once() + assert mock_create_agent.call_args[0][0] == mock_config, ( + 'First parameter to create_agent should be mock_config' + ) mock_add_mcp_tools.assert_called_once_with(mock_agent, mock_runtime, mock_memory) mock_create_runtime.assert_called_once() mock_create_controller.assert_called_once() diff --git a/tests/unit/test_cli_openhands_provider_auth_error.py b/tests/unit/test_cli_openhands_provider_auth_error.py index a955b86e21..60579ca693 100644 --- a/tests/unit/test_cli_openhands_provider_auth_error.py +++ b/tests/unit/test_cli_openhands_provider_auth_error.py @@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from litellm.exceptions import AuthenticationError +from pydantic import SecretStr from openhands.cli import main as cli +from openhands.core.config.llm_config import LLMConfig from openhands.events import EventSource from openhands.events.action import MessageAction @@ -45,11 +47,10 @@ def mock_config(): config.workspace_base = '/test/dir' # Set up LLM config to use OpenHands provider - llm_config = MagicMock() + llm_config = LLMConfig(model='openhands/o3', api_key=SecretStr('invalid-api-key')) llm_config.model = 'openhands/o3' # Use OpenHands provider with o3 model - llm_config.api_key = MagicMock() - llm_config.api_key.get_secret_value.return_value = 'invalid-api-key' - config.llm = llm_config + config.get_llm_config.return_value = llm_config + config.get_llm_config_from_agent.return_value = llm_config # Mock search_api_key with get_secret_value method search_api_key_mock = MagicMock() diff --git a/tests/unit/test_cli_runtime_mcp.py b/tests/unit/test_cli_runtime_mcp.py index 1ec246d7eb..9330b73711 100644 --- a/tests/unit/test_cli_runtime_mcp.py +++ b/tests/unit/test_cli_runtime_mcp.py @@ -13,6 +13,7 @@ from openhands.core.config.mcp_config import ( from openhands.events.action.mcp import MCPAction from openhands.events.observation import ErrorObservation from openhands.events.observation.mcp import MCPObservation +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.impl.cli.cli_runtime import CLIRuntime @@ -23,8 +24,12 @@ class TestCLIRuntimeMCP: """Set up test fixtures.""" self.config = OpenHandsConfig() self.event_stream = MagicMock() + llm_registry = LLMRegistry(config=OpenHandsConfig()) self.runtime = CLIRuntime( - config=self.config, event_stream=self.event_stream, sid='test-session' + config=self.config, + event_stream=self.event_stream, + sid='test-session', + llm_registry=llm_registry, ) @pytest.mark.asyncio diff --git a/tests/unit/test_cli_workspace.py b/tests/unit/test_cli_workspace.py index 43822fd580..1a0deed394 100644 --- a/tests/unit/test_cli_workspace.py +++ b/tests/unit/test_cli_workspace.py @@ -7,10 +7,18 @@ import pytest from openhands.core.config import OpenHandsConfig from openhands.events import EventStream + +# Mock LLMRegistry from openhands.runtime.impl.cli.cli_runtime import CLIRuntime from openhands.storage import get_file_store +# Create a mock LLMRegistry class +class MockLLMRegistry: + def __init__(self, config): + self.config = config + + @pytest.fixture def temp_dir(): """Create a temporary directory for testing.""" @@ -25,7 +33,8 @@ def cli_runtime(temp_dir): event_stream = EventStream('test', file_store) config = OpenHandsConfig() config.workspace_base = temp_dir - runtime = CLIRuntime(config, event_stream) + llm_registry = MockLLMRegistry(config) + runtime = CLIRuntime(config, event_stream, llm_registry) runtime._runtime_initialized = True # Skip initialization return runtime diff --git a/tests/unit/test_condenser.py b/tests/unit/test_condenser.py index 4eb3616c05..f21f39305a 100644 --- a/tests/unit/test_condenser.py +++ b/tests/unit/test_condenser.py @@ -17,6 +17,7 @@ from openhands.core.config.condenser_config import ( StructuredSummaryCondenserConfig, ) from openhands.core.config.llm_config import LLMConfig +from openhands.core.config.openhands_config import OpenHandsConfig from openhands.core.message import Message, TextContent from openhands.core.schema.action import ActionType from openhands.events.event import Event, EventSource @@ -24,6 +25,7 @@ from openhands.events.observation import BrowserOutputObservation from openhands.events.observation.agent import AgentCondensationObservation from openhands.events.observation.observation import Observation from openhands.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.memory.condenser import Condenser from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View from openhands.memory.condenser.impl import ( @@ -38,6 +40,7 @@ from openhands.memory.condenser.impl import ( StructuredSummaryCondenser, ) from openhands.memory.condenser.impl.pipeline import CondenserPipeline +from openhands.server.services.conversation_stats import ConversationStats def create_test_event( @@ -56,12 +59,15 @@ def create_test_event( @pytest.fixture def mock_llm() -> LLM: """Mocks an LLM object with a utility function for setting and resetting response contents in unit tests.""" + # Create a real LLMConfig instead of a mock to properly handle SecretStr api_key + real_config = LLMConfig( + model='gpt-4o', api_key='test_key', custom_llm_provider=None + ) + # Create a MagicMock for the LLM object mock_llm = MagicMock( spec=LLM, - config=MagicMock( - spec=LLMConfig, model='gpt-4o', api_key='test_key', custom_llm_provider=None - ), + config=real_config, metrics=MagicMock(), ) _mock_content = None @@ -95,6 +101,23 @@ def mock_llm() -> LLM: return mock_llm +@pytest.fixture +def mock_conversation_stats() -> ConversationStats: + """Creates a mock ConversationStats service.""" + mock_stats = MagicMock(spec=ConversationStats) + return mock_stats + + +@pytest.fixture +def mock_llm_registry(mock_llm, mock_conversation_stats) -> LLMRegistry: + """Creates an actual LLMRegistry that returns real LLMs.""" + # Create an actual LLMRegistry with a basic OpenHandsConfig + config = OpenHandsConfig() + registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None) + + return registry + + class RollingCondenserTestHarness: """Test harness for rolling condensers. @@ -165,10 +188,10 @@ class RollingCondenserTestHarness: return ((index - max_size) // target_size) + 1 -def test_noop_condenser_from_config(): +def test_noop_condenser_from_config(mock_llm_registry): """Test that the NoOpCondenser objects can be made from config.""" config = NoOpCondenserConfig() - condenser = Condenser.from_config(config) + condenser = Condenser.from_config(config, mock_llm_registry) assert isinstance(condenser, NoOpCondenser) @@ -189,11 +212,11 @@ def test_noop_condenser(): assert result == View(events=events) -def test_observation_masking_condenser_from_config(): +def test_observation_masking_condenser_from_config(mock_llm_registry): """Test that ObservationMaskingCondenser objects can be made from config.""" attention_window = 5 config = ObservationMaskingCondenserConfig(attention_window=attention_window) - condenser = Condenser.from_config(config) + condenser = Condenser.from_config(config, mock_llm_registry) assert isinstance(condenser, ObservationMaskingCondenser) assert condenser.attention_window == attention_window @@ -229,11 +252,11 @@ def test_observation_masking_condenser_respects_attention_window(): assert event == condensed_event -def test_browser_output_condenser_from_config(): +def test_browser_output_condenser_from_config(mock_llm_registry): """Test that BrowserOutputCondenser objects can be made from config.""" attention_window = 5 config = BrowserOutputCondenserConfig(attention_window=attention_window) - condenser = Condenser.from_config(config) + condenser = Condenser.from_config(config, mock_llm_registry) assert isinstance(condenser, BrowserOutputCondenser) assert condenser.attention_window == attention_window @@ -271,12 +294,12 @@ def test_browser_output_condenser_respects_attention_window(): assert event == condensed_event -def test_recent_events_condenser_from_config(): +def test_recent_events_condenser_from_config(mock_llm_registry): """Test that RecentEventsCondenser objects can be made from config.""" max_events = 5 keep_first = True config = RecentEventsCondenserConfig(keep_first=keep_first, max_events=max_events) - condenser = Condenser.from_config(config) + condenser = Condenser.from_config(config, mock_llm_registry) assert isinstance(condenser, RecentEventsCondenser) assert condenser.max_events == max_events @@ -334,14 +357,14 @@ def test_recent_events_condenser(): assert result[2]._message == 'Event 5' # kept from max_events -def test_llm_summarizing_condenser_from_config(): +def test_llm_summarizing_condenser_from_config(mock_llm_registry): """Test that LLMSummarizingCondenser objects can be made from config.""" config = LLMSummarizingCondenserConfig( max_size=50, keep_first=10, llm_config=LLMConfig(model='gpt-4o', api_key='test_key', caching_prompt=True), ) - condenser = Condenser.from_config(config) + condenser = Condenser.from_config(config, mock_llm_registry) assert isinstance(condenser, LLMSummarizingCondenser) assert condenser.llm.config.model == 'gpt-4o' @@ -349,25 +372,33 @@ def test_llm_summarizing_condenser_from_config(): assert condenser.max_size == 50 assert condenser.keep_first == 10 - # Since this condenser can't take advantage of caching, we intercept the - # passed config and manually flip the caching prompt to False. - assert not condenser.llm.config.caching_prompt - -def test_llm_summarizing_condenser_invalid_config(): +def test_llm_summarizing_condenser_invalid_config(mock_llm, mock_llm_registry): """Test that LLMSummarizingCondenser raises error when keep_first > max_size.""" pytest.raises( ValueError, LLMSummarizingCondenser, - llm=MagicMock(), + llm=mock_llm, max_size=4, keep_first=2, ) - pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), max_size=0) - pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), keep_first=-1) + pytest.raises( + ValueError, + LLMSummarizingCondenser, + llm=mock_llm, + max_size=0, + ) + pytest.raises( + ValueError, + LLMSummarizingCondenser, + llm=mock_llm, + keep_first=-1, + ) -def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm): +def test_llm_summarizing_condenser_gives_expected_view_size( + mock_llm, mock_llm_registry +): """Test that LLMSummarizingCondenser maintains the correct view size.""" max_size = 10 condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm) @@ -383,12 +414,16 @@ def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm): assert len(view) == harness.expected_size(i, max_size) -def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm): +def test_llm_summarizing_condenser_keeps_first_and_summary_events( + mock_llm, mock_llm_registry +): """Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events.""" max_size = 10 keep_first = 3 condenser = LLMSummarizingCondenser( - max_size=max_size, keep_first=keep_first, llm=mock_llm + max_size=max_size, + keep_first=keep_first, + llm=mock_llm, ) mock_llm.set_mock_response_content('Summary of forgotten events') @@ -412,14 +447,14 @@ def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm): assert isinstance(view[keep_first], AgentCondensationObservation) -def test_amortized_forgetting_condenser_from_config(): +def test_amortized_forgetting_condenser_from_config(mock_llm_registry): """Test that AmortizedForgettingCondenser objects can be made from config.""" max_size = 50 keep_first = 10 config = AmortizedForgettingCondenserConfig( max_size=max_size, keep_first=keep_first ) - condenser = Condenser.from_config(config) + condenser = Condenser.from_config(config, mock_llm_registry) assert isinstance(condenser, AmortizedForgettingCondenser) assert condenser.max_size == max_size @@ -475,7 +510,7 @@ def test_amortized_forgetting_condenser_keeps_first_and_last_events(): assert view[:keep_first] == events[: min(keep_first, i + 1)] -def test_llm_attention_condenser_from_config(): +def test_llm_attention_condenser_from_config(mock_llm_registry): """Test that LLMAttentionCondenser objects can be made from config.""" config = LLMAttentionCondenserConfig( max_size=50, @@ -486,37 +521,32 @@ def test_llm_attention_condenser_from_config(): caching_prompt=True, ), ) - condenser = Condenser.from_config(config) + condenser = Condenser.from_config(config, mock_llm_registry) assert isinstance(condenser, LLMAttentionCondenser) assert condenser.llm.config.model == 'gpt-4o' - assert condenser.llm.config.api_key.get_secret_value() == 'test_key' assert condenser.max_size == 50 assert condenser.keep_first == 10 - # Since this condenser can't take advantage of caching, we intercept the - # passed config and manually flip the caching prompt to False. - assert not condenser.llm.config.caching_prompt + # Create a mock LLM that doesn't support function calling + mock_llm = MagicMock() + mock_llm.is_function_calling_active.return_value = False + + # Create a new registry that returns our mock LLM that doesn't support function calling + mock_registry = MagicMock(spec=LLMRegistry) + mock_registry.get_llm.return_value = mock_llm + + pytest.raises(ValueError, LLMAttentionCondenser.from_config, config, mock_registry) -def test_llm_attention_condenser_invalid_config(): - """Test that LLMAttentionCondenser raises an error if the configured LLM doesn't support response schema.""" - config = LLMAttentionCondenserConfig( - max_size=50, - keep_first=10, - llm_config=LLMConfig( - model='claude-2', # Older model that doesn't support response schema - api_key='test_key', - ), - ) - - pytest.raises(ValueError, LLMAttentionCondenser.from_config, config) - - -def test_llm_attention_condenser_gives_expected_view_size(mock_llm): +def test_llm_attention_condenser_gives_expected_view_size(mock_llm, mock_llm_registry): """Test that the LLMAttentionCondenser gives views of the expected size.""" max_size = 10 - condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) + condenser = LLMAttentionCondenser( + max_size=max_size, + keep_first=0, + llm=mock_llm, + ) events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] @@ -534,10 +564,16 @@ def test_llm_attention_condenser_gives_expected_view_size(mock_llm): assert len(view) == harness.expected_size(i, max_size) -def test_llm_attention_condenser_handles_events_outside_history(mock_llm): +def test_llm_attention_condenser_handles_events_outside_history( + mock_llm, mock_llm_registry +): """Test that the LLMAttentionCondenser handles event IDs that aren't from the event history.""" max_size = 2 - condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) + condenser = LLMAttentionCondenser( + max_size=max_size, + keep_first=0, + llm=mock_llm, + ) events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] @@ -555,10 +591,14 @@ def test_llm_attention_condenser_handles_events_outside_history(mock_llm): assert len(view) == harness.expected_size(i, max_size) -def test_llm_attention_condenser_handles_too_many_events(mock_llm): +def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_llm_registry): """Test that the LLMAttentionCondenser handles when the response contains too many event IDs.""" max_size = 2 - condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) + condenser = LLMAttentionCondenser( + max_size=max_size, + keep_first=0, + llm=mock_llm, + ) events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] @@ -576,12 +616,16 @@ def test_llm_attention_condenser_handles_too_many_events(mock_llm): assert len(view) == harness.expected_size(i, max_size) -def test_llm_attention_condenser_handles_too_few_events(mock_llm): +def test_llm_attention_condenser_handles_too_few_events(mock_llm, mock_llm_registry): """Test that the LLMAttentionCondenser handles when the response contains too few event IDs.""" max_size = 2 # Developer note: We must specify keep_first=0 because # keep_first (1) >= max_size//2 (1) is invalid. - condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) + condenser = LLMAttentionCondenser( + max_size=max_size, + keep_first=0, + llm=mock_llm, + ) events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] @@ -597,12 +641,14 @@ def test_llm_attention_condenser_handles_too_few_events(mock_llm): assert len(view) == harness.expected_size(i, max_size) -def test_llm_attention_condenser_handles_keep_first_events(mock_llm): +def test_llm_attention_condenser_handles_keep_first_events(mock_llm, mock_llm_registry): """Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size).""" max_size = 12 keep_first = 4 condenser = LLMAttentionCondenser( - max_size=max_size, keep_first=keep_first, llm=mock_llm + max_size=max_size, + keep_first=keep_first, + llm=mock_llm, ) events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] @@ -620,7 +666,7 @@ def test_llm_attention_condenser_handles_keep_first_events(mock_llm): assert view[:keep_first] == events[: min(keep_first, i + 1)] -def test_structured_summary_condenser_from_config(): +def test_structured_summary_condenser_from_config(mock_llm_registry): """Test that StructuredSummaryCondenser objects can be made from config.""" config = StructuredSummaryCondenserConfig( max_size=50, @@ -631,7 +677,7 @@ def test_structured_summary_condenser_from_config(): caching_prompt=True, ), ) - condenser = Condenser.from_config(config) + condenser = Condenser.from_config(config, mock_llm_registry) assert isinstance(condenser, StructuredSummaryCondenser) assert condenser.llm.config.model == 'gpt-4o' @@ -639,40 +685,55 @@ def test_structured_summary_condenser_from_config(): assert condenser.max_size == 50 assert condenser.keep_first == 10 - # Since this condenser can't take advantage of caching, we intercept the - # passed config and manually flip the caching prompt to False. - assert not condenser.llm.config.caching_prompt - -def test_structured_summary_condenser_invalid_config(): +def test_structured_summary_condenser_invalid_config(mock_llm): """Test that StructuredSummaryCondenser raises error when keep_first > max_size.""" # Since the condenser only works when function calling is on, we need to # mock up the check for that. - llm = MagicMock() - llm.is_function_calling_active.return_value = True + mock_llm.is_function_calling_active.return_value = True pytest.raises( ValueError, StructuredSummaryCondenser, - llm=llm, + llm=mock_llm, max_size=4, keep_first=2, ) - pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, max_size=0) - pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, keep_first=-1) + pytest.raises( + ValueError, + StructuredSummaryCondenser, + llm=mock_llm, + max_size=0, + ) + pytest.raises( + ValueError, + StructuredSummaryCondenser, + llm=mock_llm, + keep_first=-1, + ) # If all other parameters are good but there's no function calling the # condenser still counts as improperly configured. - llm.is_function_calling_active.return_value = False + # Create a mock LLM that doesn't support function calling + mock_llm_no_func = MagicMock() + mock_llm_no_func.is_function_calling_active.return_value = False + pytest.raises( - ValueError, StructuredSummaryCondenser, llm=llm, max_size=40, keep_first=2 + ValueError, + StructuredSummaryCondenser, + llm=mock_llm_no_func, + max_size=40, + keep_first=2, ) -def test_structured_summary_condenser_gives_expected_view_size(mock_llm): +def test_structured_summary_condenser_gives_expected_view_size( + mock_llm, mock_llm_registry +): """Test that StructuredSummaryCondenser maintains the correct view size.""" max_size = 10 + mock_llm.is_function_calling_active.return_value = True condenser = StructuredSummaryCondenser(max_size=max_size, llm=mock_llm) events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] @@ -686,12 +747,17 @@ def test_structured_summary_condenser_gives_expected_view_size(mock_llm): assert len(view) == harness.expected_size(i, max_size) -def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm): +def test_structured_summary_condenser_keeps_first_and_summary_events( + mock_llm, mock_llm_registry +): """Test that the StructuredSummaryCondenser appropriately maintains the event prefix and any summary events.""" max_size = 10 keep_first = 3 + mock_llm.is_function_calling_active.return_value = True condenser = StructuredSummaryCondenser( - max_size=max_size, keep_first=keep_first, llm=mock_llm + max_size=max_size, + keep_first=keep_first, + llm=mock_llm, ) mock_llm.set_mock_response_content('Summary of forgotten events') @@ -715,7 +781,7 @@ def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm): assert isinstance(view[keep_first], AgentCondensationObservation) -def test_condenser_pipeline_from_config(): +def test_condenser_pipeline_from_config(mock_llm_registry): """Test that CondenserPipeline condensers can be created from configuration objects.""" config = CondenserPipelineConfig( condensers=[ @@ -728,7 +794,7 @@ def test_condenser_pipeline_from_config(): ), ] ) - condenser = Condenser.from_config(config) + condenser = Condenser.from_config(config, mock_llm_registry) assert isinstance(condenser, CondenserPipeline) assert len(condenser.condensers) == 3 diff --git a/tests/unit/test_conversation_stats.py b/tests/unit/test_conversation_stats.py new file mode 100644 index 0000000000..61de343f89 --- /dev/null +++ b/tests/unit/test_conversation_stats.py @@ -0,0 +1,490 @@ +import base64 +import pickle +from unittest.mock import patch + +import pytest + +from openhands.core.config import LLMConfig, OpenHandsConfig +from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry, RegistryEvent +from openhands.llm.metrics import Metrics +from openhands.server.services.conversation_stats import ConversationStats +from openhands.storage.memory import InMemoryFileStore + + +@pytest.fixture +def mock_file_store(): + """Create a mock file store for testing.""" + return InMemoryFileStore({}) + + +@pytest.fixture +def conversation_stats(mock_file_store): + """Create a ConversationStats instance for testing.""" + return ConversationStats( + file_store=mock_file_store, + conversation_id='test-conversation-id', + user_id='test-user-id', + ) + + +@pytest.fixture +def mock_llm_registry(): + """Create a mock LLM registry that properly simulates LLM registration.""" + config = OpenHandsConfig() + registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None) + return registry + + +@pytest.fixture +def connected_registry_and_stats(mock_llm_registry, conversation_stats): + """Connect the LLMRegistry and ConversationStats properly.""" + # Subscribe to LLM registry events to track metrics + mock_llm_registry.subscribe(conversation_stats.register_llm) + return mock_llm_registry, conversation_stats + + +def test_conversation_stats_initialization(conversation_stats): + """Test that ConversationStats initializes correctly.""" + assert conversation_stats.conversation_id == 'test-conversation-id' + assert conversation_stats.user_id == 'test-user-id' + assert conversation_stats.service_to_metrics == {} + assert isinstance(conversation_stats.restored_metrics, dict) + + +def test_save_metrics(conversation_stats, mock_file_store): + """Test that metrics are saved correctly.""" + # Add a service with metrics + service_id = 'test-service' + metrics = Metrics(model_name='gpt-4') + metrics.add_cost(0.05) + conversation_stats.service_to_metrics[service_id] = metrics + + # Save metrics + conversation_stats.save_metrics() + + # Verify that metrics were saved to the file store + try: + # Verify the saved content can be decoded and unpickled + encoded = mock_file_store.read(conversation_stats.metrics_path) + pickled = base64.b64decode(encoded) + restored = pickle.loads(pickled) + + assert service_id in restored + assert restored[service_id].accumulated_cost == 0.05 + except FileNotFoundError: + pytest.fail(f'File not found: {conversation_stats.metrics_path}') + + +def test_maybe_restore_metrics(mock_file_store): + """Test that metrics are restored correctly.""" + # Create metrics to save + service_id = 'test-service' + metrics = Metrics(model_name='gpt-4') + metrics.add_cost(0.1) + service_to_metrics = {service_id: metrics} + + # Serialize and save metrics + pickled = pickle.dumps(service_to_metrics) + serialized_metrics = base64.b64encode(pickled).decode('utf-8') + + # Create a new ConversationStats with pre-populated file store + conversation_id = 'test-conversation-id' + user_id = 'test-user-id' + + # Get the correct path using the same function as ConversationStats + from openhands.storage.locations import get_conversation_stats_filename + + metrics_path = get_conversation_stats_filename(conversation_id, user_id) + + # Write to the correct path + mock_file_store.write(metrics_path, serialized_metrics) + + # Create ConversationStats which should restore metrics + stats = ConversationStats( + file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id + ) + + # Verify metrics were restored + assert service_id in stats.restored_metrics + assert stats.restored_metrics[service_id].accumulated_cost == 0.1 + + +def test_get_combined_metrics(conversation_stats): + """Test that combined metrics are calculated correctly.""" + # Add multiple services with metrics + service1 = 'service1' + metrics1 = Metrics(model_name='gpt-4') + metrics1.add_cost(0.05) + metrics1.add_token_usage( + prompt_tokens=100, + completion_tokens=50, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=8000, + response_id='resp1', + ) + + service2 = 'service2' + metrics2 = Metrics(model_name='gpt-3.5') + metrics2.add_cost(0.02) + metrics2.add_token_usage( + prompt_tokens=200, + completion_tokens=100, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=4000, + response_id='resp2', + ) + + conversation_stats.service_to_metrics[service1] = metrics1 + conversation_stats.service_to_metrics[service2] = metrics2 + + # Get combined metrics + combined = conversation_stats.get_combined_metrics() + + # Verify combined metrics + assert combined.accumulated_cost == 0.07 # 0.05 + 0.02 + assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200 + assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100 + assert ( + combined.accumulated_token_usage.context_window == 8000 + ) # max of 8000 and 4000 + + +def test_get_metrics_for_service(conversation_stats): + """Test that metrics for a specific service are retrieved correctly.""" + # Add a service with metrics + service_id = 'test-service' + metrics = Metrics(model_name='gpt-4') + metrics.add_cost(0.05) + conversation_stats.service_to_metrics[service_id] = metrics + + # Get metrics for the service + retrieved_metrics = conversation_stats.get_metrics_for_service(service_id) + + # Verify metrics + assert retrieved_metrics.accumulated_cost == 0.05 + assert retrieved_metrics is metrics # Should be the same object + + # Test getting metrics for non-existent service + # Use a specific exception message pattern instead of a blind Exception + with pytest.raises(Exception, match='LLM service does not exist'): + conversation_stats.get_metrics_for_service('non-existent-service') + + +def test_register_llm_with_new_service(conversation_stats): + """Test registering a new LLM service.""" + # Create a real LLM instance with a mock config + llm_config = LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + # Patch the LLM class to avoid actual API calls + with patch('openhands.llm.llm.litellm_completion'): + llm = LLM(service_id='new-service', config=llm_config) + + # Create a registry event + service_id = 'new-service' + event = RegistryEvent(llm=llm, service_id=service_id) + + # Register the LLM + conversation_stats.register_llm(event) + + # Verify the service was registered + assert service_id in conversation_stats.service_to_metrics + assert conversation_stats.service_to_metrics[service_id] is llm.metrics + + +def test_register_llm_with_restored_metrics(conversation_stats): + """Test registering an LLM service with restored metrics.""" + # Create restored metrics + service_id = 'restored-service' + restored_metrics = Metrics(model_name='gpt-4') + restored_metrics.add_cost(0.1) + conversation_stats.restored_metrics = {service_id: restored_metrics} + + # Create a real LLM instance with a mock config + llm_config = LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + # Patch the LLM class to avoid actual API calls + with patch('openhands.llm.llm.litellm_completion'): + llm = LLM(service_id=service_id, config=llm_config) + + # Create a registry event + event = RegistryEvent(llm=llm, service_id=service_id) + + # Register the LLM + conversation_stats.register_llm(event) + + # Verify the service was registered with restored metrics + assert service_id in conversation_stats.service_to_metrics + assert conversation_stats.service_to_metrics[service_id] is llm.metrics + assert llm.metrics.accumulated_cost == 0.1 # Restored cost + + # Verify the specific service was removed from restored_metrics + assert service_id not in conversation_stats.restored_metrics + assert hasattr( + conversation_stats, 'restored_metrics' + ) # The dict should still exist + + +def test_llm_registry_notifications(connected_registry_and_stats): + """Test that LLM registry notifications update conversation stats.""" + mock_llm_registry, conversation_stats = connected_registry_and_stats + + # Create a new LLM through the registry + service_id = 'test-service' + llm_config = LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + # Get LLM from registry (this should trigger the notification) + llm = mock_llm_registry.get_llm(service_id, llm_config) + + # Verify the service was registered in conversation stats + assert service_id in conversation_stats.service_to_metrics + assert conversation_stats.service_to_metrics[service_id] is llm.metrics + + # Add some metrics to the LLM + llm.metrics.add_cost(0.05) + llm.metrics.add_token_usage( + prompt_tokens=100, + completion_tokens=50, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=8000, + response_id='resp1', + ) + + # Verify the metrics are reflected in conversation stats + assert conversation_stats.service_to_metrics[service_id].accumulated_cost == 0.05 + assert ( + conversation_stats.service_to_metrics[ + service_id + ].accumulated_token_usage.prompt_tokens + == 100 + ) + assert ( + conversation_stats.service_to_metrics[ + service_id + ].accumulated_token_usage.completion_tokens + == 50 + ) + + # Get combined metrics and verify + combined = conversation_stats.get_combined_metrics() + assert combined.accumulated_cost == 0.05 + assert combined.accumulated_token_usage.prompt_tokens == 100 + assert combined.accumulated_token_usage.completion_tokens == 50 + + +def test_multiple_llm_services(connected_registry_and_stats): + """Test tracking metrics for multiple LLM services.""" + mock_llm_registry, conversation_stats = connected_registry_and_stats + + # Create multiple LLMs through the registry + service1 = 'service1' + service2 = 'service2' + + llm_config1 = LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + llm_config2 = LLMConfig( + model='gpt-3.5-turbo', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + # Get LLMs from registry (this should trigger notifications) + llm1 = mock_llm_registry.get_llm(service1, llm_config1) + llm2 = mock_llm_registry.get_llm(service2, llm_config2) + + # Add different metrics to each LLM + llm1.metrics.add_cost(0.05) + llm1.metrics.add_token_usage( + prompt_tokens=100, + completion_tokens=50, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=8000, + response_id='resp1', + ) + + llm2.metrics.add_cost(0.02) + llm2.metrics.add_token_usage( + prompt_tokens=200, + completion_tokens=100, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=4000, + response_id='resp2', + ) + + # Verify services were registered in conversation stats + assert service1 in conversation_stats.service_to_metrics + assert service2 in conversation_stats.service_to_metrics + + # Verify individual metrics + assert conversation_stats.service_to_metrics[service1].accumulated_cost == 0.05 + assert conversation_stats.service_to_metrics[service2].accumulated_cost == 0.02 + + # Get combined metrics and verify + combined = conversation_stats.get_combined_metrics() + assert combined.accumulated_cost == 0.07 # 0.05 + 0.02 + assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200 + assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100 + assert ( + combined.accumulated_token_usage.context_window == 8000 + ) # max of 8000 and 4000 + + +def test_register_llm_with_multiple_restored_services_bug(conversation_stats): + """Test that reproduces the bug where del self.restored_metrics deletes entire dict instead of specific service.""" + # Create restored metrics for multiple services + service_id_1 = 'service-1' + service_id_2 = 'service-2' + + restored_metrics_1 = Metrics(model_name='gpt-4') + restored_metrics_1.add_cost(0.1) + + restored_metrics_2 = Metrics(model_name='gpt-3.5') + restored_metrics_2.add_cost(0.05) + + # Set up restored metrics for both services + conversation_stats.restored_metrics = { + service_id_1: restored_metrics_1, + service_id_2: restored_metrics_2, + } + + # Create LLM configs + llm_config_1 = LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + llm_config_2 = LLMConfig( + model='gpt-3.5-turbo', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + # Patch the LLM class to avoid actual API calls + with patch('openhands.llm.llm.litellm_completion'): + # Register first LLM + llm_1 = LLM(service_id=service_id_1, config=llm_config_1) + event_1 = RegistryEvent(llm=llm_1, service_id=service_id_1) + conversation_stats.register_llm(event_1) + + # Verify first service was registered with restored metrics + assert service_id_1 in conversation_stats.service_to_metrics + assert llm_1.metrics.accumulated_cost == 0.1 + + # After registering first service, restored_metrics should still contain service_id_2 + assert service_id_2 in conversation_stats.restored_metrics + + # Register second LLM - this should also work with restored metrics + llm_2 = LLM(service_id=service_id_2, config=llm_config_2) + event_2 = RegistryEvent(llm=llm_2, service_id=service_id_2) + conversation_stats.register_llm(event_2) + + # Verify second service was registered with restored metrics + assert service_id_2 in conversation_stats.service_to_metrics + assert llm_2.metrics.accumulated_cost == 0.05 + + # After both services are registered, restored_metrics should be empty + assert len(conversation_stats.restored_metrics) == 0 + + +def test_save_and_restore_workflow(mock_file_store): + """Test the full workflow of saving and restoring metrics.""" + # Create initial conversation stats + conversation_id = 'test-conversation-id' + user_id = 'test-user-id' + + stats1 = ConversationStats( + file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id + ) + + # Add a service with metrics + service_id = 'test-service' + metrics = Metrics(model_name='gpt-4') + metrics.add_cost(0.05) + metrics.add_token_usage( + prompt_tokens=100, + completion_tokens=50, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=8000, + response_id='resp1', + ) + stats1.service_to_metrics[service_id] = metrics + + # Save metrics + stats1.save_metrics() + + # Create a new conversation stats instance that should restore the metrics + stats2 = ConversationStats( + file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id + ) + + # Verify metrics were restored + assert service_id in stats2.restored_metrics + assert stats2.restored_metrics[service_id].accumulated_cost == 0.05 + assert ( + stats2.restored_metrics[service_id].accumulated_token_usage.prompt_tokens == 100 + ) + assert ( + stats2.restored_metrics[service_id].accumulated_token_usage.completion_tokens + == 50 + ) + + # Create a real LLM instance with a mock config + llm_config = LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + # Patch the LLM class to avoid actual API calls + with patch('openhands.llm.llm.litellm_completion'): + llm = LLM(service_id=service_id, config=llm_config) + + # Create a registry event + event = RegistryEvent(llm=llm, service_id=service_id) + + # Register the LLM to trigger restoration + stats2.register_llm(event) + + # Verify metrics were applied to the LLM + assert llm.metrics.accumulated_cost == 0.05 + assert llm.metrics.accumulated_token_usage.prompt_tokens == 100 + assert llm.metrics.accumulated_token_usage.completion_tokens == 50 diff --git a/tests/unit/test_conversation_summary.py b/tests/unit/test_conversation_summary.py index b688b14fd4..4fad19b9cb 100644 --- a/tests/unit/test_conversation_summary.py +++ b/tests/unit/test_conversation_summary.py @@ -1,6 +1,6 @@ """Tests for the conversation summary generator.""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -11,55 +11,51 @@ from openhands.utils.conversation_summary import generate_conversation_title @pytest.mark.asyncio async def test_generate_conversation_title_empty_message(): """Test that an empty message returns None.""" - result = await generate_conversation_title('', MagicMock()) + mock_llm_registry = MagicMock() + mock_llm_config = LLMConfig(model='test-model') + + result = await generate_conversation_title('', mock_llm_config, mock_llm_registry) assert result is None - result = await generate_conversation_title(' ', MagicMock()) + result = await generate_conversation_title( + ' ', mock_llm_config, mock_llm_registry + ) assert result is None @pytest.mark.asyncio async def test_generate_conversation_title_success(): """Test successful title generation.""" - # Create a proper mock response - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = 'Generated Title' + # Create a mock LLM registry that returns a title + mock_llm_registry = MagicMock() + mock_llm_registry.request_extraneous_completion.return_value = 'Generated Title' - # Create a mock LLM instance with a synchronous completion method - mock_llm = MagicMock() - mock_llm.completion = MagicMock(return_value=mock_response) + mock_llm_config = LLMConfig(model='test-model') - # Patch the LLM class to return our mock - with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm): - result = await generate_conversation_title( - 'Can you help me with Python?', LLMConfig(model='test-model') - ) + result = await generate_conversation_title( + 'Can you help me with Python?', mock_llm_config, mock_llm_registry + ) assert result == 'Generated Title' # Verify the mock was called with the expected arguments - mock_llm.completion.assert_called_once() + mock_llm_registry.request_extraneous_completion.assert_called_once() @pytest.mark.asyncio async def test_generate_conversation_title_long_title(): """Test that long titles are truncated.""" - # Create a proper mock response with a long title - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[ - 0 - ].message.content = 'This is a very long title that should be truncated because it exceeds the maximum length' + # Create a mock LLM registry that returns a long title + mock_llm_registry = MagicMock() + mock_llm_registry.request_extraneous_completion.return_value = 'This is a very long title that should be truncated because it exceeds the maximum length' - # Create a mock LLM instance with a synchronous completion method - mock_llm = MagicMock() - mock_llm.completion = MagicMock(return_value=mock_response) + mock_llm_config = LLMConfig(model='test-model') - # Patch the LLM class to return our mock - with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm): - result = await generate_conversation_title( - 'Can you help me with Python?', LLMConfig(model='test-model'), max_length=30 - ) + result = await generate_conversation_title( + 'Can you help me with Python?', + mock_llm_config, + mock_llm_registry, + max_length=30, + ) # Verify the title is truncated correctly assert len(result) <= 30 @@ -69,15 +65,17 @@ async def test_generate_conversation_title_long_title(): @pytest.mark.asyncio async def test_generate_conversation_title_exception(): """Test that exceptions are handled gracefully.""" - # Create a mock LLM instance with a synchronous completion method that raises an exception - mock_llm = MagicMock() - mock_llm.completion = MagicMock(side_effect=Exception('Test error')) + # Create a mock LLM registry that raises an exception + mock_llm_registry = MagicMock() + mock_llm_registry.request_extraneous_completion.side_effect = Exception( + 'Test error' + ) - # Patch the LLM class to return our mock - with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm): - result = await generate_conversation_title( - 'Can you help me with Python?', LLMConfig(model='test-model') - ) + mock_llm_config = LLMConfig(model='test-model') + + result = await generate_conversation_title( + 'Can you help me with Python?', mock_llm_config, mock_llm_registry + ) # Verify that None is returned when an exception occurs assert result is None diff --git a/tests/unit/test_docker_runtime.py b/tests/unit/test_docker_runtime.py index de0581ee18..e8cbcdb9c3 100644 --- a/tests/unit/test_docker_runtime.py +++ b/tests/unit/test_docker_runtime.py @@ -4,6 +4,7 @@ import pytest from openhands.core.config import OpenHandsConfig from openhands.events import EventStream +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.impl.docker.docker_runtime import DockerRuntime @@ -40,12 +41,17 @@ def event_stream(): return MagicMock(spec=EventStream) +@pytest.fixture +def llm_registry(): + return MagicMock(spec=LLMRegistry) + + @patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers') def test_container_stopped_when_keep_runtime_alive_false( - mock_stop_containers, mock_docker_client, config, event_stream + mock_stop_containers, mock_docker_client, config, event_stream, llm_registry ): # Arrange - runtime = DockerRuntime(config, event_stream, sid='test-sid') + runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid') runtime.container = mock_docker_client.containers.get.return_value # Act @@ -57,11 +63,11 @@ def test_container_stopped_when_keep_runtime_alive_false( @patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers') def test_container_not_stopped_when_keep_runtime_alive_true( - mock_stop_containers, mock_docker_client, config, event_stream + mock_stop_containers, mock_docker_client, config, event_stream, llm_registry ): # Arrange config.sandbox.keep_runtime_alive = True - runtime = DockerRuntime(config, event_stream, sid='test-sid') + runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid') runtime.container = mock_docker_client.containers.get.return_value # Act diff --git a/tests/unit/test_llm_registry.py b/tests/unit/test_llm_registry.py new file mode 100644 index 0000000000..beb873966d --- /dev/null +++ b/tests/unit/test_llm_registry.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + +from openhands.core.config.llm_config import LLMConfig +from openhands.core.config.openhands_config import OpenHandsConfig +from openhands.llm.llm_registry import LLMRegistry, RegistryEvent + + +class TestLLMRegistry(unittest.TestCase): + def setUp(self): + """Set up test environment before each test.""" + # Create a basic LLM config for testing + self.llm_config = LLMConfig(model='test-model') + + # Create a basic OpenHands config for testing + self.config = OpenHandsConfig( + llms={'llm': self.llm_config}, default_agent='CodeActAgent' + ) + + # Create a registry for testing + self.registry = LLMRegistry(config=self.config) + + def test_get_llm_creates_new_llm(self): + """Test that get_llm creates a new LLM when service doesn't exist.""" + service_id = 'test-service' + + # Mock the _create_new_llm method to avoid actual LLM initialization + with patch.object(self.registry, '_create_new_llm') as mock_create: + mock_llm = MagicMock() + mock_llm.config = self.llm_config + mock_create.return_value = mock_llm + + # Get LLM for the first time + llm = self.registry.get_llm(service_id, self.llm_config) + + # Verify LLM was created and stored + self.assertEqual(llm, mock_llm) + mock_create.assert_called_once_with( + config=self.llm_config, service_id=service_id + ) + + def test_get_llm_returns_existing_llm(self): + """Test that get_llm returns existing LLM when service already exists.""" + service_id = 'test-service' + + # Mock the _create_new_llm method to avoid actual LLM initialization + with patch.object(self.registry, '_create_new_llm') as mock_create: + mock_llm = MagicMock() + mock_llm.config = self.llm_config + mock_create.return_value = mock_llm + + # Get LLM for the first time + llm1 = self.registry.get_llm(service_id, self.llm_config) + + # Manually add to registry to simulate existing LLM + self.registry.service_to_llm[service_id] = mock_llm + + # Get LLM for the second time - should return the same instance + llm2 = self.registry.get_llm(service_id, self.llm_config) + + # Verify same LLM instance is returned + self.assertEqual(llm1, llm2) + self.assertEqual(llm1, mock_llm) + + # Verify _create_new_llm was only called once + mock_create.assert_called_once() + + def test_get_llm_with_different_config_raises_error(self): + """Test that requesting same service ID with different config raises an error.""" + service_id = 'test-service' + different_config = LLMConfig(model='different-model') + + # Manually add an LLM to the registry to simulate existing service + mock_llm = MagicMock() + mock_llm.config = self.llm_config + self.registry.service_to_llm[service_id] = mock_llm + + # Attempt to get LLM with different config should raise ValueError + with self.assertRaises(ValueError) as context: + self.registry.get_llm(service_id, different_config) + + self.assertIn('Requesting same service ID', str(context.exception)) + self.assertIn('with different config', str(context.exception)) + + def test_get_llm_without_config_raises_error(self): + """Test that requesting new LLM without config raises an error.""" + service_id = 'test-service' + + # Attempt to get LLM without providing config should raise ValueError + with self.assertRaises(ValueError) as context: + self.registry.get_llm(service_id, None) + + self.assertIn( + 'Requesting new LLM without specifying LLM config', str(context.exception) + ) + + def test_request_extraneous_completion(self): + """Test that requesting an extraneous completion creates a new LLM if needed.""" + service_id = 'extraneous-service' + messages = [{'role': 'user', 'content': 'Hello, world!'}] + + # Mock the _create_new_llm method to avoid actual LLM initialization + with patch.object(self.registry, '_create_new_llm') as mock_create: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = ' Hello from the LLM! ' + mock_llm.completion.return_value = mock_response + mock_create.return_value = mock_llm + + # Mock the side effect to add the LLM to the registry + def side_effect(*args, **kwargs): + self.registry.service_to_llm[service_id] = mock_llm + return mock_llm + + mock_create.side_effect = side_effect + + # Request a completion + response = self.registry.request_extraneous_completion( + service_id=service_id, + llm_config=self.llm_config, + messages=messages, + ) + + # Verify the response (should be stripped) + self.assertEqual(response, 'Hello from the LLM!') + + # Verify that _create_new_llm was called with correct parameters + mock_create.assert_called_once_with( + config=self.llm_config, service_id=service_id, with_listener=False + ) + + # Verify completion was called with correct messages + mock_llm.completion.assert_called_once_with(messages=messages) + + def test_get_active_llm(self): + """Test that get_active_llm returns the active agent LLM.""" + active_llm = self.registry.get_active_llm() + self.assertEqual(active_llm, self.registry.active_agent_llm) + + def test_subscribe_and_notify(self): + """Test the subscription and notification system.""" + events_received = [] + + def callback(event: RegistryEvent): + events_received.append(event) + + # Subscribe to events + self.registry.subscribe(callback) + + # Should receive notification for the active agent LLM + self.assertEqual(len(events_received), 1) + self.assertEqual(events_received[0].llm, self.registry.active_agent_llm) + self.assertEqual( + events_received[0].service_id, self.registry.active_agent_llm.service_id + ) + + # Test that the subscriber is set correctly + self.assertIsNotNone(self.registry.subscriber) + + # Test notify method directly with a mock event + with patch.object(self.registry, 'subscriber') as mock_subscriber: + mock_event = MagicMock() + self.registry.notify(mock_event) + mock_subscriber.assert_called_once_with(mock_event) + + def test_registry_has_unique_id(self): + """Test that each registry instance has a unique ID.""" + registry2 = LLMRegistry(config=self.config) + self.assertNotEqual(self.registry.registry_id, registry2.registry_id) + self.assertTrue(len(self.registry.registry_id) > 0) + self.assertTrue(len(registry2.registry_id) > 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_mcp_config.py b/tests/unit/test_mcp_config.py index 7ceb1b685c..e1455bc4ea 100644 --- a/tests/unit/test_mcp_config.py +++ b/tests/unit/test_mcp_config.py @@ -12,6 +12,8 @@ from openhands.core.config.mcp_config import ( MCPSSEServerConfig, MCPStdioServerConfig, ) +from openhands.llm.llm_registry import LLMRegistry +from openhands.server.services.conversation_stats import ConversationStats from openhands.server.session.conversation_init_data import ConversationInitData from openhands.server.session.session import Session from openhands.storage.memory import InMemoryFileStore @@ -428,6 +430,8 @@ async def test_session_preserves_env_mcp_config(monkeypatch): file_store=InMemoryFileStore({}), config=config, sio=AsyncMock(), + llm_registry=LLMRegistry(config=OpenHandsConfig()), + convo_stats=ConversationStats(None, 'test-sid', None), ) # Create empty settings diff --git a/tests/unit/test_mcp_tool_timeout_stall.py b/tests/unit/test_mcp_tool_timeout_stall.py index 6d53ffa907..e5ce5a72e0 100644 --- a/tests/unit/test_mcp_tool_timeout_stall.py +++ b/tests/unit/test_mcp_tool_timeout_stall.py @@ -8,7 +8,8 @@ import pytest from mcp import McpError from openhands.controller.agent import Agent -from openhands.controller.agent_controller import AgentController, AgentState +from openhands.controller.agent_controller import AgentController +from openhands.core.schema import AgentState from openhands.events.action.mcp import MCPAction from openhands.events.action.message import SystemMessageAction from openhands.events.event import EventSource @@ -17,6 +18,8 @@ from openhands.events.stream import EventStream from openhands.mcp.client import MCPClient from openhands.mcp.tool import MCPClientTool from openhands.mcp.utils import call_tool_mcp +from openhands.server.services.conversation_stats import ConversationStats +from openhands.storage.memory import InMemoryFileStore class MockConfig: @@ -34,6 +37,11 @@ class MockLLM: self.config = MockConfig() +@pytest.fixture +def convo_stats(): + return ConversationStats(None, 'convo-id', None) + + class MockAgent(Agent): """Mock agent for testing.""" @@ -53,7 +61,7 @@ class MockAgent(Agent): @pytest.mark.asyncio -async def test_mcp_tool_timeout_error_handling(): +async def test_mcp_tool_timeout_error_handling(convo_stats): """Test that verifies MCP tool timeout errors are properly handled and returned as observations.""" # Create a mock MCPClient mock_client = mock.MagicMock(spec=MCPClient) @@ -80,7 +88,7 @@ async def test_mcp_tool_timeout_error_handling(): mock_client.tool_map = {'test_tool': mock_tool} # Create a mock file store - mock_file_store = mock.MagicMock() + mock_file_store = InMemoryFileStore({}) # Create a mock event stream event_stream = EventStream(sid='test-session', file_store=mock_file_store) @@ -90,13 +98,12 @@ async def test_mcp_tool_timeout_error_handling(): # Create a mock agent controller controller = AgentController( - sid='test-session', - file_store=mock_file_store, - user_id='test-user', agent=agent, event_stream=event_stream, + convo_stats=convo_stats, iteration_delta=10, budget_per_task_delta=None, + sid='test-session', ) # Set up the agent state @@ -143,7 +150,7 @@ async def test_mcp_tool_timeout_error_handling(): @pytest.mark.asyncio -async def test_mcp_tool_timeout_agent_continuation(): +async def test_mcp_tool_timeout_agent_continuation(convo_stats): """Test that verifies the agent can continue processing after an MCP tool timeout.""" # Create a mock MCPClient mock_client = mock.MagicMock(spec=MCPClient) @@ -170,7 +177,7 @@ async def test_mcp_tool_timeout_agent_continuation(): mock_client.tool_map = {'test_tool': mock_tool} # Create a mock file store - mock_file_store = mock.MagicMock() + mock_file_store = InMemoryFileStore({}) # Create a mock event stream event_stream = EventStream(sid='test-session', file_store=mock_file_store) @@ -180,13 +187,12 @@ async def test_mcp_tool_timeout_agent_continuation(): # Create a mock agent controller controller = AgentController( - sid='test-session', - file_store=mock_file_store, - user_id='test-user', agent=agent, event_stream=event_stream, + convo_stats=convo_stats, iteration_delta=10, budget_per_task_delta=None, + sid='test-session', ) # Set up the agent state diff --git a/tests/unit/test_memory.py b/tests/unit/test_memory.py index 91e6dc261a..b01253ecf2 100644 --- a/tests/unit/test_memory.py +++ b/tests/unit/test_memory.py @@ -21,11 +21,13 @@ from openhands.events.observation.agent import ( from openhands.events.serialization.observation import observation_from_dict from openhands.events.stream import EventStream from openhands.llm import LLM +from openhands.llm.llm_registry import LLMRegistry from openhands.llm.metrics import Metrics from openhands.memory.memory import Memory from openhands.runtime.impl.action_execution.action_execution_client import ( ActionExecutionClient, ) +from openhands.server.services.conversation_stats import ConversationStats from openhands.server.session.agent_session import AgentSession from openhands.storage.memory import InMemoryFileStore from openhands.utils.prompt import ( @@ -42,6 +44,12 @@ def file_store(): return InMemoryFileStore({}) +@pytest.fixture +def mock_llm_registry(file_store): + """Create a mock LLMRegistry for testing.""" + return MagicMock(spec=LLMRegistry) + + @pytest.fixture def event_stream(file_store): """Create a test event stream.""" @@ -90,24 +98,29 @@ def mock_agent(): @pytest.mark.asyncio -async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent): +async def test_memory_on_event_exception_handling( + memory, event_stream, mock_agent, mock_llm_registry +): """Test that exceptions in Memory.on_event are properly handled via status callback.""" # Create a mock runtime runtime = MagicMock(spec=ActionExecutionClient) runtime.event_stream = event_stream # Mock Memory method to raise an exception - with patch.object( - memory, '_on_workspace_context_recall', side_effect=Exception('Test error') + with ( + patch.object( + memory, '_on_workspace_context_recall', side_effect=Exception('Test error') + ), + patch('openhands.core.main.create_agent', return_value=mock_agent), ): state = await run_controller( config=OpenHandsConfig(), initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', - agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', memory=memory, + llm_registry=mock_llm_registry, ) # Verify that the controller's last error was set @@ -118,7 +131,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream, mock_age @pytest.mark.asyncio async def test_memory_on_workspace_context_recall_exception_handling( - memory, event_stream, mock_agent + memory, event_stream, mock_agent, mock_llm_registry ): """Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback.""" # Create a mock runtime @@ -126,19 +139,22 @@ async def test_memory_on_workspace_context_recall_exception_handling( runtime.event_stream = event_stream # Mock Memory._on_workspace_context_recall to raise an exception - with patch.object( - memory, - '_find_microagent_knowledge', - side_effect=Exception('Test error from _find_microagent_knowledge'), + with ( + patch.object( + memory, + '_find_microagent_knowledge', + side_effect=Exception('Test error from _find_microagent_knowledge'), + ), + patch('openhands.core.main.create_agent', return_value=mock_agent), ): state = await run_controller( config=OpenHandsConfig(), initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', - agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', memory=memory, + llm_registry=mock_llm_registry, ) # Verify that the controller's last error was set @@ -593,12 +609,14 @@ REPOSITORY INSTRUCTIONS: This is the second test repository. @pytest.mark.asyncio async def test_conversation_instructions_plumbed_to_memory( - mock_agent, event_stream, file_store + mock_agent, event_stream, file_store, mock_llm_registry ): # Setup session = AgentSession( sid='test-session', file_store=file_store, + llm_registry=mock_llm_registry, + convo_stats=ConversationStats(file_store, 'test-session', None), ) # Create a mock runtime and set it up diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 963b590d3f..60cc0bb16f 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -3,26 +3,30 @@ from litellm import ModelResponse from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent from openhands.core.config import AgentConfig, LLMConfig +from openhands.core.config.openhands_config import OpenHandsConfig from openhands.events.action import MessageAction -from openhands.llm.llm import LLM +from openhands.llm.llm_registry import LLMRegistry @pytest.fixture -def mock_llm(): - llm = LLM( - LLMConfig( - model='claude-3-5-sonnet-20241022', - api_key='fake', - caching_prompt=True, - ) +def llm_config(): + return LLMConfig( + model='claude-3-5-sonnet-20241022', + api_key='fake', + caching_prompt=True, ) - return llm @pytest.fixture -def codeact_agent(mock_llm): +def llm_registry(): + registry = LLMRegistry(config=OpenHandsConfig()) + return registry + + +@pytest.fixture +def codeact_agent(llm_registry): config = AgentConfig() - agent = CodeActAgent(mock_llm, config) + agent = CodeActAgent(config, llm_registry) return agent diff --git a/tests/unit/test_runtime_git_tokens.py b/tests/unit/test_runtime_git_tokens.py index 55ed5b5816..aed15fd777 100644 --- a/tests/unit/test_runtime_git_tokens.py +++ b/tests/unit/test_runtime_git_tokens.py @@ -12,14 +12,26 @@ from openhands.events.observation import NullObservation, Observation from openhands.events.stream import EventStream from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType from openhands.integrations.service_types import AuthenticationError, Repository +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.base import Runtime from openhands.storage import get_file_store -class TestRuntime(Runtime): +class MockRuntime(Runtime): """A concrete implementation of Runtime for testing""" def __init__(self, *args, **kwargs): + # Ensure llm_registry is provided if not already in kwargs + if 'llm_registry' not in kwargs and len(args) < 3: + # Create a mock LLMRegistry if not provided + config = ( + kwargs.get('config') + if 'config' in kwargs + else args[0] + if args + else OpenHandsConfig() + ) + kwargs['llm_registry'] = LLMRegistry(config=config) super().__init__(*args, **kwargs) self.run_action_calls = [] self._execute_shell_fn_git_handler = MagicMock( @@ -89,9 +101,11 @@ def runtime(temp_dir): ) file_store = get_file_store('local', temp_dir) event_stream = EventStream('abc', file_store) - runtime = TestRuntime( + llm_registry = LLMRegistry(config=config) + runtime = MockRuntime( config=config, event_stream=event_stream, + llm_registry=llm_registry, sid='test', user_id='test_user', git_provider_tokens=git_provider_tokens, @@ -119,7 +133,7 @@ async def test_export_latest_git_provider_tokens_no_user_id(temp_dir): config = OpenHandsConfig() file_store = get_file_store('local', temp_dir) event_stream = EventStream('abc', file_store) - runtime = TestRuntime(config=config, event_stream=event_stream, sid='test') + runtime = MockRuntime(config=config, event_stream=event_stream, sid='test') # Create a command that would normally trigger token export cmd = CmdRunAction(command='echo $GITHUB_TOKEN') @@ -137,7 +151,7 @@ async def test_export_latest_git_provider_tokens_no_token_ref(temp_dir): config = OpenHandsConfig() file_store = get_file_store('local', temp_dir) event_stream = EventStream('abc', file_store) - runtime = TestRuntime( + runtime = MockRuntime( config=config, event_stream=event_stream, sid='test', user_id='test_user' ) @@ -177,7 +191,7 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir): ) file_store = get_file_store('local', temp_dir) event_stream = EventStream('abc', file_store) - runtime = TestRuntime( + runtime = MockRuntime( config=config, event_stream=event_stream, sid='test', @@ -225,7 +239,7 @@ async def test_clone_or_init_repo_no_repo_init_git_in_empty_workspace(temp_dir): config.init_git_in_empty_workspace = True file_store = get_file_store('local', temp_dir) event_stream = EventStream('abc', file_store) - runtime = TestRuntime( + runtime = MockRuntime( config=config, event_stream=event_stream, sid='test', user_id=None ) @@ -249,7 +263,7 @@ async def test_clone_or_init_repo_no_repo_no_user_id_with_workspace_base(temp_di config.workspace_base = '/some/path' # Set workspace_base file_store = get_file_store('local', temp_dir) event_stream = EventStream('abc', file_store) - runtime = TestRuntime( + runtime = MockRuntime( config=config, event_stream=event_stream, sid='test', user_id=None ) @@ -267,7 +281,7 @@ async def test_clone_or_init_repo_auth_error(temp_dir): config = OpenHandsConfig() file_store = get_file_store('local', temp_dir) event_stream = EventStream('abc', file_store) - runtime = TestRuntime( + runtime = MockRuntime( config=config, event_stream=event_stream, sid='test', user_id='test_user' ) @@ -298,7 +312,7 @@ async def test_clone_or_init_repo_github_with_token(temp_dir, monkeypatch): {ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))} ) - runtime = TestRuntime( + runtime = MockRuntime( config=config, event_stream=event_stream, sid='test', @@ -336,7 +350,7 @@ async def test_clone_or_init_repo_github_no_token(temp_dir, monkeypatch): file_store = get_file_store('local', temp_dir) event_stream = EventStream('abc', file_store) - runtime = TestRuntime( + runtime = MockRuntime( config=config, event_stream=event_stream, sid='test', user_id='test_user' ) @@ -371,7 +385,7 @@ async def test_clone_or_init_repo_gitlab_with_token(temp_dir, monkeypatch): {ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))} ) - runtime = TestRuntime( + runtime = MockRuntime( config=config, event_stream=event_stream, sid='test', @@ -410,7 +424,7 @@ async def test_clone_or_init_repo_with_branch(temp_dir, monkeypatch): file_store = get_file_store('local', temp_dir) event_stream = EventStream('abc', file_store) - runtime = TestRuntime( + runtime = MockRuntime( config=config, event_stream=event_stream, sid='test', user_id='test_user' ) diff --git a/tests/unit/test_runtime_gitlab_microagents.py b/tests/unit/test_runtime_gitlab_microagents.py index a363f4ed01..c4a108386b 100644 --- a/tests/unit/test_runtime_gitlab_microagents.py +++ b/tests/unit/test_runtime_gitlab_microagents.py @@ -9,10 +9,12 @@ import pytest from openhands.core.config import OpenHandsConfig, SandboxConfig from openhands.events import EventStream from openhands.integrations.service_types import ProviderType, Repository +from openhands.llm.llm_registry import LLMRegistry from openhands.microagent.microagent import ( RepoMicroagent, ) from openhands.runtime.base import Runtime +from openhands.storage import get_file_store class MockRuntime(Runtime): @@ -24,12 +26,21 @@ class MockRuntime(Runtime): config.workspace_mount_path_in_sandbox = str(workspace_root) config.sandbox = SandboxConfig() - # Create a mock event stream + # Create a mock event stream and file store + file_store = get_file_store('local', str(workspace_root)) event_stream = MagicMock(spec=EventStream) + event_stream.file_store = file_store + + # Create a mock LLM registry + llm_registry = LLMRegistry(config) # Initialize the parent class properly super().__init__( - config=config, event_stream=event_stream, sid='test', git_provider_tokens={} + config=config, + event_stream=event_stream, + llm_registry=llm_registry, + sid='test', + git_provider_tokens={}, ) self._workspace_root = workspace_root diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py index e9fa6c68bc..6c9f9b2f17 100644 --- a/tests/unit/test_security.py +++ b/tests/unit/test_security.py @@ -595,7 +595,7 @@ async def test_check_usertask( analyzer = InvariantAnalyzer(event_stream) mock_response = {'choices': [{'message': {'content': is_appropriate}}]} mock_litellm_completion.return_value = mock_response - analyzer.guardrail_llm = LLM(config=default_config) + analyzer.guardrail_llm = LLM(config=default_config, service_id='test') analyzer.check_browsing_alignment = True data = [ (MessageAction(usertask), EventSource.USER), @@ -657,7 +657,7 @@ async def test_check_fillaction( analyzer = InvariantAnalyzer(event_stream) mock_response = {'choices': [{'message': {'content': is_harmful}}]} mock_litellm_completion.return_value = mock_response - analyzer.guardrail_llm = LLM(config=default_config) + analyzer.guardrail_llm = LLM(config=default_config, service_id='test') analyzer.check_browsing_alignment = True data = [ (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT), diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 3e59bda492..e9211b6a4f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -7,7 +7,9 @@ from litellm.exceptions import ( from openhands.core.config.llm_config import LLMConfig from openhands.core.config.openhands_config import OpenHandsConfig +from openhands.llm.llm_registry import LLMRegistry from openhands.runtime.runtime_status import RuntimeStatus +from openhands.server.services.conversation_stats import ConversationStats from openhands.server.session.session import Session from openhands.storage.memory import InMemoryFileStore @@ -33,10 +35,28 @@ def default_llm_config(): ) +@pytest.fixture +def llm_registry(): + config = OpenHandsConfig() + return LLMRegistry(config=config) + + +@pytest.fixture +def conversation_stats(): + file_store = InMemoryFileStore({}) + return ConversationStats( + file_store=file_store, conversation_id='test-conversation', user_id='test-user' + ) + + @pytest.mark.asyncio @patch('openhands.llm.llm.litellm_completion') async def test_notify_on_llm_retry( - mock_litellm_completion, mock_sio, default_llm_config + mock_litellm_completion, + mock_sio, + default_llm_config, + llm_registry, + conversation_stats, ): config = OpenHandsConfig() config.set_llm_config(default_llm_config) @@ -44,6 +64,8 @@ async def test_notify_on_llm_retry( sid='..sid..', file_store=InMemoryFileStore({}), config=config, + llm_registry=llm_registry, + convo_stats=conversation_stats, sio=mock_sio, user_id='..uid..', ) @@ -56,12 +78,20 @@ async def test_notify_on_llm_retry( ), {'choices': [{'message': {'content': 'Retry successful'}}]}, ] - llm = session._create_llm('..cls..') - llm.completion( - messages=[{'role': 'user', 'content': 'Hello!'}], - stream=False, - ) + # Set the retry listener on the registry + llm_registry.retry_listner = session._notify_on_llm_retry + + # Create an LLM through the registry + llm = llm_registry.get_llm( + service_id='test_service', + config=default_llm_config, + ) + + llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) assert mock_litellm_completion.call_count == 2 session.queue_status_message.assert_called_once_with(