[Refactor]: Add LLMRegistry for llm services (#9589)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Graham Neubig <neubig@gmail.com>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Rohit Malhotra 2025-08-18 02:11:20 -04:00 committed by GitHub
parent 17b1a21296
commit 25d9cf2890
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
84 changed files with 2376 additions and 817 deletions

View File

@ -18,7 +18,7 @@ from openhands.events.action import (
from openhands.events.event import EventSource
from openhands.events.observation import BrowserOutputObservation
from openhands.events.observation.observation import Observation
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.plugins import (
PluginRequirement,
)
@ -102,15 +102,15 @@ class BrowsingAgent(Agent):
def __init__(
self,
llm: LLM,
config: AgentConfig,
llm_registry: LLMRegistry,
) -> None:
"""Initializes a new instance of the BrowsingAgent class.
Parameters:
- llm (LLM): The llm to be used by this agent
"""
super().__init__(llm, config)
super().__init__(config, llm_registry)
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
action_subsets = ['chat', 'bid']

View File

@ -3,6 +3,8 @@ import sys
from collections import deque
from typing import TYPE_CHECKING
from openhands.llm.llm_registry import LLMRegistry
if TYPE_CHECKING:
from litellm import ChatCompletionToolParam
@ -32,7 +34,6 @@ from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.events.action import AgentFinishAction, MessageAction
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.llm.llm_utils import check_tools
from openhands.memory.condenser import Condenser
from openhands.memory.condenser.condenser import Condensation, View
@ -74,18 +75,13 @@ class CodeActAgent(Agent):
JupyterRequirement(),
]
def __init__(
self,
llm: LLM,
config: AgentConfig,
) -> None:
def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None:
"""Initializes a new instance of the CodeActAgent class.
Parameters:
- llm (LLM): The llm to be used by this agent
- config (AgentConfig): The configuration for this agent
"""
super().__init__(llm, config)
super().__init__(config, llm_registry)
self.pending_actions: deque['Action'] = deque()
self.reset()
self.tools = self._get_tools()
@ -93,7 +89,7 @@ class CodeActAgent(Agent):
# Create a ConversationMemory instance
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
self.condenser = Condenser.from_config(self.config.condenser)
self.condenser = Condenser.from_config(self.config.condenser, llm_registry)
logger.debug(f'Using condenser: {type(self.condenser)}')
@property

View File

@ -22,7 +22,7 @@ from openhands.events.observation import (
Observation,
)
from openhands.events.serialization.event import event_to_dict
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
"""
FIXME: There are a few problems this surfaced
@ -42,8 +42,12 @@ class DummyAgent(Agent):
without making any LLM calls.
"""
def __init__(self, llm: LLM, config: AgentConfig):
super().__init__(llm, config)
def __init__(
self,
config: AgentConfig,
llm_registry: LLMRegistry,
):
super().__init__(config, llm_registry)
self.steps: list[ActionObs] = [
{
'action': MessageAction('Time to get started!'),

View File

@ -4,7 +4,7 @@ import openhands.agenthub.loc_agent.function_calling as locagent_function_callin
from openhands.agenthub.codeact_agent import CodeActAgent
from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
if TYPE_CHECKING:
from openhands.events.action import Action
@ -16,8 +16,8 @@ class LocAgent(CodeActAgent):
def __init__(
self,
llm: LLM,
config: AgentConfig,
llm_registry: LLMRegistry,
) -> None:
"""Initializes a new instance of the LocAgent class.
@ -25,7 +25,8 @@ class LocAgent(CodeActAgent):
- llm (LLM): The llm to be used by this agent
- config (AgentConfig): The configuration for the agent
"""
super().__init__(llm, config)
super().__init__(config, llm_registry)
self.tools = locagent_function_calling.get_tools()
logger.debug(

View File

@ -3,6 +3,8 @@
import os
from typing import TYPE_CHECKING
from openhands.llm.llm_registry import LLMRegistry
if TYPE_CHECKING:
from litellm import ChatCompletionToolParam
@ -15,7 +17,6 @@ from openhands.agenthub.readonly_agent import (
)
from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm import LLM
from openhands.utils.prompt import PromptManager
@ -37,17 +38,16 @@ class ReadOnlyAgent(CodeActAgent):
def __init__(
self,
llm: LLM,
config: AgentConfig,
llm_registry: LLMRegistry,
) -> None:
"""Initializes a new instance of the ReadOnlyAgent class.
Parameters:
- llm (LLM): The llm to be used by this agent
- config (AgentConfig): The configuration for this agent
"""
# Initialize the CodeActAgent class; some of it is overridden with class methods
super().__init__(llm, config)
super().__init__(config, llm_registry)
logger.debug(
f'TOOLS loaded for ReadOnlyAgent: {", ".join([tool.get("function").get("name") for tool in self.tools])}'

View File

@ -16,7 +16,7 @@ from openhands.events.action import (
from openhands.events.event import EventSource
from openhands.events.observation import BrowserOutputObservation
from openhands.events.observation.observation import Observation
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.plugins import (
PluginRequirement,
)
@ -127,17 +127,13 @@ class VisualBrowsingAgent(Agent):
sandbox_plugins: list[PluginRequirement] = []
response_parser = BrowsingResponseParser()
def __init__(
self,
llm: LLM,
config: AgentConfig,
) -> None:
def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None:
"""Initializes a new instance of the VisualBrowsingAgent class.
Parameters:
- llm (LLM): The llm to be used by this agent
"""
super().__init__(llm, config)
super().__init__(config, llm_registry)
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
action_subsets = [

View File

@ -83,6 +83,7 @@ from openhands.microagent.microagent import BaseMicroagent
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.storage.settings.file_settings_store import FileSettingsStore
from openhands.utils.utils import create_registry_and_convo_stats
async def cleanup_session(
@ -147,9 +148,16 @@ async def run_session(
None, display_initialization_animation, 'Initializing...', is_loaded
)
agent = create_agent(config)
llm_registry, convo_stats, config = create_registry_and_convo_stats(
config,
sid,
None,
)
agent = create_agent(config, llm_registry)
runtime = create_runtime(
config,
llm_registry,
sid=sid,
headless_mode=True,
agent=agent,
@ -161,7 +169,7 @@ async def run_session(
runtime.subscribe_to_shell_stream(stream_to_console)
controller, initial_state = create_controller(agent, runtime, config)
controller, initial_state = create_controller(agent, runtime, config, convo_stats)
event_stream = runtime.event_stream

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from openhands.llm.llm_registry import LLMRegistry
if TYPE_CHECKING:
from openhands.controller.state.state import State
from openhands.events.action import Action
@ -17,7 +19,6 @@ from openhands.core.exceptions import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import EventSource
from openhands.llm.llm import LLM
from openhands.runtime.plugins import PluginRequirement
@ -38,10 +39,11 @@ class Agent(ABC):
def __init__(
self,
llm: LLM,
config: AgentConfig,
llm_registry: LLMRegistry,
):
self.llm = llm
self.llm = llm_registry.get_llm_from_agent_config('agent', config)
self.llm_registry = llm_registry
self.config = config
self._complete = False
self._prompt_manager: 'PromptManager' | None = None

View File

@ -73,9 +73,9 @@ from openhands.events.observation import (
Observation,
)
from openhands.events.serialization.event import truncate_content
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.files import FileStore
# note: RESUME is only available on web GUI
@ -109,6 +109,7 @@ class AgentController:
self,
agent: Agent,
event_stream: EventStream,
convo_stats: ConversationStats,
iteration_delta: int,
budget_per_task_delta: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
@ -148,6 +149,7 @@ class AgentController:
self.agent = agent
self.headless_mode = headless_mode
self.is_delegate = is_delegate
self.convo_stats = convo_stats
# the event stream must be set before maybe subscribing to it
self.event_stream = event_stream
@ -163,6 +165,7 @@ class AgentController:
# state from the previous session, state from a parent agent, or a fresh state
self.set_initial_state(
state=initial_state,
convo_stats=convo_stats,
max_iterations=iteration_delta,
max_budget_per_task=budget_per_task_delta,
confirmation_mode=confirmation_mode,
@ -477,11 +480,6 @@ class AgentController:
log_level, str(observation_to_print), extra={'msg_type': 'OBSERVATION'}
)
# TODO: these metrics come from the draft editor, and they get accumulated into controller's state metrics and the agent's llm metrics
# In the future, we should have a more principled way to sharing metrics across all LLM instances for a given conversation
if observation.llm_metrics is not None:
self.state_tracker.merge_metrics(observation.llm_metrics)
# this happens for runnable actions and microagent actions
if self._pending_action and self._pending_action.id == observation.cause:
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
@ -657,14 +655,10 @@ class AgentController:
"""
agent_cls: type[Agent] = Agent.get_cls(action.agent)
agent_config = self.agent_configs.get(action.agent, self.agent.config)
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
# Make sure metrics are shared between parent and child for global accumulation
llm = LLM(
config=llm_config,
retry_listener=self.agent.llm.retry_listener,
metrics=self.state.metrics,
delegate_agent = agent_cls(
config=agent_config, llm_registry=self.agent.llm_registry
)
delegate_agent = agent_cls(llm=llm, config=agent_config)
# Take a snapshot of the current metrics before starting the delegate
state = State(
@ -683,7 +677,7 @@ class AgentController:
)
self.log(
'debug',
f'start delegate, creating agent {delegate_agent.name} using LLM {llm}',
f'start delegate, creating agent {delegate_agent.name}',
)
# Create the delegate with is_delegate=True so it does NOT subscribe directly
@ -693,6 +687,7 @@ class AgentController:
user_id=self.user_id,
agent=delegate_agent,
event_stream=self.event_stream,
convo_stats=self.convo_stats,
iteration_delta=self._initial_max_iterations,
budget_per_task_delta=self._initial_max_budget_per_task,
agent_to_llm_config=self.agent_to_llm_config,
@ -795,13 +790,8 @@ class AgentController:
extra={'msg_type': 'STEP'},
)
# Ensure budget control flag is synchronized with the latest metrics.
# In the future, we should centralized the use of one LLM object per conversation.
# This will help us unify the cost for auto generating titles, running the condensor, etc.
# Before many microservices will touh the same llm cost field, we should sync with the budget flag for the controller
# and check that we haven't exceeded budget BEFORE executing an agent step.
# Synchronize spend across all llm services with the budget flag
self.state_tracker.sync_budget_flag_with_metrics()
if self._is_stuck():
await self._react_to_exception(
AgentStuckInLoopError('Agent got stuck in a loop')
@ -961,14 +951,15 @@ class AgentController:
def set_initial_state(
self,
state: State | None,
convo_stats: ConversationStats,
max_iterations: int,
max_budget_per_task: float | None,
confirmation_mode: bool = False,
):
self.state_tracker.set_initial_state(
self.id,
self.agent,
state,
convo_stats,
max_iterations,
max_budget_per_task,
confirmation_mode,
@ -1009,37 +1000,20 @@ class AgentController:
action: The action to attach metrics to
"""
# Get metrics from agent LLM
agent_metrics = self.state.metrics
metrics = self.convo_stats.get_combined_metrics()
# Get metrics from condenser LLM if it exists
condenser_metrics: Metrics | None = None
if hasattr(self.agent, 'condenser') and hasattr(self.agent.condenser, 'llm'):
condenser_metrics = self.agent.condenser.llm.metrics
# Create a new minimal metrics object with just what the frontend needs
metrics = Metrics(model_name=agent_metrics.model_name)
# Set accumulated cost (sum of agent and condenser costs)
metrics.accumulated_cost = agent_metrics.accumulated_cost
if condenser_metrics:
metrics.accumulated_cost += condenser_metrics.accumulated_cost
# Create a clean copy with only the fields we want to keep
clean_metrics = Metrics()
clean_metrics.accumulated_cost = metrics.accumulated_cost
clean_metrics._accumulated_token_usage = copy.deepcopy(
metrics.accumulated_token_usage
)
# Add max_budget_per_task to metrics
if self.state.budget_flag:
metrics.max_budget_per_task = self.state.budget_flag.max_value
clean_metrics.max_budget_per_task = self.state.budget_flag.max_value
# Set accumulated token usage (sum of agent and condenser token usage)
# Use a deep copy to ensure we don't modify the original object
metrics._accumulated_token_usage = (
agent_metrics.accumulated_token_usage.model_copy(deep=True)
)
if condenser_metrics:
metrics._accumulated_token_usage = (
metrics._accumulated_token_usage
+ condenser_metrics.accumulated_token_usage
)
action.llm_metrics = metrics
action.llm_metrics = clean_metrics
# Log the metrics information for debugging
# Get the latest usage directly from the agent's metrics

View File

@ -21,6 +21,7 @@ from openhands.events.action.agent import AgentFinishAction
from openhands.events.event import Event, EventSource
from openhands.llm.metrics import Metrics
from openhands.memory.view import View
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.files import FileStore
from openhands.storage.locations import get_conversation_agent_state_filename
@ -84,6 +85,7 @@ class State:
limit_increase_amount=100, current_value=0, max_value=100
)
)
convo_stats: ConversationStats | None = None
budget_flag: BudgetControlFlag | None = None
confirmation_mode: bool = False
history: list[Event] = field(default_factory=list)
@ -91,8 +93,7 @@ class State:
outputs: dict = field(default_factory=dict)
agent_state: AgentState = AgentState.LOADING
resume_state: AgentState | None = None
# global metrics for the current task
metrics: Metrics = field(default_factory=Metrics)
# root agent has level 0, and every delegate increases the level by one
delegate_level: int = 0
# start_id and end_id track the range of events in history
@ -116,9 +117,14 @@ class State:
local_metrics: Metrics | None = None
delegates: dict[tuple[int, int], tuple[str, str]] | None = None
metrics: Metrics = field(default_factory=Metrics)
def save_to_session(
self, sid: str, file_store: FileStore, user_id: str | None
) -> None:
convo_stats = self.convo_stats
self.convo_stats = None # Don't save convo stats, handles itself
pickled = pickle.dumps(self)
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
encoded = base64.b64encode(pickled).decode('utf-8')
@ -138,6 +144,8 @@ class State:
logger.error(f'Failed to save state to session: {e}')
raise e
self.convo_stats = convo_stats # restore reference
@staticmethod
def restore_from_session(
sid: str, file_store: FileStore, user_id: str | None = None

View File

@ -1,4 +1,3 @@
from openhands.controller.agent import Agent
from openhands.controller.state.control_flags import (
BudgetControlFlag,
IterationControlFlag,
@ -14,7 +13,7 @@ from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.serialization.event import event_to_trajectory
from openhands.events.stream import EventStream
from openhands.llm.metrics import Metrics
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.files import FileStore
@ -51,8 +50,8 @@ class StateTracker:
def set_initial_state(
self,
id: str,
agent: Agent,
state: State | None,
convo_stats: ConversationStats,
max_iterations: int,
max_budget_per_task: float | None,
confirmation_mode: bool = False,
@ -75,6 +74,7 @@ class StateTracker:
session_id=id.removesuffix('-delegate'),
user_id=self.user_id,
inputs={},
convo_stats=convo_stats,
iteration_flag=IterationControlFlag(
limit_increase_amount=max_iterations,
current_value=0,
@ -99,13 +99,7 @@ class StateTracker:
if self.state.start_id <= -1:
self.state.start_id = 0
logger.info(
f'AgentController {id} initializing history from event {self.state.start_id}',
)
# Share the state metrics with the agent's LLM metrics
# This ensures that all accumulated metrics are always in sync between controller and llm
agent.llm.metrics = self.state.metrics
state.convo_stats = convo_stats
def _init_history(self, event_stream: EventStream) -> None:
"""Initializes the agent's history from the event stream.
@ -254,6 +248,9 @@ class StateTracker:
if self.sid and self.file_store:
self.state.save_to_session(self.sid, self.file_store, self.user_id)
if self.state.convo_stats:
self.state.convo_stats.save_metrics()
def run_control_flags(self):
"""Performs one step of the control flags"""
self.state.iteration_flag.step()
@ -264,20 +261,8 @@ class StateTracker:
"""Ensures that budget flag is up to date with accumulated costs from llm completions
Budget flag will monitor for when budget is exceeded
"""
if self.state.budget_flag:
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost
def merge_metrics(self, metrics: Metrics):
"""Merges metrics with the state metrics
NOTE: this should be refactored in the future. We should have services (draft llm, title autocomplete, condenser, etc)
use their own LLMs, but the metrics object should be shared. This way we have one source of truth for accumulated costs from
all services
This would prevent having fragmented stores for metrics, and we don't have the burden of deciding where and how to store them
if we decide introduce more specialized services that require llm completions
"""
self.state.metrics.merge(metrics)
if self.state.budget_flag:
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost
# Sync cost across all llm services from llm registry
if self.state.budget_flag and self.state.convo_stats:
self.state.budget_flag.current_value = (
self.state.convo_stats.get_combined_metrics().accumulated_cost
)

View File

@ -157,13 +157,16 @@ class OpenHandsConfig(BaseModel):
"""Get a map of agent names to llm configs."""
return {name: self.get_llm_config_from_agent(name) for name in self.agents}
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig:
agent_config: AgentConfig = self.get_agent_config(name)
def get_llm_config_from_agent_config(self, agent_config: AgentConfig):
llm_config_name = (
agent_config.llm_config if agent_config.llm_config is not None else 'llm'
)
return self.get_llm_config(llm_config_name)
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig:
agent_config: AgentConfig = self.get_agent_config(name)
return self.get_llm_config_from_agent_config(agent_config)
def get_agent_configs(self) -> dict[str, AgentConfig]:
return self.agents

View File

@ -6,7 +6,6 @@ from typing import Callable, Protocol
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
import openhands.cli.suppress_warnings # noqa: F401
from openhands.controller.agent import Agent
from openhands.controller.replay import ReplayManager
from openhands.controller.state.state import State
from openhands.core.config import (
@ -33,10 +32,12 @@ from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.io import read_input, read_task
from openhands.llm.llm_registry import LLMRegistry
from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
from openhands.utils.utils import create_registry_and_convo_stats
class FakeUserResponseFunc(Protocol):
@ -53,12 +54,12 @@ async def run_controller(
initial_user_action: Action,
sid: str | None = None,
runtime: Runtime | None = None,
agent: Agent | None = None,
exit_on_message: bool = False,
fake_user_response_fn: FakeUserResponseFunc | None = None,
headless_mode: bool = True,
memory: Memory | None = None,
conversation_instructions: str | None = None,
llm_registry: LLMRegistry | None = None,
) -> State | None:
"""Main coroutine to run the agent controller with task input flexibility.
@ -70,7 +71,6 @@ async def run_controller(
sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing.
Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
runtime: (optional) A runtime for the agent to run on.
agent: (optional) A agent to run.
exit_on_message: quit if agent asks for a message from user (optional)
fake_user_response_fn: An optional function that receives the current state
(could be None) and returns a fake user response.
@ -98,8 +98,13 @@ async def run_controller(
"""
sid = sid or generate_sid(config)
if agent is None:
agent = create_agent(config)
llm_registry, convo_stats, config = create_registry_and_convo_stats(
config,
sid,
None,
)
agent = create_agent(config, llm_registry)
# when the runtime is created, it will be connected and clone the selected repository
repo_directory = None
@ -108,6 +113,7 @@ async def run_controller(
repo_tokens = get_provider_tokens()
runtime = create_runtime(
config,
llm_registry,
sid=sid,
headless_mode=headless_mode,
agent=agent,
@ -159,7 +165,7 @@ async def run_controller(
)
controller, initial_state = create_controller(
agent, runtime, config, replay_events=replay_events
agent, runtime, config, convo_stats, replay_events=replay_events
)
assert isinstance(initial_user_action, Action), (

View File

@ -21,12 +21,13 @@ from openhands.integrations.provider import (
ProviderToken,
ProviderType,
)
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroagent
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage import get_file_store
from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
@ -34,6 +35,7 @@ from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
def create_runtime(
config: OpenHandsConfig,
llm_registry: LLMRegistry,
sid: str | None = None,
headless_mode: bool = True,
agent: Agent | None = None,
@ -82,6 +84,7 @@ def create_runtime(
sid=session_id,
plugins=agent_cls.sandbox_plugins,
headless_mode=headless_mode,
llm_registry=llm_registry,
git_provider_tokens=git_provider_tokens,
)
@ -203,16 +206,11 @@ def create_memory(
return memory
def create_agent(config: OpenHandsConfig) -> Agent:
def create_agent(config: OpenHandsConfig, llm_registry: LLMRegistry) -> Agent:
agent_cls: type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent)
agent = agent_cls(
llm=LLM(config=llm_config),
config=agent_config,
)
config.get_llm_config_from_agent(config.default_agent)
agent = agent_cls(config=agent_config, llm_registry=llm_registry)
return agent
@ -220,6 +218,7 @@ def create_controller(
agent: Agent,
runtime: Runtime,
config: OpenHandsConfig,
convo_stats: ConversationStats,
headless_mode: bool = True,
replay_events: list[Event] | None = None,
) -> tuple[AgentController, State | None]:
@ -237,6 +236,7 @@ def create_controller(
controller = AgentController(
agent=agent,
convo_stats=convo_stats,
iteration_delta=config.max_iterations,
budget_per_task_delta=config.max_budget_per_task,
agent_to_llm_config=config.get_agent_to_llm_config_map(),

View File

@ -8,6 +8,7 @@ from typing import Any, Callable
import httpx
from openhands.core.config import LLMConfig
from openhands.llm.metrics import Metrics
with warnings.catch_warnings():
warnings.simplefilter('ignore')
@ -34,7 +35,6 @@ from openhands.llm.fn_call_converter import (
convert_fncall_messages_to_non_fncall_messages,
convert_non_fncall_messages_to_fncall_messages,
)
from openhands.llm.metrics import Metrics
from openhands.llm.retry_mixin import RetryMixin
__all__ = ['LLM']
@ -133,6 +133,7 @@ class LLM(RetryMixin, DebugMixin):
def __init__(
self,
config: LLMConfig,
service_id: str,
metrics: Metrics | None = None,
retry_listener: Callable[[int, int], None] | None = None,
) -> None:
@ -145,11 +146,12 @@ class LLM(RetryMixin, DebugMixin):
metrics: The metrics to use.
"""
self._tried_model_info = False
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
self.service_id = service_id
self.metrics: Metrics = (
metrics if metrics is not None else Metrics(model_name=config.model)
)
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
self.model_info: ModelInfo | None = None
self.retry_listener = retry_listener
@ -408,8 +410,7 @@ class LLM(RetryMixin, DebugMixin):
assert self.config.log_completions_folder is not None
log_file = os.path.join(
self.config.log_completions_folder,
# use the metric model name (for draft editor)
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
f'{self.config.model.replace("/", "__")}-{time.time()}.json',
)
# set up the dict to be logged

View File

@ -0,0 +1,132 @@
import copy
from typing import Any, Callable
from uuid import uuid4
from pydantic import BaseModel, ConfigDict
from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm import LLM
class RegistryEvent(BaseModel):
llm: LLM
service_id: str
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
class LLMRegistry:
def __init__(
self,
config: OpenHandsConfig,
agent_cls: str | None = None,
retry_listener: Callable[[int, int], None] | None = None,
):
self.registry_id = str(uuid4())
self.config = copy.deepcopy(config)
self.retry_listner = retry_listener
self.agent_to_llm_config = self.config.get_agent_to_llm_config_map()
self.service_to_llm: dict[str, LLM] = {}
self.subscriber: Callable[[Any], None] | None = None
selected_agent_cls = self.config.default_agent
if agent_cls:
selected_agent_cls = agent_cls
agent_name = selected_agent_cls if selected_agent_cls is not None else 'agent'
llm_config = self.config.get_llm_config_from_agent(agent_name)
self.active_agent_llm: LLM = self.get_llm('agent', llm_config)
def _create_new_llm(
self, service_id: str, config: LLMConfig, with_listener: bool = True
) -> LLM:
if with_listener:
llm = LLM(
service_id=service_id, config=config, retry_listener=self.retry_listner
)
else:
llm = LLM(service_id=service_id, config=config)
self.service_to_llm[service_id] = llm
self.notify(RegistryEvent(llm=llm, service_id=service_id))
return llm
def request_extraneous_completion(
self, service_id: str, llm_config: LLMConfig, messages: list[dict[str, str]]
) -> str:
logger.info(f'extraneous completion: {service_id}')
if service_id not in self.service_to_llm:
self._create_new_llm(
config=llm_config, service_id=service_id, with_listener=False
)
llm = self.service_to_llm[service_id]
response = llm.completion(messages=messages)
return response.choices[0].message.content.strip()
def get_llm_from_agent_config(self, service_id: str, agent_config: AgentConfig):
llm_config = self.config.get_llm_config_from_agent_config(agent_config)
if service_id in self.service_to_llm:
if self.service_to_llm[service_id].config != llm_config:
# TODO: update llm config internally
# Done when agent delegates has different config, we should reuse the existing LLM
pass
return self.service_to_llm[service_id]
return self._create_new_llm(config=llm_config, service_id=service_id)
def get_llm(
self,
service_id: str,
config: LLMConfig | None = None,
):
logger.info(
f'[LLM registry {self.registry_id}]: Registering service for {service_id}'
)
# Attempting to switch configs for existing LLM
if (
service_id in self.service_to_llm
and self.service_to_llm[service_id].config != config
):
raise ValueError(
f'Requesting same service ID {service_id} with different config, use a new service ID'
)
if service_id in self.service_to_llm:
return self.service_to_llm[service_id]
if not config:
raise ValueError('Requesting new LLM without specifying LLM config')
return self._create_new_llm(config=config, service_id=service_id)
def get_active_llm(self) -> LLM:
return self.active_agent_llm
def _set_active_llm(self, service_id) -> None:
if service_id not in self.service_to_llm:
raise ValueError(f'Unrecognized service ID: {service_id}')
self.active_agent_llm = self.service_to_llm[service_id]
def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None:
self.subscriber = callback
# Subscriptions happen after default llm is initialized
# Notify service of this llm
self.notify(
RegistryEvent(
llm=self.active_agent_llm, service_id=self.active_agent_llm.service_id
)
)
def notify(self, event: RegistryEvent):
if self.subscriber:
try:
self.subscriber(event)
except Exception as e:
logger.warning(f'Failed to emit event: {e}')

View File

@ -10,6 +10,7 @@ from openhands.controller.state.state import State
from openhands.core.config.condenser_config import CondenserConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.agent import CondensationAction
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.view import View
CONDENSER_METADATA_KEY = 'condenser_meta'
@ -144,7 +145,9 @@ class Condenser(ABC):
CONDENSER_REGISTRY[configuration_type] = cls
@classmethod
def from_config(cls, config: CondenserConfig) -> Condenser:
def from_config(
cls, config: CondenserConfig, llm_registry: LLMRegistry
) -> Condenser:
"""Create a condenser from a configuration object.
Args:
@ -158,7 +161,7 @@ class Condenser(ABC):
"""
try:
condenser_class = CONDENSER_REGISTRY[type(config)]
return condenser_class.from_config(config)
return condenser_class.from_config(config, llm_registry)
except KeyError:
raise ValueError(f'Unknown condenser config: {config}')

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig
from openhands.events.action.agent import CondensationAction
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import (
Condensation,
RollingCondenser,
@ -58,7 +59,9 @@ class AmortizedForgettingCondenser(RollingCondenser):
@classmethod
def from_config(
cls, config: AmortizedForgettingCondenserConfig
cls,
config: AmortizedForgettingCondenserConfig,
llm_registry: LLMRegistry,
) -> AmortizedForgettingCondenser:
return AmortizedForgettingCondenser(**config.model_dump(exclude={'type'}))

View File

@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import BrowserOutputCondenserConfig
from openhands.events.event import Event
from openhands.events.observation import BrowserOutputObservation
from openhands.events.observation.agent import AgentCondensationObservation
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser, View
@ -40,7 +41,7 @@ class BrowserOutputCondenser(Condenser):
@classmethod
def from_config(
cls, config: BrowserOutputCondenserConfig
cls, config: BrowserOutputCondenserConfig, llm_registry: LLMRegistry
) -> BrowserOutputCondenser:
return BrowserOutputCondenser(**config.model_dump(exclude={'type'}))

View File

@ -9,6 +9,7 @@ from openhands.events.action.agent import (
from openhands.events.action.message import MessageAction, SystemMessageAction
from openhands.events.event import EventSource
from openhands.events.observation import Observation
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
@ -177,7 +178,9 @@ class ConversationWindowCondenser(RollingCondenser):
@classmethod
def from_config(
cls, _config: ConversationWindowCondenserConfig
cls,
_config: ConversationWindowCondenserConfig,
llm_registry: LLMRegistry,
) -> ConversationWindowCondenser:
return ConversationWindowCondenser()

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel
from openhands.core.config.condenser_config import LLMAttentionCondenserConfig
from openhands.events.action.agent import CondensationAction
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import (
Condensation,
RollingCondenser,
@ -22,7 +23,12 @@ class ImportantEventSelection(BaseModel):
class LLMAttentionCondenser(RollingCondenser):
"""Rolling condenser strategy that uses an LLM to select the most important events when condensing the history."""
def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 1):
def __init__(
self,
llm: LLM,
max_size: int = 100,
keep_first: int = 1,
):
if keep_first >= max_size // 2:
raise ValueError(
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
@ -113,15 +119,19 @@ class LLMAttentionCondenser(RollingCondenser):
return len(view) > self.max_size
@classmethod
def from_config(cls, config: LLMAttentionCondenserConfig) -> LLMAttentionCondenser:
def from_config(
cls, config: LLMAttentionCondenserConfig, llm_registry: LLMRegistry
) -> LLMAttentionCondenser:
# This condenser cannot take advantage of prompt caching. If it happens
# to be set, we'll pay for the cache writes but never get a chance to
# save on a read.
llm_config = config.llm_config.model_copy()
llm_config.caching_prompt = False
llm = llm_registry.get_llm('condenser', llm_config)
return LLMAttentionCondenser(
llm=LLM(config=llm_config),
llm=llm,
max_size=config.max_size,
keep_first=config.keep_first,
)

View File

@ -5,7 +5,8 @@ from openhands.core.message import Message, TextContent
from openhands.events.action.agent import CondensationAction
from openhands.events.observation.agent import AgentCondensationObservation
from openhands.events.serialization.event import truncate_content
from openhands.llm import LLM
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import (
Condensation,
RollingCondenser,
@ -154,16 +155,17 @@ CURRENT_STATE: Last flip: Heads, Haiku count: 15/20"""
@classmethod
def from_config(
cls, config: LLMSummarizingCondenserConfig
cls, config: LLMSummarizingCondenserConfig, llm_registry: LLMRegistry
) -> LLMSummarizingCondenser:
# This condenser cannot take advantage of prompt caching. If it happens
# to be set, we'll pay for the cache writes but never get a chance to
# save on a read.
llm_config = config.llm_config.model_copy()
llm_config.caching_prompt = False
llm = llm_registry.get_llm('condenser', llm_config)
return LLMSummarizingCondenser(
llm=LLM(config=llm_config),
llm=llm,
max_size=config.max_size,
keep_first=config.keep_first,
max_event_length=config.max_event_length,

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from openhands.core.config.condenser_config import NoOpCondenserConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser, View
@ -12,7 +13,9 @@ class NoOpCondenser(Condenser):
return view
@classmethod
def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser:
def from_config(
cls, config: NoOpCondenserConfig, llm_registry: LLMRegistry
) -> NoOpCondenser:
return NoOpCondenser()

View File

@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import ObservationMaskingCondenserCo
from openhands.events.event import Event
from openhands.events.observation import Observation
from openhands.events.observation.agent import AgentCondensationObservation
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser, View
@ -28,7 +29,9 @@ class ObservationMaskingCondenser(Condenser):
@classmethod
def from_config(
cls, config: ObservationMaskingCondenserConfig
cls,
config: ObservationMaskingCondenserConfig,
llm_registry: LLMRegistry,
) -> ObservationMaskingCondenser:
return ObservationMaskingCondenser(**config.model_dump(exclude={'type'}))

View File

@ -4,6 +4,7 @@ from contextlib import contextmanager
from openhands.controller.state.state import State
from openhands.core.config.condenser_config import CondenserPipelineConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser
from openhands.memory.view import View
@ -39,8 +40,10 @@ class CondenserPipeline(Condenser):
return result
@classmethod
def from_config(cls, config: CondenserPipelineConfig) -> CondenserPipeline:
condensers = [Condenser.from_config(c) for c in config.condensers]
def from_config(
cls, config: CondenserPipelineConfig, llm_registry: LLMRegistry
) -> CondenserPipeline:
condensers = [Condenser.from_config(c, llm_registry) for c in config.condensers]
return CondenserPipeline(*condensers)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from openhands.core.config.condenser_config import RecentEventsCondenserConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser, View
@ -21,7 +22,9 @@ class RecentEventsCondenser(Condenser):
return View(events=head + tail)
@classmethod
def from_config(cls, config: RecentEventsCondenserConfig) -> RecentEventsCondenser:
def from_config(
cls, config: RecentEventsCondenserConfig, llm_registry: LLMRegistry
) -> RecentEventsCondenser:
return RecentEventsCondenser(**config.model_dump(exclude={'type'}))

View File

@ -13,7 +13,8 @@ from openhands.core.message import Message, TextContent
from openhands.events.action.agent import CondensationAction
from openhands.events.observation.agent import AgentCondensationObservation
from openhands.events.serialization.event import truncate_content
from openhands.llm import LLM
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import (
Condensation,
RollingCondenser,
@ -180,15 +181,14 @@ class StructuredSummaryCondenser(RollingCondenser):
if max_size < 1:
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
if not llm.is_function_calling_active():
raise ValueError(
'LLM must support function calling to use StructuredSummaryCondenser'
)
self.max_size = max_size
self.keep_first = keep_first
self.max_event_length = max_event_length
self.llm = llm
if not self.llm.is_function_calling_active():
raise ValueError(
'LLM must support function calling to use StructuredSummaryCondenser'
)
super().__init__()
@ -309,16 +309,17 @@ Capture all relevant information, especially:
@classmethod
def from_config(
cls, config: StructuredSummaryCondenserConfig
cls, config: StructuredSummaryCondenserConfig, llm_registry: LLMRegistry
) -> StructuredSummaryCondenser:
# This condenser cannot take advantage of prompt caching. If it happens
# to be set, we'll pay for the cache writes but never get a chance to
# save on a read.
llm_config = config.llm_config.model_copy()
llm_config.caching_prompt = False
llm = llm_registry.get_llm('condenser', llm_config)
return StructuredSummaryCondenser(
llm=LLM(config=llm_config),
llm=llm,
max_size=config.max_size,
keep_first=config.keep_first,
max_event_length=config.max_event_length,

View File

@ -23,7 +23,7 @@ class ServiceContext:
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
self._strategy = strategy
if llm_config is not None:
self.llm = LLM(llm_config)
self.llm = LLM(llm_config, service_id='resolver')
def set_strategy(self, strategy: IssueHandlerInterface) -> None:
self._strategy = strategy

View File

@ -28,6 +28,7 @@ from openhands.events.observation import (
)
from openhands.events.stream import EventStreamSubscriber
from openhands.integrations.service_types import ProviderType
from openhands.llm.llm_registry import LLMRegistry
from openhands.resolver.interfaces.issue import Issue
from openhands.resolver.interfaces.issue_definitions import (
ServiceContextIssue,
@ -412,7 +413,8 @@ class IssueResolver:
shutil.rmtree(self.workspace_base)
shutil.copytree(os.path.join(self.output_dir, 'repo'), self.workspace_base)
runtime = create_runtime(self.app_config)
llm_registry = LLMRegistry(self.app_config)
runtime = create_runtime(self.app_config, llm_registry)
await runtime.connect()
def on_event(evt: Event) -> None:

View File

@ -463,7 +463,7 @@ def update_existing_pull_request(
# Summarize with LLM if provided
if llm_config is not None:
llm = LLM(llm_config)
llm = LLM(llm_config, service_id='resolver')
with open(
os.path.join(
os.path.dirname(__file__),

View File

@ -54,6 +54,7 @@ from openhands.integrations.provider import (
ProviderType,
)
from openhands.integrations.service_types import AuthenticationError
from openhands.llm.llm_registry import LLMRegistry
from openhands.microagent import (
BaseMicroagent,
load_microagents_from_dir,
@ -125,6 +126,7 @@ class Runtime(FileEditRuntimeMixin):
self,
config: OpenHandsConfig,
event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
@ -178,7 +180,9 @@ class Runtime(FileEditRuntimeMixin):
# Load mixins
FileEditRuntimeMixin.__init__(
self, enable_llm_editor=config.get_agent_config().enable_llm_editor
self,
enable_llm_editor=config.get_agent_config().enable_llm_editor,
llm_registry=llm_registry,
)
self.user_id = user_id

View File

@ -43,6 +43,7 @@ from openhands.events.observation import (
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.request import send_request
@ -68,6 +69,7 @@ class ActionExecutionClient(Runtime):
self,
config: OpenHandsConfig,
event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
@ -85,6 +87,7 @@ class ActionExecutionClient(Runtime):
super().__init__(
config,
event_stream,
llm_registry,
sid,
plugins,
env_vars,

View File

@ -46,6 +46,7 @@ from openhands.events.observation import (
Observation,
)
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.runtime_status import RuntimeStatus
@ -107,6 +108,7 @@ class CLIRuntime(Runtime):
self,
config: OpenHandsConfig,
event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
@ -119,6 +121,7 @@ class CLIRuntime(Runtime):
super().__init__(
config,
event_stream,
llm_registry,
sid,
plugins,
env_vars,

View File

@ -20,6 +20,7 @@ from openhands.core.logger import DEBUG, DEBUG_RUNTIME
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.builder import DockerRuntimeBuilder
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
@ -90,6 +91,7 @@ class DockerRuntime(ActionExecutionClient):
self,
config: OpenHandsConfig,
event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
@ -143,6 +145,7 @@ class DockerRuntime(ActionExecutionClient):
super().__init__(
config,
event_stream,
llm_registry,
sid,
plugins,
env_vars,

View File

@ -43,6 +43,7 @@ from openhands.core.logger import DEBUG
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
@ -81,6 +82,7 @@ class KubernetesRuntime(ActionExecutionClient):
self,
config: OpenHandsConfig,
event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
@ -137,6 +139,7 @@ class KubernetesRuntime(ActionExecutionClient):
super().__init__(
config,
event_stream,
llm_registry,
sid,
plugins,
env_vars,

View File

@ -26,6 +26,7 @@ from openhands.events.observation import (
)
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
@ -135,6 +136,7 @@ class LocalRuntime(ActionExecutionClient):
self,
config: OpenHandsConfig,
event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
@ -186,6 +188,7 @@ class LocalRuntime(ActionExecutionClient):
super().__init__(
config,
event_stream,
llm_registry,
sid,
plugins,
env_vars,
@ -801,12 +804,6 @@ def _create_warm_server_in_background(
def _get_plugins(config: OpenHandsConfig) -> list[PluginRequirement]:
from openhands.controller.agent import Agent
from openhands.llm.llm import LLM
agent_config = config.get_agent_config(config.default_agent)
llm = LLM(
config=config.get_llm_config_from_agent(config.default_agent),
)
agent = Agent.get_cls(config.default_agent)(llm, agent_config)
plugins = agent.sandbox_plugins
plugins = Agent.get_cls(config.default_agent).sandbox_plugins
return plugins

View File

@ -19,6 +19,7 @@ from openhands.core.exceptions import (
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
@ -51,6 +52,7 @@ class RemoteRuntime(ActionExecutionClient):
self,
config: OpenHandsConfig,
event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
@ -64,6 +66,7 @@ class RemoteRuntime(ActionExecutionClient):
super().__init__(
config,
event_stream,
llm_registry,
sid,
plugins,
env_vars,

View File

@ -23,7 +23,7 @@ from openhands.events.observation import (
)
from openhands.linter import DefaultLinter
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.llm.llm_registry import LLMRegistry
from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
USER_MSG = """
@ -128,7 +128,13 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
# This restricts the number of lines we can edit to avoid exceeding the token limit.
MAX_LINES_TO_EDIT = 300
def __init__(self, enable_llm_editor: bool, *args: Any, **kwargs: Any) -> None:
def __init__(
self,
enable_llm_editor: bool,
llm_registry: LLMRegistry,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.enable_llm_editor = enable_llm_editor
@ -138,7 +144,6 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
draft_editor_config = self.config.get_llm_config('draft_editor')
# manually set the model name for the draft editor LLM to distinguish token costs
llm_metrics = Metrics(model_name='draft_editor:' + draft_editor_config.model)
if draft_editor_config.caching_prompt:
logger.debug(
'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
@ -146,7 +151,9 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
)
draft_editor_config.caching_prompt = False
self.draft_editor_llm = LLM(draft_editor_config, metrics=llm_metrics)
self.draft_editor_llm = llm_registry.get_llm(
'draft_editor_llm', draft_editor_config
)
logger.debug(
f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
)

View File

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
import socketio
from openhands.core.config import OpenHandsConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.events.action import MessageAction
from openhands.server.config.server_config import ServerConfig
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
@ -136,6 +137,16 @@ class ConversationManager(ABC):
) -> list[AgentLoopInfo]:
"""Get the AgentLoopInfo for conversations."""
@abstractmethod
async def request_llm_completion(
self,
sid: str,
service_id: str,
llm_config: LLMConfig,
messages: list[dict[str, str]],
) -> str:
"""Request extraneous llm completions for a conversation"""
@classmethod
@abstractmethod
def get_instance(

View File

@ -15,13 +15,13 @@ from docker.models.containers import Container
from openhands.controller.agent import Agent
from openhands.core.config import OpenHandsConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import MessageAction
from openhands.events.nested_event_store import NestedEventStore
from openhands.events.stream import EventStream
from openhands.experiments.experiment_manager import ExperimentManagerImpl
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
from openhands.server.config.server_config import ServerConfig
@ -42,6 +42,7 @@ from openhands.storage.files import FileStore
from openhands.storage.locations import get_conversation_dir
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.import_utils import get_impl
from openhands.utils.utils import create_registry_and_convo_stats
@dataclass
@ -275,6 +276,16 @@ class DockerNestedConversationManager(ConversationManager):
# Not supported - clients should connect directly to the nested server!
raise ValueError('unsupported_operation')
async def request_llm_completion(
self,
sid: str,
service_id: str,
llm_config: LLMConfig,
messages: list[dict[str, str]],
) -> str:
# Not supported - clients should connect directly to the nested server!
raise ValueError('unsupported_operation')
async def send_event_to_conversation(self, sid, data):
async with httpx.AsyncClient(
headers={
@ -471,27 +482,27 @@ class DockerNestedConversationManager(ConversationManager):
# This session is created here only because it is the easiest way to get a runtime, which
# is the easiest way to create the needed docker container
# Run experiment manager variant test before creating session
config: OpenHandsConfig = ExperimentManagerImpl.run_config_variant_test(
user_id, sid, self.config
)
llm_registry, convo_stats, config = create_registry_and_convo_stats(
config, sid, user_id, settings
)
session = Session(
sid=sid,
llm_registry=llm_registry,
convo_stats=convo_stats,
file_store=self.file_store,
config=config,
sio=self.sio,
user_id=user_id,
)
llm_registry.retry_listner = session._notify_on_llm_retry
agent_cls = settings.agent or config.default_agent
agent_name = agent_cls if agent_cls is not None else 'agent'
llm = LLM(
config=config.get_llm_config_from_agent(agent_name),
retry_listener=session._notify_on_llm_retry,
)
llm = session._create_llm(agent_cls)
agent_config = config.get_agent_config(agent_cls)
agent = Agent.get_cls(agent_cls)(llm, agent_config)
agent = Agent.get_cls(agent_cls)(agent_config, llm_registry)
config = config.model_copy(deep=True)
env_vars = config.sandbox.runtime_startup_env_vars
@ -543,6 +554,7 @@ class DockerNestedConversationManager(ConversationManager):
headless_mode=False,
attach_to_existing=False,
main_module='openhands.server',
llm_registry=llm_registry,
)
# Hack - disable setting initial env.

View File

@ -6,12 +6,14 @@ from typing import Callable, Iterable
import socketio
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import MessageAction
from openhands.events.stream import EventStreamSubscriber, session_exists
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime import get_runtime_cls
from openhands.server.config.server_config import ServerConfig
from openhands.server.constants import ROOM_KEY
@ -37,6 +39,7 @@ from openhands.utils.conversation_summary import (
)
from openhands.utils.import_utils import get_impl
from openhands.utils.shutdown_listener import should_continue
from openhands.utils.utils import create_registry_and_convo_stats
from .conversation_manager import ConversationManager
@ -332,12 +335,15 @@ class StandaloneConversationManager(ConversationManager):
)
await self.close_session(oldest_conversation_id)
config = self.config.model_copy(deep=True)
llm_registry, convo_stats, config = create_registry_and_convo_stats(
self.config, sid, user_id, settings
)
session = Session(
sid=sid,
file_store=self.file_store,
config=config,
llm_registry=llm_registry,
convo_stats=convo_stats,
sio=self.sio,
user_id=user_id,
)
@ -349,7 +355,9 @@ class StandaloneConversationManager(ConversationManager):
try:
session.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER,
self._create_conversation_update_callback(user_id, sid, settings),
self._create_conversation_update_callback(
user_id, sid, settings, session.llm_registry
),
UPDATED_AT_CALLBACK_ID,
)
except ValueError:
@ -369,6 +377,21 @@ class StandaloneConversationManager(ConversationManager):
raise RuntimeError(f'no_conversation:{sid}')
await session.dispatch(data)
async def request_llm_completion(
self,
sid: str,
service_id: str,
llm_config: LLMConfig,
messages: list[dict[str, str]],
):
session = self._local_agent_loops_by_sid.get(sid)
if not session:
raise RuntimeError(f'no_conversation:{sid}')
llm_registry = session.llm_registry
return llm_registry.request_extraneous_completion(
service_id, llm_config, messages
)
async def disconnect_from_session(self, connection_id: str):
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
logger.info(
@ -450,6 +473,7 @@ class StandaloneConversationManager(ConversationManager):
user_id: str | None,
conversation_id: str,
settings: Settings,
llm_registry: LLMRegistry,
) -> Callable:
def callback(event, *args, **kwargs):
call_async_from_sync(
@ -458,6 +482,7 @@ class StandaloneConversationManager(ConversationManager):
user_id,
conversation_id,
settings,
llm_registry,
event,
)
@ -468,6 +493,7 @@ class StandaloneConversationManager(ConversationManager):
user_id: str,
conversation_id: str,
settings: Settings,
llm_registry: LLMRegistry,
event=None,
):
conversation_store = await self._get_conversation_store(user_id)
@ -495,7 +521,7 @@ class StandaloneConversationManager(ConversationManager):
conversation.title == default_title
): # attempt to autogenerate if default title is in use
title = await auto_generate_title(
conversation_id, user_id, self.file_store, settings
conversation_id, user_id, self.file_store, settings, llm_registry
)
if title and not title.isspace():
conversation.title = title

View File

@ -33,7 +33,6 @@ from openhands.integrations.service_types import (
ProviderType,
SuggestedTask,
)
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
@ -47,6 +46,7 @@ from openhands.server.services.conversation_service import (
setup_init_conversation_settings,
)
from openhands.server.shared import (
ConversationManagerImpl,
ConversationStoreImpl,
config,
conversation_manager,
@ -364,7 +364,7 @@ async def get_prompt(
)
prompt_template = generate_prompt_template(stringified_events)
prompt = generate_prompt(llm_config, prompt_template)
prompt = generate_prompt(llm_config, prompt_template, conversation_id)
return JSONResponse(
{
@ -380,8 +380,9 @@ def generate_prompt_template(events: str) -> str:
return template.render(events=events)
def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
llm = LLM(llm_config)
def generate_prompt(
llm_config: LLMConfig, prompt_template: str, conversation_id: str
) -> str:
messages = [
{
'role': 'system',
@ -393,8 +394,9 @@ def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
},
]
response = llm.completion(messages=messages)
raw_prompt = response['choices'][0]['message']['content'].strip()
raw_prompt = ConversationManagerImpl.request_llm_completion(
'remember_prompt', conversation_id, llm_config, messages
)
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
if prompt:

View File

@ -31,20 +31,60 @@ from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.utils.conversation_summary import get_default_conversation_title
async def create_new_conversation(
async def initialize_conversation(
user_id: str | None,
conversation_id: str | None,
selected_repository: str | None,
selected_branch: str | None,
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
git_provider: ProviderType | None = None,
) -> ConversationMetadata | None:
if conversation_id is None:
conversation_id = uuid.uuid4().hex
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
if not await conversation_store.exists(conversation_id):
logger.info(
f'New conversation ID: {conversation_id}',
extra={'user_id': user_id, 'session_id': conversation_id},
)
conversation_title = get_default_conversation_title(conversation_id)
logger.info(f'Saving metadata for conversation {conversation_id}')
convo_metadata = ConversationMetadata(
trigger=conversation_trigger,
conversation_id=conversation_id,
title=conversation_title,
user_id=user_id,
selected_repository=selected_repository,
selected_branch=selected_branch,
git_provider=git_provider,
)
await conversation_store.save_metadata(convo_metadata)
return convo_metadata
try:
convo_metadata = await conversation_store.get_metadata(conversation_id)
return convo_metadata
except Exception:
pass
return None
async def start_conversation(
user_id: str | None,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
selected_repository: str | None,
selected_branch: str | None,
initial_user_msg: str | None,
image_urls: list[str] | None,
replay_json: str | None,
conversation_instructions: str | None = None,
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
attach_conversation_id: bool = False,
git_provider: ProviderType | None = None,
conversation_id: str | None = None,
conversation_id: str,
convo_metadata: ConversationMetadata,
conversation_instructions: str | None,
mcp_config: MCPConfig | None = None,
) -> AgentLoopInfo:
logger.info(
@ -52,7 +92,7 @@ async def create_new_conversation(
extra={
'signal': 'create_conversation',
'user_id': user_id,
'trigger': conversation_trigger.value,
'trigger': convo_metadata.trigger,
},
)
logger.info('Loading settings')
@ -79,53 +119,25 @@ async def create_new_conversation(
raise MissingSettingsError('Settings not found')
session_init_args['git_provider_tokens'] = git_provider_tokens
session_init_args['selected_repository'] = selected_repository
session_init_args['selected_repository'] = convo_metadata.selected_repository
session_init_args['custom_secrets'] = custom_secrets
session_init_args['selected_branch'] = selected_branch
session_init_args['git_provider'] = git_provider
session_init_args['selected_branch'] = convo_metadata.selected_branch
session_init_args['git_provider'] = convo_metadata.git_provider
session_init_args['conversation_instructions'] = conversation_instructions
if mcp_config:
session_init_args['mcp_config'] = mcp_config
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
logger.info('ServerConversation store loaded')
# For nested runtimes, we allow a single conversation id, passed in on container creation
if conversation_id is None:
conversation_id = uuid.uuid4().hex
if not await conversation_store.exists(conversation_id):
logger.info(
f'New conversation ID: {conversation_id}',
extra={'user_id': user_id, 'session_id': conversation_id},
)
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
user_id, conversation_id, conversation_init_data
)
conversation_title = get_default_conversation_title(conversation_id)
logger.info(f'Saving metadata for conversation {conversation_id}')
await conversation_store.save_metadata(
ConversationMetadata(
trigger=conversation_trigger,
conversation_id=conversation_id,
title=conversation_title,
user_id=user_id,
selected_repository=selected_repository,
selected_branch=selected_branch,
git_provider=git_provider,
llm_model=conversation_init_data.llm_model,
)
)
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
user_id, conversation_id, conversation_init_data
)
logger.info(
f'Starting agent loop for conversation {conversation_id}',
extra={'user_id': user_id, 'session_id': conversation_id},
)
initial_message_action = None
if initial_user_msg or image_urls:
initial_message_action = MessageAction(
@ -133,9 +145,6 @@ async def create_new_conversation(
image_urls=image_urls or [],
)
if attach_conversation_id:
logger.warning('Attaching conversation ID is deprecated, skipping process')
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
conversation_id,
conversation_init_data,
@ -147,6 +156,47 @@ async def create_new_conversation(
return agent_loop_info
async def create_new_conversation(
user_id: str | None,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
selected_repository: str | None,
selected_branch: str | None,
initial_user_msg: str | None,
image_urls: list[str] | None,
replay_json: str | None,
conversation_instructions: str | None = None,
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
git_provider: ProviderType | None = None,
conversation_id: str | None = None,
mcp_config: MCPConfig | None = None,
) -> AgentLoopInfo:
conversation_metadata = await initialize_conversation(
user_id,
conversation_id,
selected_repository,
selected_branch,
conversation_trigger,
git_provider,
)
if not conversation_metadata:
raise Exception('Failed to initialize conversation')
return await start_conversation(
user_id,
git_provider_tokens,
custom_secrets,
initial_user_msg,
image_urls,
replay_json,
conversation_metadata.conversation_id,
conversation_metadata,
conversation_instructions,
mcp_config,
)
def create_provider_tokens_object(
providers_set: list[ProviderType],
) -> PROVIDER_TOKEN_TYPE:

View File

@ -0,0 +1,77 @@
import base64
import pickle
from threading import Lock
from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm_registry import RegistryEvent
from openhands.llm.metrics import Metrics
from openhands.storage.files import FileStore
from openhands.storage.locations import get_conversation_stats_filename
class ConversationStats:
def __init__(
self,
file_store: FileStore | None,
conversation_id: str,
user_id: str | None,
):
self.metrics_path = get_conversation_stats_filename(conversation_id, user_id)
self.file_store = file_store
self.conversation_id = conversation_id
self.user_id = user_id
self._save_lock = Lock()
self.service_to_metrics: dict[str, Metrics] = {}
self.restored_metrics: dict[str, Metrics] = {}
# Always attempt to restore registry if it exists
self.maybe_restore_metrics()
def save_metrics(self):
if not self.file_store:
return
with self._save_lock:
pickled = pickle.dumps(self.service_to_metrics)
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
self.file_store.write(self.metrics_path, serialized_metrics)
def maybe_restore_metrics(self):
if not self.file_store or not self.conversation_id:
return
try:
encoded = self.file_store.read(self.metrics_path)
pickled = base64.b64decode(encoded)
self.restored_metrics = pickle.loads(pickled)
logger.info(f'restored metrics: {self.conversation_id}')
except FileNotFoundError:
pass
def get_combined_metrics(self) -> Metrics:
total_metrics = Metrics()
for metrics in self.service_to_metrics.values():
total_metrics.merge(metrics)
logger.info(f'metrics by all services: {self.service_to_metrics}')
logger.info(f'combined metrics\n\n{total_metrics}')
return total_metrics
def get_metrics_for_service(self, service_id: str) -> Metrics:
if service_id not in self.service_to_metrics:
raise Exception(f'LLM service does not exist {service_id}')
return self.service_to_metrics[service_id]
def register_llm(self, event: RegistryEvent):
# Listen for llm creations and track their metrics
llm = event.llm
service_id = event.service_id
if service_id in self.restored_metrics:
llm.metrics = self.restored_metrics[service_id].copy()
del self.restored_metrics[service_id]
self.service_to_metrics[service_id] = llm.metrics

View File

@ -21,6 +21,7 @@ from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
)
from openhands.llm.llm_registry import LLMRegistry
from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroagent
@ -29,6 +30,7 @@ from openhands.runtime.base import Runtime
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
from openhands.runtime.runtime_status import RuntimeStatus
from openhands.security import SecurityAnalyzer, options
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.storage.files import FileStore
from openhands.utils.async_utils import EXECUTOR, call_sync_from_async
@ -48,6 +50,7 @@ class AgentSession:
sid: str
user_id: str | None
event_stream: EventStream
llm_registry: LLMRegistry
file_store: FileStore
controller: AgentController | None = None
runtime: Runtime | None = None
@ -63,6 +66,8 @@ class AgentSession:
self,
sid: str,
file_store: FileStore,
llm_registry: LLMRegistry,
convo_stats: ConversationStats,
status_callback: Callable | None = None,
user_id: str | None = None,
) -> None:
@ -80,6 +85,8 @@ class AgentSession:
self.logger = OpenHandsLoggerAdapter(
extra={'session_id': sid, 'user_id': user_id}
)
self.llm_registry = llm_registry
self.convo_stats = convo_stats
async def start(
self,
@ -340,6 +347,7 @@ class AgentSession:
self.runtime = runtime_cls(
config=config,
event_stream=self.event_stream,
llm_registry=self.llm_registry,
sid=self.sid,
plugins=agent.sandbox_plugins,
status_callback=self._status_callback,
@ -360,6 +368,7 @@ class AgentSession:
self.runtime = runtime_cls(
config=config,
event_stream=self.event_stream,
llm_registry=self.llm_registry,
sid=self.sid,
plugins=agent.sandbox_plugins,
status_callback=self._status_callback,
@ -441,6 +450,7 @@ class AgentSession:
user_id=self.user_id,
file_store=self.file_store,
event_stream=self.event_stream,
convo_stats=self.convo_stats,
agent=agent,
iteration_delta=int(max_iterations),
budget_per_task_delta=max_budget_per_task,
@ -490,6 +500,15 @@ class AgentSession:
)
return memory
def get_state(self) -> AgentState | None:
controller = self.controller
if controller:
return controller.state.agent_state
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
return AgentState.ERROR
return None
def _maybe_restore_state(self) -> State | None:
"""Helper method to handle state restore logic."""
restored_state = None
@ -510,14 +529,5 @@ class AgentSession:
self.logger.debug('No events found, no state to restore')
return restored_state
def get_state(self) -> AgentState | None:
controller = self.controller
if controller:
return controller.state.agent_state
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
return AgentState.ERROR
return None
def is_closed(self) -> bool:
return self._closed

View File

@ -2,6 +2,7 @@ import asyncio
from openhands.core.config import OpenHandsConfig
from openhands.events.stream import EventStream
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options
@ -45,6 +46,7 @@ class ServerConversation:
else:
runtime_cls = get_runtime_cls(self.config.runtime)
runtime = runtime_cls(
llm_registry=LLMRegistry(self.config),
config=config,
event_stream=self.event_stream,
sid=self.sid,

View File

@ -1,6 +1,5 @@
import asyncio
import time
from copy import deepcopy
from logging import LoggerAdapter
import socketio
@ -28,9 +27,10 @@ from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.constants import ROOM_KEY
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.agent_session import AgentSession
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.storage.data_models.settings import Settings
@ -45,6 +45,7 @@ class Session:
agent_session: AgentSession
loop: asyncio.AbstractEventLoop
config: OpenHandsConfig
llm_registry: LLMRegistry
file_store: FileStore
user_id: str | None
logger: LoggerAdapter
@ -53,6 +54,8 @@ class Session:
self,
sid: str,
config: OpenHandsConfig,
llm_registry: LLMRegistry,
convo_stats: ConversationStats,
file_store: FileStore,
sio: socketio.AsyncServer | None,
user_id: str | None = None,
@ -62,17 +65,21 @@ class Session:
self.last_active_ts = int(time.time())
self.file_store = file_store
self.logger = OpenHandsLoggerAdapter(extra={'session_id': sid})
self.llm_registry = llm_registry
self.convo_stats = convo_stats
self.agent_session = AgentSession(
sid,
file_store,
llm_registry=self.llm_registry,
convo_stats=convo_stats,
status_callback=self.queue_status_message,
user_id=user_id,
)
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event, self.sid
)
# Copying this means that when we update variables they are not applied to the shared global configuration!
self.config = deepcopy(config)
self.config = config
# Lazy import to avoid circular dependency
from openhands.experiments.experiment_manager import ExperimentManagerImpl
@ -140,13 +147,6 @@ class Session:
else self.config.max_budget_per_task
)
# This is a shallow copy of the default LLM config, so changes here will
# persist if we retrieve the default LLM config again when constructing
# the agent
default_llm_config = self.config.get_llm_config()
default_llm_config.model = settings.llm_model or ''
default_llm_config.api_key = settings.llm_api_key
default_llm_config.base_url = settings.llm_base_url
self.config.search_api_key = settings.search_api_key
if settings.sandbox_api_key:
self.config.sandbox.api_key = settings.sandbox_api_key.get_secret_value()
@ -181,10 +181,9 @@ class Session:
)
# TODO: override other LLM config & agent config groups (#2075)
llm = self._create_llm(agent_cls)
agent_config = self.config.get_agent_config(agent_cls)
agent_name = agent_cls if agent_cls is not None else 'agent'
llm_config = self.config.get_llm_config_from_agent(agent_name)
if settings.enable_default_condenser:
# Default condenser chains three condensers together:
# 1. a conversation window condenser that handles explicit
@ -200,7 +199,7 @@ class Session:
ConversationWindowCondenserConfig(),
BrowserOutputCondenserConfig(attention_window=2),
LLMSummarizingCondenserConfig(
llm_config=llm.config, keep_first=4, max_size=120
llm_config=llm_config, keep_first=4, max_size=120
),
]
)
@ -208,12 +207,14 @@ class Session:
self.logger.info(
f'Enabling pipeline condenser with:'
f' browser_output_masking(attention_window=2), '
f' llm(model="{llm.config.model}", '
f' base_url="{llm.config.base_url}", '
f' llm(model="{llm_config.model}", '
f' base_url="{llm_config.base_url}", '
f' keep_first=4, max_size=80)'
)
agent_config.condenser = default_condenser_config
agent = Agent.get_cls(agent_cls)(llm, agent_config)
agent = Agent.get_cls(agent_cls)(agent_config, self.llm_registry)
self.llm_registry.retry_listner = self._notify_on_llm_retry
git_provider_tokens = None
selected_repository = None
@ -269,14 +270,6 @@ class Session:
)
return
def _create_llm(self, agent_cls: str | None) -> LLM:
"""Initialize LLM, extracted for testing."""
agent_name = agent_cls if agent_cls is not None else 'agent'
return LLM(
config=self.config.get_llm_config_from_agent(agent_name),
retry_listener=self._notify_on_llm_retry,
)
def _notify_on_llm_retry(self, retries: int, max: int) -> None:
self.queue_status_message(
'info', RuntimeStatus.LLM_RETRY, f'Retrying LLM request, {retries} / {max}'

View File

@ -30,5 +30,13 @@ def get_conversation_agent_state_filename(sid: str, user_id: str | None = None)
return f'{get_conversation_dir(sid, user_id)}agent_state.pkl'
def get_conversation_llm_registry_filename(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}llm_registry.json'
def get_conversation_stats_filename(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}convo_stats.json'
def get_experiment_config_filename(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}exp_config.json'

View File

@ -7,13 +7,16 @@ from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.events.event import EventSource
from openhands.events.event_store import EventStore
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.storage.data_models.settings import Settings
from openhands.storage.files import FileStore
async def generate_conversation_title(
message: str, llm_config: LLMConfig, max_length: int = 50
message: str,
llm_config: LLMConfig,
llm_registry: LLMRegistry,
max_length: int = 50,
) -> Optional[str]:
"""Generate a concise title for a conversation based on the first user message.
@ -35,8 +38,6 @@ async def generate_conversation_title(
truncated_message = message
try:
llm = LLM(llm_config)
# Create a simple prompt for the LLM to generate a title
messages = [
{
@ -49,8 +50,9 @@ async def generate_conversation_title(
},
]
response = llm.completion(messages=messages)
title = response.choices[0].message.content.strip()
title = llm_registry.request_extraneous_completion(
'convo_title_creator', llm_config, messages
)
# Ensure the title isn't too long
if len(title) > max_length:
@ -75,7 +77,11 @@ def get_default_conversation_title(conversation_id: str) -> str:
async def auto_generate_title(
conversation_id: str, user_id: str | None, file_store: FileStore, settings: Settings
conversation_id: str,
user_id: str | None,
file_store: FileStore,
settings: Settings,
llm_registry: LLMRegistry,
) -> str:
"""Auto-generate a title for a conversation based on the first user message.
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
@ -116,7 +122,7 @@ async def auto_generate_title(
# Try to generate title using LLM
llm_title = await generate_conversation_title(
first_user_message, llm_config
first_user_message, llm_config, llm_registry
)
if llm_title:
logger.info(f'Generated title using LLM: {llm_title}')

37
openhands/utils/utils.py Normal file
View File

@ -0,0 +1,37 @@
from copy import deepcopy
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage import get_file_store
from openhands.storage.data_models.settings import Settings
def setup_llm_config(config: OpenHandsConfig, settings: Settings) -> OpenHandsConfig:
# Copying this means that when we update variables they are not applied to the shared global configuration!
config = deepcopy(config)
llm_config = config.get_llm_config()
llm_config.model = settings.llm_model or ''
llm_config.api_key = settings.llm_api_key
llm_config.base_url = settings.llm_base_url
config.set_llm_config(llm_config)
return config
def create_registry_and_convo_stats(
config: OpenHandsConfig,
sid: str,
user_id: str | None,
user_settings: Settings | None = None,
) -> tuple[LLMRegistry, ConversationStats, OpenHandsConfig]:
user_config = config
if user_settings:
user_config = setup_llm_config(config, user_settings)
agent_cls = user_settings.agent if user_settings else None
llm_registry = LLMRegistry(user_config, agent_cls)
file_store = get_file_store(user_config.file_store, user_config.file_store_path)
convo_stats = ConversationStats(file_store, sid, user_id)
llm_registry.subscribe(convo_stats.register_llm)
return llm_registry, convo_stats, user_config

View File

@ -1,3 +1,4 @@
[pytest]
addopts = -p no:warnings
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function

0
tests/__init__.py Normal file
View File

View File

@ -10,6 +10,7 @@ from pytest import TempPathFactory
from openhands.core.config import MCPConfig, OpenHandsConfig, load_openhands_config
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.base import Runtime
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
@ -268,9 +269,13 @@ def _load_runtime(
)
event_stream = EventStream(sid, file_store)
# Create a LLMRegistry instance for the runtime
llm_registry = LLMRegistry(config=OpenHandsConfig())
runtime = runtime_cls(
config=config,
event_stream=event_stream,
llm_registry=llm_registry,
sid=sid,
plugins=plugins,
)

0
tests/unit/__init__.py Normal file
View File

View File

@ -20,7 +20,7 @@ def test_llm():
def _get_llm(type_: type[LLM]):
with _patch_http():
return type_(config=config.get_llm_config())
return type_(config=config.get_llm_config(), service_id='test_service')
@pytest.fixture

View File

@ -38,7 +38,7 @@ def default_config():
def test_llm_init_with_default_config(default_config):
llm = LLM(default_config)
llm = LLM(default_config, service_id='test-service')
assert llm.config.model == 'gpt-4o'
assert llm.config.api_key.get_secret_value() == 'test_key'
assert isinstance(llm.metrics, Metrics)
@ -129,7 +129,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
'max_input_tokens': 8000,
'max_output_tokens': 2000,
}
llm = LLM(default_config)
llm = LLM(default_config, service_id='test-service')
llm.init_model_info()
assert llm.config.max_input_tokens == 8000
assert llm.config.max_output_tokens == 2000
@ -138,7 +138,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
@patch('openhands.llm.llm.litellm.get_model_info')
def test_llm_init_without_model_info(mock_get_model_info, default_config):
mock_get_model_info.side_effect = Exception('Model info not available')
llm = LLM(default_config)
llm = LLM(default_config, service_id='test-service')
llm.init_model_info()
assert llm.config.max_input_tokens is None
assert llm.config.max_output_tokens is None
@ -154,7 +154,7 @@ def test_llm_init_with_custom_config():
top_p=0.9,
top_k=None,
)
llm = LLM(custom_config)
llm = LLM(custom_config, service_id='test-service')
assert llm.config.model == 'custom-model'
assert llm.config.api_key.get_secret_value() == 'custom_key'
assert llm.config.max_input_tokens == 5000
@ -168,7 +168,7 @@ def test_llm_init_with_custom_config():
def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
# Create a config with top_k set
config_with_top_k = LLMConfig(top_k=50)
llm = LLM(config_with_top_k)
llm = LLM(config_with_top_k, service_id='test-service')
# Define a side effect function to check top_k
def side_effect(*args, **kwargs):
@ -186,7 +186,7 @@ def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
# Create a config with top_k set to None
config_without_top_k = LLMConfig(top_k=None)
llm = LLM(config_without_top_k)
llm = LLM(config_without_top_k, service_id='test-service')
# Define a side effect function to check top_k
def side_effect(*args, **kwargs):
@ -202,7 +202,7 @@ def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
def test_llm_init_with_metrics():
config = LLMConfig(model='gpt-4o', api_key='test_key')
metrics = Metrics()
llm = LLM(config, metrics=metrics)
llm = LLM(config, metrics=metrics, service_id='test-service')
assert llm.metrics is metrics
assert (
llm.metrics.model_name == 'default'
@ -224,7 +224,7 @@ def test_response_latency_tracking(mock_time, mock_litellm_completion):
# Create LLM instance and make a completion call
config = LLMConfig(model='gpt-4o', api_key='test_key')
llm = LLM(config)
llm = LLM(config, service_id='test-service')
response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
# Verify the response latency was tracked correctly
@ -257,7 +257,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
'max_input_tokens': 7000,
'max_output_tokens': 1500,
}
llm = LLM(default_config)
llm = LLM(default_config, service_id='test-service')
llm.init_model_info()
assert llm.config.max_input_tokens == 7000
assert llm.config.max_output_tokens == 1500
@ -280,7 +280,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
default_config.model = (
'custom-model' # Use a model not in FUNCTION_CALLING_SUPPORTED_MODELS
)
llm = LLM(default_config)
llm = LLM(default_config, service_id='test-service')
llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
tools=[
@ -292,7 +292,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
# Test with Grok-4 model that doesn't support stop parameter
default_config.model = 'xai/grok-4-0709'
llm = LLM(default_config)
llm = LLM(default_config, service_id='test-service')
llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
tools=[
@ -314,7 +314,7 @@ def test_completion_with_mocked_logger(
'choices': [{'message': {'content': 'Test response'}}]
}
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -345,7 +345,7 @@ def test_completion_retries(
{'choices': [{'message': {'content': 'Retry successful'}}]},
]
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -365,7 +365,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
{'choices': [{'message': {'content': 'Retry successful'}}]},
]
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -387,7 +387,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
def test_completion_operation_cancelled(mock_litellm_completion, default_config):
mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
with pytest.raises(OperationCancelled):
llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
@ -404,7 +404,7 @@ def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
mock_litellm_completion.side_effect = side_effect
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
with pytest.raises(OperationCancelled):
try:
llm.completion(
@ -428,7 +428,7 @@ def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_
mock_litellm_completion.side_effect = side_effect
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
result = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -469,7 +469,7 @@ def test_completion_retry_with_llm_no_response_error_zero_temp(
mock_litellm_completion.side_effect = side_effect
# Create LLM instance and make a completion call
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -509,7 +509,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp(
'LLM did not return a response'
)
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
with pytest.raises(LLMNoResponseError):
llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
@ -575,7 +575,7 @@ def test_gemini_25_pro_function_calling(mock_httpx_get, mock_get_model_info):
for model_name, expected_support in test_cases:
config = LLMConfig(model=model_name, api_key='test_key')
llm = LLM(config)
llm = LLM(config, service_id='test-service')
assert llm.is_function_calling_active() == expected_support, (
f'Expected function calling support to be {expected_support} for model {model_name}'
@ -617,7 +617,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp_successful_ret
mock_litellm_completion.side_effect = side_effect
# Create LLM instance and make a completion call with non-zero temperature
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -677,7 +677,7 @@ def test_completion_retry_with_llm_no_response_error_successful_retry(
mock_litellm_completion.side_effect = side_effect
# Create LLM instance and make a completion call with explicit temperature=0
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -709,7 +709,7 @@ def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
}
mock_litellm_completion.return_value = mock_response
test_llm = LLM(config=default_config)
test_llm = LLM(config=default_config, service_id='test-service')
response = test_llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -743,7 +743,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
}
# Initialize LLM and call completion
llm = LLM(config=gemini_config)
llm = LLM(config=gemini_config, service_id='test-service')
llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
# Verify that litellm_completion was called with the 'thinking' parameter
@ -762,7 +762,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
@patch('openhands.llm.llm.litellm.token_counter')
def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
mock_token_counter.return_value = 42
llm = LLM(default_config)
llm = LLM(default_config, service_id='test-service')
messages = [{'role': 'user', 'content': 'Hello!'}]
token_count = llm.get_token_count(messages)
@ -777,7 +777,7 @@ def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
def test_get_token_count_with_message_objects(
mock_token_counter, default_config, mock_logger
):
llm = LLM(default_config)
llm = LLM(default_config, service_id='test-service')
# Create a Message object and its equivalent dict
message_obj = Message(role='user', content=[TextContent(text='Hello!')])
@ -806,7 +806,7 @@ def test_get_token_count_with_custom_tokenizer(
config = copy.deepcopy(default_config)
config.custom_tokenizer = 'custom/tokenizer'
llm = LLM(config)
llm = LLM(config, service_id='test-service')
messages = [{'role': 'user', 'content': 'Hello!'}]
token_count = llm.get_token_count(messages)
@ -823,7 +823,7 @@ def test_get_token_count_error_handling(
mock_token_counter, default_config, mock_logger
):
mock_token_counter.side_effect = Exception('Token counting failed')
llm = LLM(default_config)
llm = LLM(default_config, service_id='test-service')
messages = [{'role': 'user', 'content': 'Hello!'}]
token_count = llm.get_token_count(messages)
@ -865,7 +865,7 @@ def test_llm_token_usage(mock_litellm_completion, default_config):
# We'll make mock_litellm_completion return these responses in sequence
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
# First call
llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}])
@ -924,7 +924,7 @@ def test_accumulated_token_usage(mock_litellm_completion, default_config):
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
# Create LLM instance
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
# First call
llm.completion(messages=[{'role': 'user', 'content': 'First message'}])
@ -980,7 +980,7 @@ def test_completion_with_log_completions(mock_litellm_completion, default_config
}
mock_litellm_completion.return_value = mock_response
test_llm = LLM(config=default_config)
test_llm = LLM(config=default_config, service_id='test-service')
response = test_llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -1006,7 +1006,7 @@ def test_llm_base_url_auto_protocol_patch(mock_get):
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = {'model': 'fake'}
llm = LLM(config=config)
llm = LLM(config=config, service_id='test-service')
llm.init_model_info()
called_url = mock_get.call_args[0][0]
@ -1020,7 +1020,7 @@ def test_unknown_model_token_limits():
"""Test that models without known token limits get None for both max_output_tokens and max_input_tokens."""
# Create LLM instance with a non-existent model to avoid litellm having model info for it
config = LLMConfig(model='non-existent-model', api_key='test_key')
llm = LLM(config)
llm = LLM(config, service_id='test-service')
# Verify max_output_tokens and max_input_tokens are initialized to None (default value)
assert llm.config.max_output_tokens is None
@ -1031,7 +1031,7 @@ def test_max_tokens_from_model_info():
"""Test that max_output_tokens and max_input_tokens are correctly initialized from model info."""
# Create LLM instance with GPT-4 model which has known token limits
config = LLMConfig(model='gpt-4', api_key='test_key')
llm = LLM(config)
llm = LLM(config, service_id='test-service')
# GPT-4 has specific token limits
# These are the expected values from litellm
@ -1043,7 +1043,7 @@ def test_claude_3_7_sonnet_max_output_tokens():
"""Test that Claude 3.7 Sonnet models get the special 64000 max_output_tokens value and default max_input_tokens."""
# Create LLM instance with Claude 3.7 Sonnet model
config = LLMConfig(model='claude-3-7-sonnet', api_key='test_key')
llm = LLM(config)
llm = LLM(config, service_id='test-service')
# Verify max_output_tokens is set to 64000 for Claude 3.7 Sonnet
assert llm.config.max_output_tokens == 64000
@ -1055,7 +1055,7 @@ def test_claude_sonnet_4_max_output_tokens():
"""Test that Claude Sonnet 4 models get the correct max_output_tokens and max_input_tokens values."""
# Create LLM instance with a Claude Sonnet 4 model
config = LLMConfig(model='claude-sonnet-4-20250514', api_key='test_key')
llm = LLM(config)
llm = LLM(config, service_id='test-service')
# Verify max_output_tokens is set to the expected value
assert llm.config.max_output_tokens == 64000
@ -1068,7 +1068,7 @@ def test_sambanova_deepseek_model_max_output_tokens():
"""Test that SambaNova DeepSeek-V3-0324 model gets the correct max_output_tokens value."""
# Create LLM instance with SambaNova DeepSeek model
config = LLMConfig(model='sambanova/DeepSeek-V3-0324', api_key='test_key')
llm = LLM(config)
llm = LLM(config, service_id='test-service')
# SambaNova DeepSeek model has specific token limits
# This is the expected value from litellm
@ -1081,7 +1081,7 @@ def test_max_output_tokens_override_in_config():
config = LLMConfig(
model='claude-sonnet-4-20250514', api_key='test_key', max_output_tokens=2048
)
llm = LLM(config)
llm = LLM(config, service_id='test-service')
# Verify the config has the overridden max_output_tokens value
assert llm.config.max_output_tokens == 2048
@ -1098,7 +1098,7 @@ def test_azure_model_default_max_tokens():
)
# Create LLM instance with Azure model
llm = LLM(azure_config)
llm = LLM(azure_config, service_id='test-service')
# Verify the config has the default max_output_tokens value
assert llm.config.max_output_tokens is None # Default value
@ -1143,7 +1143,7 @@ def test_gemini_none_reasoning_effort_uses_thinking_budget(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
}
llm = LLM(config)
llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages)
@ -1167,7 +1167,7 @@ def test_gemini_low_reasoning_effort_uses_thinking_budget(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
}
llm = LLM(config)
llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages)
@ -1191,7 +1191,7 @@ def test_gemini_medium_reasoning_effort_passes_through(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
}
llm = LLM(config)
llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages)
@ -1214,7 +1214,7 @@ def test_gemini_high_reasoning_effort_passes_through(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
}
llm = LLM(config)
llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages)
@ -1235,7 +1235,7 @@ def test_non_gemini_uses_reasoning_effort(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
}
llm = LLM(config)
llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages)
@ -1259,7 +1259,7 @@ def test_non_reasoning_model_no_optimization(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
}
llm = LLM(config)
llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages)
@ -1285,7 +1285,7 @@ def test_gemini_performance_optimization_end_to_end(mock_completion):
assert config.reasoning_effort is None
# Create LLM and make completion
llm = LLM(config)
llm = LLM(config, service_id='test-service')
messages = [{'role': 'user', 'content': 'Solve this complex problem'}]
response = llm.completion(messages=messages)

View File

@ -207,7 +207,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
),
]
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
handler = ServiceContextIssue(
GithubIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
)
@ -251,7 +251,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
)
# Initialize LLM and handler
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
handler = ServiceContextPR(
GithubPRHandler('test-owner', 'test-repo', 'test-token'), default_config
)

View File

@ -463,7 +463,7 @@ async def test_process_issue(
[],
)
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
handler_instance.llm = LLM(llm_config)
handler_instance.llm = LLM(llm_config, service_id='test-service')
# Mock the runtime and its methods
mock_runtime = MagicMock()

View File

@ -209,7 +209,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
),
]
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
handler = ServiceContextIssue(
GitlabIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
)
@ -253,7 +253,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
)
# Initialize LLM and handler
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
handler = ServiceContextPR(
GitlabPRHandler('test-owner', 'test-repo', 'test-token'), default_config
)

View File

@ -500,7 +500,7 @@ async def test_process_issue(
[],
)
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
handler_instance.llm = LLM(llm_config)
handler_instance.llm = LLM(llm_config, service_id='test-service')
# Create mock runtime and mock run_controller
mock_runtime = MagicMock()

View File

@ -18,6 +18,7 @@ from openhands.controller.state.control_flags import (
from openhands.controller.state.state import State
from openhands.core.config import OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.main import run_controller
from openhands.core.schema import AgentState
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
@ -33,6 +34,7 @@ from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
from openhands.llm.metrics import Metrics, TokenUsage
from openhands.memory.condenser.condenser import Condensation
from openhands.memory.condenser.impl.conversation_window_condenser import (
@ -45,6 +47,7 @@ from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore
@ -61,15 +64,43 @@ def event_loop():
@pytest.fixture
def mock_agent():
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = OpenHandsConfig().get_llm_config()
def mock_agent_with_stats():
"""Create a mock agent with properly connected LLM registry and conversation stats."""
import uuid
# Add config with enable_mcp attribute
agent.config = MagicMock(spec=AgentConfig)
agent.config.enable_mcp = True
# Create LLM registry
config = OpenHandsConfig()
llm_registry = LLMRegistry(config=config)
# Create conversation stats
file_store = InMemoryFileStore({})
conversation_id = f'test-conversation-{uuid.uuid4()}'
conversation_stats = ConversationStats(
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
)
# Connect registry to stats (this is the key requirement)
llm_registry.subscribe(conversation_stats.register_llm)
# Create mock agent
agent = MagicMock(spec=Agent)
agent_config = MagicMock(spec=AgentConfig)
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
agent_config.disabled_microagents = []
agent_config.enable_mcp = True
llm_registry.service_to_llm.clear()
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
agent.llm = mock_llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
# Add a proper system message mock
system_message = SystemMessageAction(
@ -79,7 +110,7 @@ def mock_agent():
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
return agent
return agent, conversation_stats, llm_registry
@pytest.fixture
@ -134,10 +165,13 @@ async def send_event_to_controller(controller, event):
@pytest.mark.asyncio
async def test_set_agent_state(mock_agent, mock_event_stream):
async def test_set_agent_state(mock_agent_with_stats, mock_event_stream):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -152,10 +186,13 @@ async def test_set_agent_state(mock_agent, mock_event_stream):
@pytest.mark.asyncio
async def test_on_event_message_action(mock_agent, mock_event_stream):
async def test_on_event_message_action(mock_agent_with_stats, mock_event_stream):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -169,10 +206,15 @@ async def test_on_event_message_action(mock_agent, mock_event_stream):
@pytest.mark.asyncio
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
async def test_on_event_change_agent_state_action(
mock_agent_with_stats, mock_event_stream
):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -186,10 +228,17 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream)
@pytest.mark.asyncio
async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
async def test_react_to_exception(
mock_agent_with_stats,
mock_event_stream,
mock_status_callback,
):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
status_callback=mock_status_callback,
iteration_delta=10,
sid='test',
@ -204,12 +253,17 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
@pytest.mark.asyncio
async def test_react_to_content_policy_violation(
mock_agent, mock_event_stream, mock_status_callback
mock_agent_with_stats,
mock_event_stream,
mock_status_callback,
):
"""Test that the controller properly handles content policy violations from the LLM."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
status_callback=mock_status_callback,
iteration_delta=10,
sid='test',
@ -246,18 +300,16 @@ async def test_react_to_content_policy_violation(
@pytest.mark.asyncio
async def test_run_controller_with_fatal_error(
test_event_stream, mock_memory, mock_agent
test_event_stream, mock_memory, mock_agent_with_stats
):
config = OpenHandsConfig()
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
return CmdRunAction(command='ls')
mock_agent.step = agent_step_fn
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=ActionExecutionClient)
@ -284,15 +336,17 @@ async def test_run_controller_with_fatal_error(
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
)
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
)
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
llm_registry=llm_registry,
)
print(f'state: {state}')
events = list(test_event_stream.get_events())
print(f'event_stream: {events}')
@ -312,18 +366,16 @@ async def test_run_controller_with_fatal_error(
@pytest.mark.asyncio
async def test_run_controller_stop_with_stuck(
test_event_stream, mock_memory, mock_agent
test_event_stream, mock_memory, mock_agent_with_stats
):
config = OpenHandsConfig()
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
return CmdRunAction(command='ls')
mock_agent.step = agent_step_fn
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=ActionExecutionClient)
@ -352,15 +404,17 @@ async def test_run_controller_stop_with_stuck(
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
)
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
)
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
llm_registry=llm_registry,
)
events = list(test_event_stream.get_events())
print(f'state: {state}')
for i, event in enumerate(events):
@ -391,11 +445,14 @@ async def test_run_controller_stop_with_stuck(
@pytest.mark.asyncio
async def test_max_iterations_extension(mock_agent, mock_event_stream):
async def test_max_iterations_extension(mock_agent_with_stats, mock_event_stream):
# Test with headless_mode=False - should extend max_iterations
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -426,6 +483,7 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -450,7 +508,9 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
@pytest.mark.asyncio
async def test_step_max_budget(mock_agent, mock_event_stream):
async def test_step_max_budget(mock_agent_with_stats, mock_event_stream):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Metrics are always synced with budget flag before
metrics = Metrics()
metrics.accumulated_cost = 10.1
@ -458,9 +518,13 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
limit_increase_amount=10, current_value=10.1, max_value=10
)
# Update agent's LLM metrics in place
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
budget_per_task_delta=10,
sid='test',
@ -475,7 +539,9 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
@pytest.mark.asyncio
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
async def test_step_max_budget_headless(mock_agent_with_stats, mock_event_stream):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Metrics are always synced with budget flag before
metrics = Metrics()
metrics.accumulated_cost = 10.1
@ -483,9 +549,13 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
limit_increase_amount=10, current_value=10.1, max_value=10
)
# Update agent's LLM metrics in place
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
budget_per_task_delta=10,
sid='test',
@ -500,12 +570,14 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
@pytest.mark.asyncio
async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
async def test_budget_reset_on_continue(mock_agent_with_stats, mock_event_stream):
"""Test that when a user continues after hitting the budget limit:
1. Error is thrown when budget cap is exceeded
2. LLM budget does not reset when user continues
3. Budget is extended by adding the initial budget cap to the current accumulated cost
"""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Create a real Metrics instance shared between controller state and llm
metrics = Metrics()
metrics.accumulated_cost = 6.0
@ -521,10 +593,14 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
),
)
# Update agent's LLM metrics in place
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
# Create controller with budget cap
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
budget_per_task_delta=initial_budget,
sid='test',
@ -570,11 +646,17 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
@pytest.mark.asyncio
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
async def test_reset_with_pending_action_no_observation(
mock_agent_with_stats, mock_event_stream
):
"""Test reset() when there's a pending action with tool call metadata but no observation."""
# Connect LLM registry to conversation stats
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -617,11 +699,17 @@ async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_s
@pytest.mark.asyncio
async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_stream):
async def test_reset_with_pending_action_stopped_state(
mock_agent_with_stats, mock_event_stream
):
"""Test reset() when there's a pending action and agent state is STOPPED."""
# Connect LLM registry to conversation stats
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -665,12 +753,16 @@ async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_st
@pytest.mark.asyncio
async def test_reset_with_pending_action_existing_observation(
mock_agent, mock_event_stream
mock_agent_with_stats, mock_event_stream
):
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
# Connect LLM registry to conversation stats
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -708,11 +800,15 @@ async def test_reset_with_pending_action_existing_observation(
@pytest.mark.asyncio
async def test_reset_without_pending_action(mock_agent, mock_event_stream):
async def test_reset_without_pending_action(mock_agent_with_stats, mock_event_stream):
"""Test reset() when there's no pending action."""
# Connect LLM registry to conversation stats
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -738,12 +834,15 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream):
@pytest.mark.asyncio
async def test_reset_with_pending_action_no_metadata(
mock_agent, mock_event_stream, monkeypatch
mock_agent_with_stats, mock_event_stream, monkeypatch
):
"""Test reset() when there's a pending action without tool call metadata."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -782,16 +881,13 @@ async def test_reset_with_pending_action_no_metadata(
@pytest.mark.asyncio
async def test_run_controller_max_iterations_has_metrics(
test_event_stream, mock_memory, mock_agent
test_event_stream, mock_memory, mock_agent_with_stats
):
config = OpenHandsConfig(
max_iterations=3,
)
event_stream = test_event_stream
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
step_count = 0
@ -833,15 +929,17 @@ async def test_run_controller_max_iterations_has_metrics(
event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()))
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
)
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
llm_registry=llm_registry,
)
state.metrics = mock_agent.llm.metrics
assert state.iteration_flag.current_value == 3
@ -867,10 +965,17 @@ async def test_run_controller_max_iterations_has_metrics(
@pytest.mark.asyncio
async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback):
async def test_notify_on_llm_retry(
mock_agent_with_stats,
mock_event_stream,
mock_status_callback,
):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
status_callback=mock_status_callback,
iteration_delta=10,
sid='test',
@ -908,9 +1013,15 @@ async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_ca
],
)
async def test_context_window_exceeded_error_handling(
context_window_error, mock_agent, mock_runtime, test_event_stream, mock_memory
context_window_error,
mock_agent_with_stats,
mock_runtime,
test_event_stream,
mock_memory,
):
"""Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
max_iterations = 5
error_after = 2
@ -973,18 +1084,20 @@ async def test_context_window_exceeded_error_handling(
# state is set to error out before then, if this terminates and we have a
# record of the error being thrown we can be confident that the controller
# handles the truncation correctly.
final_state = await asyncio.wait_for(
run_controller(
config=config,
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
),
timeout=10,
)
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
final_state = await asyncio.wait_for(
run_controller(
config=config,
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
llm_registry=llm_registry,
),
timeout=10,
)
# Check that the context window exception was thrown and the controller
# called the agent's `step` function the right number of times.
@ -1072,9 +1185,13 @@ async def test_context_window_exceeded_error_handling(
@pytest.mark.asyncio
async def test_run_controller_with_context_window_exceeded_with_truncation(
mock_agent, mock_runtime, mock_memory, test_event_stream
mock_agent_with_stats,
mock_runtime,
mock_memory,
test_event_stream,
):
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
class StepState:
def __init__(self):
@ -1121,18 +1238,20 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
mock_runtime.config = copy.deepcopy(config)
try:
state = await asyncio.wait_for(
run_controller(
config=config,
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
),
timeout=10,
)
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await asyncio.wait_for(
run_controller(
config=config,
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
llm_registry=llm_registry,
),
timeout=10,
)
# A timeout error indicates the run_controller entrypoint is not making
# progress
@ -1156,9 +1275,13 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
@pytest.mark.asyncio
async def test_run_controller_with_context_window_exceeded_without_truncation(
mock_agent, mock_runtime, mock_memory, test_event_stream
mock_agent_with_stats,
mock_runtime,
mock_memory,
test_event_stream,
):
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
class StepState:
def __init__(self):
@ -1199,18 +1322,20 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
config = OpenHandsConfig(max_iterations=3)
mock_runtime.config = copy.deepcopy(config)
try:
state = await asyncio.wait_for(
run_controller(
config=config,
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
),
timeout=10,
)
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await asyncio.wait_for(
run_controller(
config=config,
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
llm_registry=llm_registry,
),
timeout=10,
)
# A timeout error indicates the run_controller entrypoint is not making
# progress
@ -1244,7 +1369,11 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
@pytest.mark.asyncio
async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
async def test_run_controller_with_memory_error(
test_event_stream, mock_agent_with_stats
):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
config = OpenHandsConfig()
event_stream = test_event_stream
@ -1273,15 +1402,17 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
with patch.object(
memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge
):
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
)
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
llm_registry=llm_registry,
)
assert state.iteration_flag.current_value == 0
assert state.agent_state == AgentState.ERROR
@ -1289,7 +1420,9 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
@pytest.mark.asyncio
async def test_action_metrics_copy(mock_agent):
async def test_action_metrics_copy(mock_agent_with_stats):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Setup
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
@ -1299,8 +1432,7 @@ async def test_action_metrics_copy(mock_agent):
initial_state = State(metrics=metrics, budget_flag=None)
# Create agent with metrics
mock_agent.llm = MagicMock(spec=LLM)
# Update agent's LLM metrics
# Add multiple token usages - we should get the last one in the action
usage1 = TokenUsage(
@ -1342,6 +1474,11 @@ async def test_action_metrics_copy(mock_agent):
mock_agent.llm.metrics = metrics
# Register the metrics with the LLM registry
llm_registry.service_to_llm['agent'] = mock_agent.llm
# Manually notify the conversation stats about the LLM registration
llm_registry.notify(RegistryEvent(llm=mock_agent.llm, service_id='agent'))
# Mock agent step to return an action
action = MessageAction(content='Test message')
@ -1354,6 +1491,7 @@ async def test_action_metrics_copy(mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -1411,12 +1549,13 @@ async def test_action_metrics_copy(mock_agent):
@pytest.mark.asyncio
async def test_condenser_metrics_included(mock_agent, test_event_stream):
async def test_condenser_metrics_included(mock_agent_with_stats, test_event_stream):
"""Test that metrics from the condenser's LLM are included in the action metrics."""
# Set up agent metrics
agent_metrics = Metrics(model_name='agent-model')
agent_metrics.accumulated_cost = 0.05
agent_metrics._accumulated_token_usage = TokenUsage(
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Set up agent metrics in place
mock_agent.llm.metrics.accumulated_cost = 0.05
mock_agent.llm.metrics._accumulated_token_usage = TokenUsage(
model='agent-model',
prompt_tokens=100,
completion_tokens=50,
@ -1424,7 +1563,6 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
cache_write_tokens=10,
response_id='agent-accumulated',
)
# mock_agent.llm.metrics = agent_metrics
mock_agent.name = 'TestAgent'
# Create condenser with its own metrics
@ -1442,6 +1580,11 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
)
condenser.llm.metrics = condenser_metrics
# Register the condenser metrics with the LLM registry
llm_registry.service_to_llm['condenser'] = condenser.llm
# Manually notify the conversation stats about the condenser LLM registration
llm_registry.notify(RegistryEvent(llm=condenser.llm, service_id='condenser'))
# Attach the condenser to the mock_agent
mock_agent.condenser = condenser
@ -1463,11 +1606,12 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=test_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
initial_state=State(metrics=agent_metrics, budget_flag=None),
initial_state=State(metrics=mock_agent.llm.metrics, budget_flag=None),
)
# Execute one step
@ -1505,7 +1649,9 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
@pytest.mark.asyncio
async def test_first_user_message_with_identical_content(test_event_stream, mock_agent):
async def test_first_user_message_with_identical_content(
test_event_stream, mock_agent_with_stats
):
"""Test that _first_user_message correctly identifies the first user message.
This test verifies that messages with identical content but different IDs are properly
@ -1514,14 +1660,12 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
The issue we're checking is that the comparison (action == self._first_user_message())
should correctly differentiate between messages with the same content but different IDs.
"""
# Create an agent controller
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController(
agent=mock_agent,
event_stream=test_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test',
confirmation_mode=False,
@ -1569,11 +1713,15 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
@pytest.mark.asyncio
async def test_agent_controller_processes_null_observation_with_cause():
async def test_agent_controller_processes_null_observation_with_cause(
mock_agent_with_stats,
):
"""Test that AgentController processes NullObservation events with a cause value.
And that the agent's step method is called as a result.
"""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Create an in-memory file store and real event stream
file_store = InMemoryFileStore()
event_stream = EventStream(sid='test-session', file_store=file_store)
@ -1581,19 +1729,11 @@ async def test_agent_controller_processes_null_observation_with_cause():
# Create a Memory instance - not used directly in this test but needed for setup
Memory(event_stream=event_stream, sid='test-session')
# Create a mock agent with necessary attributes
mock_agent = MagicMock(spec=Agent)
mock_agent.get_system_message = MagicMock(
return_value=None,
)
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
# Create a controller with the mock agent
controller = AgentController(
agent=mock_agent,
event_stream=event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test-session',
)
@ -1655,8 +1795,12 @@ async def test_agent_controller_processes_null_observation_with_cause():
)
def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agent):
def test_agent_controller_should_step_with_null_observation_cause_zero(
mock_agent_with_stats,
):
"""Test that AgentController's should_step method returns False for NullObservation with cause = 0."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Create a mock event stream
file_store = InMemoryFileStore()
event_stream = EventStream(sid='test-session', file_store=file_store)
@ -1665,6 +1809,7 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
controller = AgentController(
agent=mock_agent,
event_stream=event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
sid='test-session',
)
@ -1683,10 +1828,15 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
)
def test_system_message_in_event_stream(mock_agent, test_event_stream):
def test_system_message_in_event_stream(mock_agent_with_stats, test_event_stream):
"""Test that SystemMessageAction is added to event stream in AgentController."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
_ = AgentController(
agent=mock_agent, event_stream=test_event_stream, iteration_delta=10
agent=mock_agent,
event_stream=test_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
)
# Get events from the event stream

View File

@ -12,8 +12,9 @@ from openhands.controller.state.control_flags import (
IterationControlFlag,
)
from openhands.controller.state.state import State
from openhands.core.config import LLMConfig
from openhands.core.config import OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream
from openhands.events.action import (
@ -28,11 +29,39 @@ from openhands.events.event import Event, RecallType
from openhands.events.observation.agent import RecallObservation
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def llm_registry():
config = OpenHandsConfig()
return LLMRegistry(config=config)
@pytest.fixture
def conversation_stats():
import uuid
file_store = InMemoryFileStore({})
# Use a unique conversation ID for each test to avoid conflicts
conversation_id = f'test-conversation-{uuid.uuid4()}'
return ConversationStats(
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
)
@pytest.fixture
def connected_registry_and_stats(llm_registry, conversation_stats):
"""Connect the LLMRegistry and ConversationStats properly"""
# Subscribe to LLM registry events to track metrics
llm_registry.subscribe(conversation_stats.register_llm)
return llm_registry, conversation_stats
@pytest.fixture
def mock_event_stream():
"""Creates an event stream in memory."""
@ -42,15 +71,17 @@ def mock_event_stream():
@pytest.fixture
def mock_parent_agent():
def mock_parent_agent(llm_registry):
"""Creates a mock parent agent for testing delegation."""
agent = MagicMock(spec=Agent)
agent.name = 'ParentAgent'
agent.llm = MagicMock(spec=LLM)
agent.llm.service_id = 'main_agent'
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.llm.retry_listener = None # Add retry_listener attribute
agent.config = AgentConfig()
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
# Add a proper system message mock
system_message = SystemMessageAction(content='Test system message')
@ -61,15 +92,17 @@ def mock_parent_agent():
@pytest.fixture
def mock_child_agent():
def mock_child_agent(llm_registry):
"""Creates a mock child agent for testing delegation."""
agent = MagicMock(spec=Agent)
agent.name = 'ChildAgent'
agent.llm = MagicMock(spec=LLM)
agent.llm.service_id = 'main_agent'
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.llm.retry_listener = None # Add retry_listener attribute
agent.config = AgentConfig()
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
@ -78,15 +111,37 @@ def mock_child_agent():
return agent
def create_mock_agent_factory(mock_child_agent, llm_registry):
"""Helper function to create a mock agent factory with proper LLM registration."""
def create_mock_agent(config, llm_registry=None):
# Register the mock agent's LLM in the registry so get_combined_metrics() can find it
if llm_registry:
mock_child_agent.llm = llm_registry.get_llm('agent_llm', LLMConfig())
mock_child_agent.llm_registry = (
llm_registry # Set the llm_registry attribute
)
return mock_child_agent
return create_mock_agent
@pytest.mark.asyncio
async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream):
"""Test that when the parent agent delegates to a child
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
2. metrics are accumulated globally (delegate is adding to the parents metrics)
3. local metrics for the delegate are still accessible
async def test_delegation_flow(
mock_parent_agent, mock_child_agent, mock_event_stream, connected_registry_and_stats
):
"""
Test that when the parent agent delegates to a child
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
2. metrics are accumulated globally via LLM registry (delegate adds to the global metrics)
3. global metrics tracking works correctly through the LLM registry
"""
llm_registry, conversation_stats = connected_registry_and_stats
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
Agent.get_cls = Mock(
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
)
step_count = 0
@ -97,6 +152,12 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
mock_child_agent.step = agent_step_fn
# Set up the parent agent's LLM with initial cost and register it in the registry
# The parent agent's LLM should use the existing registered LLM to ensure proper tracking
parent_llm = llm_registry.service_to_llm['agent']
parent_llm.metrics.accumulated_cost = 2
mock_parent_agent.llm = parent_llm
parent_metrics = Metrics()
parent_metrics.accumulated_cost = 2
# Create parent controller
@ -114,6 +175,7 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
parent_controller = AgentController(
agent=mock_parent_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=1, # Add the required iteration_delta parameter
sid='parent',
confirmation_mode=False,
@ -180,21 +242,23 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
for i in range(4):
delegate_controller.state.iteration_flag.step()
delegate_controller.agent.step(delegate_controller.state)
# Update the agent's LLM metrics (not the deprecated state metrics)
delegate_controller.agent.llm.metrics.add_cost(1.0)
assert (
delegate_controller.state.get_local_step() == 4
) # verify local metrics are accessible via snapshot
# Check that the conversation stats has the combined metrics (parent + delegate)
combined_metrics = delegate_controller.state.convo_stats.get_combined_metrics()
assert (
delegate_controller.state.metrics.accumulated_cost
== 6 # Make sure delegate tracks global cost
combined_metrics.accumulated_cost
== 6 # Make sure delegate tracks global cost (2 from parent + 4 from delegate)
)
assert (
delegate_controller.state.get_local_metrics().accumulated_cost
== 4 # Delegate spent one dollar per step
)
# Since metrics are now global via LLM registry, local metrics tracking
# is handled differently. The delegate's LLM shares the same metrics object
# as the parent for global tracking, so we verify the global total is correct.
delegate_controller.state.outputs = {'delegate_result': 'done'}
@ -228,15 +292,18 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
],
)
async def test_delegate_step_different_states(
mock_parent_agent, mock_event_stream, delegate_state
mock_parent_agent, mock_event_stream, delegate_state, connected_registry_and_stats
):
"""Ensure that delegate is closed or remains open based on the delegate's state."""
llm_registry, conversation_stats = connected_registry_and_stats
# Create a state with iteration_flag.max_value set to 10
state = State(inputs={})
state.iteration_flag.max_value = 10
controller = AgentController(
agent=mock_parent_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=1, # Add the required iteration_delta parameter
sid='test',
confirmation_mode=False,
@ -292,11 +359,23 @@ async def test_delegate_step_different_states(
@pytest.mark.asyncio
async def test_delegate_hits_global_limits(
mock_child_agent, mock_event_stream, mock_parent_agent
mock_child_agent, mock_event_stream, mock_parent_agent, connected_registry_and_stats
):
"""Global limits from control flags should apply to delegates"""
"""
Global limits from control flags should apply to delegates
"""
llm_registry, conversation_stats = connected_registry_and_stats
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
Agent.get_cls = Mock(
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
)
# Set up the parent agent's LLM with initial cost and register it in the registry
mock_parent_agent.llm.metrics.accumulated_cost = 2
mock_parent_agent.llm.service_id = 'main_agent'
# Register the parent agent's LLM in the registry
llm_registry.service_to_llm['main_agent'] = mock_parent_agent.llm
parent_metrics = Metrics()
parent_metrics.accumulated_cost = 2
@ -315,6 +394,7 @@ async def test_delegate_hits_global_limits(
parent_controller = AgentController(
agent=mock_parent_agent,
event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=1, # Add the required iteration_delta parameter
sid='parent',
confirmation_mode=False,

View File

@ -9,12 +9,13 @@ from openhands.core.config import LLMConfig, OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.integrations.service_types import ProviderType
from openhands.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
@ -22,44 +23,70 @@ from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_agent():
"""Create a properly configured mock agent with all required nested attributes"""
# Create the base mocks
agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)
agent_config = MagicMock(spec=AgentConfig)
def mock_llm_registry():
"""Create a mock LLM registry that properly simulates LLM registration"""
config = OpenHandsConfig()
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
return registry
# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.max_message_chars = 1000
# Configure the agent config
agent_config.disabled_microagents = []
agent_config.enable_mcp = True
@pytest.fixture
def mock_conversation_stats():
"""Create a mock ConversationStats that properly simulates metrics tracking"""
file_store = InMemoryFileStore({})
stats = ConversationStats(
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
)
return stats
# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
return agent
@pytest.fixture
def connected_registry_and_stats(mock_llm_registry, mock_conversation_stats):
"""Connect the LLMRegistry and ConversationStats properly"""
# Subscribe to LLM registry events to track metrics
mock_llm_registry.subscribe(mock_conversation_stats.register_llm)
return mock_llm_registry, mock_conversation_stats
@pytest.fixture
def make_mock_agent():
def _make_mock_agent(llm_registry):
agent = MagicMock(spec=Agent)
agent_config = MagicMock(spec=AgentConfig)
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
agent_config.disabled_microagents = []
agent_config.enable_mcp = True
llm_registry.service_to_llm.clear()
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
agent.llm = mock_llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
return agent
return _make_mock_agent
@pytest.mark.asyncio
async def test_agent_session_start_with_no_state(mock_agent):
async def test_agent_session_start_with_no_state(
make_mock_agent, mock_llm_registry, mock_conversation_stats
):
"""Test that AgentSession.start() works correctly when there's no state to restore"""
mock_agent = make_mock_agent(mock_llm_registry)
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create a mock runtime and set it up
@ -140,13 +167,18 @@ async def test_agent_session_start_with_no_state(mock_agent):
@pytest.mark.asyncio
async def test_agent_session_start_with_restored_state(mock_agent):
async def test_agent_session_start_with_restored_state(
make_mock_agent, mock_llm_registry, mock_conversation_stats
):
"""Test that AgentSession.start() works correctly when there's a state to restore"""
mock_agent = make_mock_agent(mock_llm_registry)
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create a mock runtime and set it up
@ -230,13 +262,21 @@ async def test_agent_session_start_with_restored_state(mock_agent):
@pytest.mark.asyncio
async def test_metrics_centralization_and_sharing(mock_agent):
"""Test that metrics are centralized and shared between controller and agent."""
async def test_metrics_centralization_via_conversation_stats(
make_mock_agent, connected_registry_and_stats
):
"""Test that metrics are centralized through the ConversationStats service."""
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
mock_agent = make_mock_agent(mock_llm_registry)
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create a mock runtime and set it up
@ -262,6 +302,8 @@ async def test_metrics_centralization_and_sharing(mock_agent):
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# The registry already has a real metrics object set up in the fixture
# Patch necessary components
with (
patch(
@ -281,49 +323,50 @@ async def test_metrics_centralization_and_sharing(mock_agent):
max_iterations=10,
)
# Verify that the agent's LLM metrics and controller's state metrics are the same object
assert session.controller.agent.llm.metrics is session.controller.state.metrics
# Verify that the ConversationStats is properly set up
assert session.controller.state.convo_stats is mock_conversation_stats
# Add some metrics to the agent's LLM
# Add some metrics to the agent's LLM (simulating LLM usage)
test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost)
# Verify that the cost is reflected in the controller's state metrics
assert session.controller.state.metrics.accumulated_cost == test_cost
# Verify that the cost is reflected in the combined metrics from the conversation stats
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
assert combined_metrics.accumulated_cost == test_cost
# Create a test metrics object to simulate an observation with metrics
test_observation_metrics = Metrics()
test_observation_metrics.add_cost(0.1)
# Add more cost to simulate additional LLM usage
additional_cost = 0.1
session.controller.agent.llm.metrics.add_cost(additional_cost)
# Get the current accumulated cost before merging
current_cost = session.controller.state.metrics.accumulated_cost
# Verify the combined metrics reflect the total cost
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
assert combined_metrics.accumulated_cost == test_cost + additional_cost
# Simulate merging metrics from an observation
session.controller.state_tracker.merge_metrics(test_observation_metrics)
# Verify that the merged metrics are reflected in both agent and controller
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
assert (
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
)
# Reset the agent and verify that metrics are not reset
# Reset the agent and verify that combined metrics are preserved
session.controller.agent.reset()
# Metrics should still be the same after reset
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1
assert session.controller.agent.llm.metrics is session.controller.state.metrics
# Combined metrics should still be preserved after agent reset
assert (
session.controller.state.convo_stats.get_combined_metrics().accumulated_cost
== test_cost + additional_cost
)
@pytest.mark.asyncio
async def test_budget_control_flag_syncs_with_metrics(mock_agent):
async def test_budget_control_flag_syncs_with_metrics(
make_mock_agent, connected_registry_and_stats
):
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
mock_agent = make_mock_agent(mock_llm_registry)
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create a mock runtime and set it up
@ -349,6 +392,8 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# The registry already has a real metrics object set up in the fixture
# Patch necessary components
with (
patch(
@ -375,7 +420,7 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
assert session.controller.state.budget_flag.max_value == 1.0
assert session.controller.state.budget_flag.current_value == 0.0
# Add some metrics to the agent's LLM
# Add some metrics to the agent's LLM (simulating LLM usage)
test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost)
@ -384,24 +429,31 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
session.controller.state_tracker.sync_budget_flag_with_metrics()
assert session.controller.state.budget_flag.current_value == test_cost
# Create a test metrics object to simulate an observation with metrics
test_observation_metrics = Metrics()
test_observation_metrics.add_cost(0.1)
# Add more cost to simulate additional LLM usage
additional_cost = 0.1
session.controller.agent.llm.metrics.add_cost(additional_cost)
# Simulate merging metrics from an observation
session.controller.state_tracker.merge_metrics(test_observation_metrics)
# Sync again and verify the budget flag is updated
session.controller.state_tracker.sync_budget_flag_with_metrics()
assert (
session.controller.state.budget_flag.current_value
== test_cost + additional_cost
)
# Verify that the budget control flag's current value is updated to match the new accumulated cost
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
# Reset the agent and verify that metrics and budget flag are not reset
# Reset the agent and verify that budget flag still reflects the accumulated cost
session.controller.agent.reset()
# Budget control flag should still reflect the accumulated cost after reset
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
session.controller.state_tracker.sync_budget_flag_with_metrics()
assert (
session.controller.state.budget_flag.current_value
== test_cost + additional_cost
)
def test_override_provider_tokens_with_custom_secret():
def test_override_provider_tokens_with_custom_secret(
mock_llm_registry, mock_conversation_stats
):
"""Test that override_provider_tokens_with_custom_secret works correctly.
This test verifies that the method properly removes provider tokens when
@ -413,6 +465,8 @@ def test_override_provider_tokens_with_custom_secret():
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create test data

View File

@ -30,6 +30,7 @@ from openhands.agenthub.readonly_agent.tools import (
)
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.exceptions import FunctionCallNotExistsError
from openhands.core.message import ImageContent, Message, TextContent
from openhands.events.action import (
@ -42,10 +43,20 @@ from openhands.events.observation.commands import (
CmdOutputObservation,
)
from openhands.events.tool import ToolCallMetadata
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser import View
@pytest.fixture
def create_llm_registry():
def _get_registry(llm_config):
config = OpenHandsConfig()
config.set_llm_config(llm_config)
return LLMRegistry(config=config)
return _get_registry
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
def agent_class(request):
if request.param == 'CodeActAgent':
@ -57,18 +68,22 @@ def agent_class(request):
@pytest.fixture
def agent(agent_class) -> Union[CodeActAgent, ReadOnlyAgent]:
def agent(agent_class, create_llm_registry) -> Union[CodeActAgent, ReadOnlyAgent]:
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
config = AgentConfig()
agent = agent_class(llm=LLM(LLMConfig()), config=config)
agent = agent_class(config=config, llm_registry=create_llm_registry(llm_config))
agent.llm = Mock()
agent.llm.config = Mock()
agent.llm.config.max_message_chars = 1000
return agent
def test_agent_with_default_config_has_default_tools():
def test_agent_with_default_config_has_default_tools(create_llm_registry):
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
config = AgentConfig()
codeact_agent = CodeActAgent(llm=LLM(LLMConfig()), config=config)
codeact_agent = CodeActAgent(
config=config, llm_registry=create_llm_registry(llm_config)
)
assert len(codeact_agent.tools) > 0
default_tool_names = [tool['function']['name'] for tool in codeact_agent.tools]
assert {
@ -231,7 +246,7 @@ def test_response_to_actions_invalid_tool():
readonly_response_to_actions(mock_response)
def test_step_with_no_pending_actions(mock_state: State):
def test_step_with_no_pending_actions(mock_state: State, create_llm_registry):
# Mock the LLM response
mock_response = Mock()
mock_response.id = 'mock_id'
@ -252,9 +267,12 @@ def test_step_with_no_pending_actions(mock_state: State):
llm.format_messages_for_llm = Mock(return_value=[]) # Mock message formatting
# Create agent with mocked LLM
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
config = AgentConfig()
config.enable_prompt_extensions = False
agent = CodeActAgent(llm=llm, config=config)
agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
# Replace the LLM with our mock after creation
agent.llm = llm
# Test step with no pending actions
mock_state.latest_user_message = None
@ -281,15 +299,10 @@ def test_step_with_no_pending_actions(mock_state: State):
@pytest.mark.parametrize('agent_type', ['CodeActAgent', 'ReadOnlyAgent'])
def test_correct_tool_description_loaded_based_on_model_name(
agent_type, mock_state: State
agent_type, create_llm_registry
):
"""Tests that the simplified tool descriptions are loaded for specific models."""
o3_mock_config = Mock()
o3_mock_config.model = 'mock_o3_model'
llm = Mock()
llm.config = o3_mock_config
o3_mock_config = LLMConfig(model='mock_o3_model', api_key='test_key')
if agent_type == 'CodeActAgent':
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
@ -299,16 +312,19 @@ def test_correct_tool_description_loaded_based_on_model_name(
agent_class = ReadOnlyAgent
agent = agent_class(llm=llm, config=AgentConfig())
agent = agent_class(
config=AgentConfig(),
llm_registry=create_llm_registry(o3_mock_config),
)
for tool in agent.tools:
# Assert all descriptions have less than 1024 characters
assert len(tool['function']['description']) < 1024
sonnet_mock_config = Mock()
sonnet_mock_config.model = 'mock_sonnet_model'
llm.config = sonnet_mock_config
agent = agent_class(llm=llm, config=AgentConfig())
sonnect_mock_config = LLMConfig(model='mock_sonnet_model', api_key='test_key')
agent = agent_class(
config=AgentConfig(),
llm_registry=create_llm_registry(sonnect_mock_config),
)
# Assert existence of the detailed tool descriptions that are longer than 1024 characters
if agent_type == 'CodeActAgent':
# This only holds for CodeActAgent
@ -481,10 +497,12 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
assert isinstance(enhanced_messages[5].content[0], ImageContent)
def test_get_system_message():
def test_get_system_message(create_llm_registry):
"""Test that the Agent.get_system_message method returns a SystemMessageAction."""
# Create a mock agent
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
config = AgentConfig()
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
result = agent.get_system_message()

View File

@ -34,7 +34,7 @@ def test_completion_retries_api_connection_error(
]
# Create an LLM instance and call completion
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
@ -70,7 +70,7 @@ def test_completion_max_retries_api_connection_error(
]
# Create an LLM instance and call completion
llm = LLM(config=default_config)
llm = LLM(config=default_config, service_id='test-service')
# The completion should raise an APIConnectionError after exhausting all retries
with pytest.raises(APIConnectionError) as excinfo:

View File

@ -5,11 +5,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.events.action import MessageAction
from openhands.events.event import EventSource
from openhands.events.event_store import EventStore
from openhands.llm.llm_registry import LLMRegistry
from openhands.server.conversation_manager.standalone_conversation_manager import (
StandaloneConversationManager,
)
@ -24,6 +24,7 @@ async def test_auto_generate_title_with_llm():
"""Test auto-generating a title using LLM."""
# Mock dependencies
file_store = InMemoryFileStore()
llm_registry = MagicMock(spec=LLMRegistry)
# Create test conversation with a user message
conversation_id = 'test-conversation'
@ -46,43 +47,33 @@ async def test_auto_generate_title_with_llm():
mock_event_store.search_events.return_value = [user_message]
mock_event_store_cls.return_value = mock_event_store
# Mock the LLM response
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
mock_llm = mock_llm_cls.return_value
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Python Data Analysis Script'
mock_llm.completion.return_value = mock_response
# Mock the LLM registry response
llm_registry.request_extraneous_completion.return_value = (
'Python Data Analysis Script'
)
# Create test settings with LLM config
settings = Settings(
llm_model='test-model',
llm_api_key='test-key',
llm_base_url='test-url',
)
# Create test settings with LLM config
settings = Settings(
llm_model='test-model',
llm_api_key='test-key',
llm_base_url='test-url',
)
# Call the auto_generate_title function directly
title = await auto_generate_title(
conversation_id, user_id, file_store, settings
)
# Call the auto_generate_title function directly
title = await auto_generate_title(
conversation_id, user_id, file_store, settings, llm_registry
)
# Verify the result
assert title == 'Python Data Analysis Script'
# Verify the result
assert title == 'Python Data Analysis Script'
# Verify EventStore was created with the correct parameters
mock_event_store_cls.assert_called_once_with(
conversation_id, file_store, user_id
)
# Verify EventStore was created with the correct parameters
mock_event_store_cls.assert_called_once_with(
conversation_id, file_store, user_id
)
# Verify LLM was called with appropriate parameters
mock_llm_cls.assert_called_once_with(
LLMConfig(
model='test-model',
api_key='test-key',
base_url='test-url',
)
)
mock_llm.completion.assert_called_once()
# Verify LLM registry was called with appropriate parameters
llm_registry.request_extraneous_completion.assert_called_once()
@pytest.mark.asyncio
@ -90,6 +81,7 @@ async def test_auto_generate_title_fallback():
"""Test auto-generating a title with fallback to truncation when LLM fails."""
# Mock dependencies
file_store = InMemoryFileStore()
llm_registry = MagicMock(spec=LLMRegistry)
# Create test conversation with a user message
conversation_id = 'test-conversation'
@ -111,31 +103,29 @@ async def test_auto_generate_title_fallback():
mock_event_store.search_events.return_value = [user_message]
mock_event_store_cls.return_value = mock_event_store
# Mock the LLM to raise an exception
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
mock_llm = mock_llm_cls.return_value
mock_llm.completion.side_effect = Exception('Test error')
# Mock the LLM registry to raise an exception
llm_registry.request_extraneous_completion.side_effect = Exception('Test error')
# Create test settings with LLM config
settings = Settings(
llm_model='test-model',
llm_api_key='test-key',
llm_base_url='test-url',
)
# Create test settings with LLM config
settings = Settings(
llm_model='test-model',
llm_api_key='test-key',
llm_base_url='test-url',
)
# Call the auto_generate_title function directly
title = await auto_generate_title(
conversation_id, user_id, file_store, settings
)
# Call the auto_generate_title function directly
title = await auto_generate_title(
conversation_id, user_id, file_store, settings, llm_registry
)
# Verify the result is a truncated version of the message
assert title == 'This is a very long message th...'
assert len(title) <= 35
# Verify the result is a truncated version of the message
assert title == 'This is a very long message th...'
assert len(title) <= 35
# Verify EventStore was created with the correct parameters
mock_event_store_cls.assert_called_once_with(
conversation_id, file_store, user_id
)
# Verify EventStore was created with the correct parameters
mock_event_store_cls.assert_called_once_with(
conversation_id, file_store, user_id
)
@pytest.mark.asyncio
@ -143,6 +133,7 @@ async def test_auto_generate_title_no_messages():
"""Test auto-generating a title when there are no user messages."""
# Mock dependencies
file_store = InMemoryFileStore()
llm_registry = MagicMock(spec=LLMRegistry)
# Create test conversation with no messages
conversation_id = 'test-conversation'
@ -166,7 +157,7 @@ async def test_auto_generate_title_no_messages():
# Call the auto_generate_title function directly
title = await auto_generate_title(
conversation_id, user_id, file_store, settings
conversation_id, user_id, file_store, settings, llm_registry
)
# Verify the result is empty
@ -186,6 +177,7 @@ async def test_update_conversation_with_title():
sio.emit = AsyncMock()
file_store = InMemoryFileStore()
server_config = MagicMock()
llm_registry = MagicMock(spec=LLMRegistry)
# Create test conversation
conversation_id = 'test-conversation'
@ -222,7 +214,9 @@ async def test_update_conversation_with_title():
AsyncMock(return_value='Generated Title'),
):
# Call the method
await manager._update_conversation_for_event(user_id, conversation_id, settings)
await manager._update_conversation_for_event(
user_id, conversation_id, settings, llm_registry
)
# Verify the title was updated
assert mock_metadata.title == 'Generated Title'

View File

@ -6,6 +6,7 @@ import pytest_asyncio
from openhands.cli import main as cli
from openhands.controller.state.state import State
from openhands.core.config.llm_config import LLMConfig
from openhands.events import EventSource
from openhands.events.action import MessageAction
@ -124,12 +125,14 @@ def mock_config():
'' # Empty string, not starting with 'tvly-'
)
config.search_api_key = search_api_key_mock
config.get_llm_config_from_agent.return_value = LLMConfig(model='model')
# Mock sandbox with volumes attribute to prevent finalize_config issues
config.sandbox = MagicMock()
config.sandbox.volumes = (
None # This prevents finalize_config from overriding workspace_base
)
config.model_name = 'model'
return config
@ -213,7 +216,11 @@ async def test_run_session_without_initial_action(
# Assertions for initialization flow
mock_display_runtime_init.assert_called_once_with('local')
mock_display_animation.assert_called_once()
mock_create_agent.assert_called_once_with(mock_config)
# Check that mock_config is the first parameter to create_agent
mock_create_agent.assert_called_once()
assert mock_create_agent.call_args[0][0] == mock_config, (
'First parameter to create_agent should be mock_config'
)
mock_add_mcp_tools.assert_called_once_with(mock_agent, mock_runtime, mock_memory)
mock_create_runtime.assert_called_once()
mock_create_controller.assert_called_once()

View File

@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from litellm.exceptions import AuthenticationError
from pydantic import SecretStr
from openhands.cli import main as cli
from openhands.core.config.llm_config import LLMConfig
from openhands.events import EventSource
from openhands.events.action import MessageAction
@ -45,11 +47,10 @@ def mock_config():
config.workspace_base = '/test/dir'
# Set up LLM config to use OpenHands provider
llm_config = MagicMock()
llm_config = LLMConfig(model='openhands/o3', api_key=SecretStr('invalid-api-key'))
llm_config.model = 'openhands/o3' # Use OpenHands provider with o3 model
llm_config.api_key = MagicMock()
llm_config.api_key.get_secret_value.return_value = 'invalid-api-key'
config.llm = llm_config
config.get_llm_config.return_value = llm_config
config.get_llm_config_from_agent.return_value = llm_config
# Mock search_api_key with get_secret_value method
search_api_key_mock = MagicMock()

View File

@ -13,6 +13,7 @@ from openhands.core.config.mcp_config import (
from openhands.events.action.mcp import MCPAction
from openhands.events.observation import ErrorObservation
from openhands.events.observation.mcp import MCPObservation
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
@ -23,8 +24,12 @@ class TestCLIRuntimeMCP:
"""Set up test fixtures."""
self.config = OpenHandsConfig()
self.event_stream = MagicMock()
llm_registry = LLMRegistry(config=OpenHandsConfig())
self.runtime = CLIRuntime(
config=self.config, event_stream=self.event_stream, sid='test-session'
config=self.config,
event_stream=self.event_stream,
sid='test-session',
llm_registry=llm_registry,
)
@pytest.mark.asyncio

View File

@ -7,10 +7,18 @@ import pytest
from openhands.core.config import OpenHandsConfig
from openhands.events import EventStream
# Mock LLMRegistry
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
from openhands.storage import get_file_store
# Create a mock LLMRegistry class
class MockLLMRegistry:
def __init__(self, config):
self.config = config
@pytest.fixture
def temp_dir():
"""Create a temporary directory for testing."""
@ -25,7 +33,8 @@ def cli_runtime(temp_dir):
event_stream = EventStream('test', file_store)
config = OpenHandsConfig()
config.workspace_base = temp_dir
runtime = CLIRuntime(config, event_stream)
llm_registry = MockLLMRegistry(config)
runtime = CLIRuntime(config, event_stream, llm_registry)
runtime._runtime_initialized = True # Skip initialization
return runtime

View File

@ -17,6 +17,7 @@ from openhands.core.config.condenser_config import (
StructuredSummaryCondenserConfig,
)
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.message import Message, TextContent
from openhands.core.schema.action import ActionType
from openhands.events.event import Event, EventSource
@ -24,6 +25,7 @@ from openhands.events.observation import BrowserOutputObservation
from openhands.events.observation.agent import AgentCondensationObservation
from openhands.events.observation.observation import Observation
from openhands.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser import Condenser
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
from openhands.memory.condenser.impl import (
@ -38,6 +40,7 @@ from openhands.memory.condenser.impl import (
StructuredSummaryCondenser,
)
from openhands.memory.condenser.impl.pipeline import CondenserPipeline
from openhands.server.services.conversation_stats import ConversationStats
def create_test_event(
@ -56,12 +59,15 @@ def create_test_event(
@pytest.fixture
def mock_llm() -> LLM:
"""Mocks an LLM object with a utility function for setting and resetting response contents in unit tests."""
# Create a real LLMConfig instead of a mock to properly handle SecretStr api_key
real_config = LLMConfig(
model='gpt-4o', api_key='test_key', custom_llm_provider=None
)
# Create a MagicMock for the LLM object
mock_llm = MagicMock(
spec=LLM,
config=MagicMock(
spec=LLMConfig, model='gpt-4o', api_key='test_key', custom_llm_provider=None
),
config=real_config,
metrics=MagicMock(),
)
_mock_content = None
@ -95,6 +101,23 @@ def mock_llm() -> LLM:
return mock_llm
@pytest.fixture
def mock_conversation_stats() -> ConversationStats:
"""Creates a mock ConversationStats service."""
mock_stats = MagicMock(spec=ConversationStats)
return mock_stats
@pytest.fixture
def mock_llm_registry(mock_llm, mock_conversation_stats) -> LLMRegistry:
"""Creates an actual LLMRegistry that returns real LLMs."""
# Create an actual LLMRegistry with a basic OpenHandsConfig
config = OpenHandsConfig()
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
return registry
class RollingCondenserTestHarness:
"""Test harness for rolling condensers.
@ -165,10 +188,10 @@ class RollingCondenserTestHarness:
return ((index - max_size) // target_size) + 1
def test_noop_condenser_from_config():
def test_noop_condenser_from_config(mock_llm_registry):
"""Test that the NoOpCondenser objects can be made from config."""
config = NoOpCondenserConfig()
condenser = Condenser.from_config(config)
condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, NoOpCondenser)
@ -189,11 +212,11 @@ def test_noop_condenser():
assert result == View(events=events)
def test_observation_masking_condenser_from_config():
def test_observation_masking_condenser_from_config(mock_llm_registry):
"""Test that ObservationMaskingCondenser objects can be made from config."""
attention_window = 5
config = ObservationMaskingCondenserConfig(attention_window=attention_window)
condenser = Condenser.from_config(config)
condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, ObservationMaskingCondenser)
assert condenser.attention_window == attention_window
@ -229,11 +252,11 @@ def test_observation_masking_condenser_respects_attention_window():
assert event == condensed_event
def test_browser_output_condenser_from_config():
def test_browser_output_condenser_from_config(mock_llm_registry):
"""Test that BrowserOutputCondenser objects can be made from config."""
attention_window = 5
config = BrowserOutputCondenserConfig(attention_window=attention_window)
condenser = Condenser.from_config(config)
condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, BrowserOutputCondenser)
assert condenser.attention_window == attention_window
@ -271,12 +294,12 @@ def test_browser_output_condenser_respects_attention_window():
assert event == condensed_event
def test_recent_events_condenser_from_config():
def test_recent_events_condenser_from_config(mock_llm_registry):
"""Test that RecentEventsCondenser objects can be made from config."""
max_events = 5
keep_first = True
config = RecentEventsCondenserConfig(keep_first=keep_first, max_events=max_events)
condenser = Condenser.from_config(config)
condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, RecentEventsCondenser)
assert condenser.max_events == max_events
@ -334,14 +357,14 @@ def test_recent_events_condenser():
assert result[2]._message == 'Event 5' # kept from max_events
def test_llm_summarizing_condenser_from_config():
def test_llm_summarizing_condenser_from_config(mock_llm_registry):
"""Test that LLMSummarizingCondenser objects can be made from config."""
config = LLMSummarizingCondenserConfig(
max_size=50,
keep_first=10,
llm_config=LLMConfig(model='gpt-4o', api_key='test_key', caching_prompt=True),
)
condenser = Condenser.from_config(config)
condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, LLMSummarizingCondenser)
assert condenser.llm.config.model == 'gpt-4o'
@ -349,25 +372,33 @@ def test_llm_summarizing_condenser_from_config():
assert condenser.max_size == 50
assert condenser.keep_first == 10
# Since this condenser can't take advantage of caching, we intercept the
# passed config and manually flip the caching prompt to False.
assert not condenser.llm.config.caching_prompt
def test_llm_summarizing_condenser_invalid_config():
def test_llm_summarizing_condenser_invalid_config(mock_llm, mock_llm_registry):
"""Test that LLMSummarizingCondenser raises error when keep_first > max_size."""
pytest.raises(
ValueError,
LLMSummarizingCondenser,
llm=MagicMock(),
llm=mock_llm,
max_size=4,
keep_first=2,
)
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), max_size=0)
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), keep_first=-1)
pytest.raises(
ValueError,
LLMSummarizingCondenser,
llm=mock_llm,
max_size=0,
)
pytest.raises(
ValueError,
LLMSummarizingCondenser,
llm=mock_llm,
keep_first=-1,
)
def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
def test_llm_summarizing_condenser_gives_expected_view_size(
mock_llm, mock_llm_registry
):
"""Test that LLMSummarizingCondenser maintains the correct view size."""
max_size = 10
condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm)
@ -383,12 +414,16 @@ def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
assert len(view) == harness.expected_size(i, max_size)
def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
def test_llm_summarizing_condenser_keeps_first_and_summary_events(
mock_llm, mock_llm_registry
):
"""Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events."""
max_size = 10
keep_first = 3
condenser = LLMSummarizingCondenser(
max_size=max_size, keep_first=keep_first, llm=mock_llm
max_size=max_size,
keep_first=keep_first,
llm=mock_llm,
)
mock_llm.set_mock_response_content('Summary of forgotten events')
@ -412,14 +447,14 @@ def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
assert isinstance(view[keep_first], AgentCondensationObservation)
def test_amortized_forgetting_condenser_from_config():
def test_amortized_forgetting_condenser_from_config(mock_llm_registry):
"""Test that AmortizedForgettingCondenser objects can be made from config."""
max_size = 50
keep_first = 10
config = AmortizedForgettingCondenserConfig(
max_size=max_size, keep_first=keep_first
)
condenser = Condenser.from_config(config)
condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, AmortizedForgettingCondenser)
assert condenser.max_size == max_size
@ -475,7 +510,7 @@ def test_amortized_forgetting_condenser_keeps_first_and_last_events():
assert view[:keep_first] == events[: min(keep_first, i + 1)]
def test_llm_attention_condenser_from_config():
def test_llm_attention_condenser_from_config(mock_llm_registry):
"""Test that LLMAttentionCondenser objects can be made from config."""
config = LLMAttentionCondenserConfig(
max_size=50,
@ -486,37 +521,32 @@ def test_llm_attention_condenser_from_config():
caching_prompt=True,
),
)
condenser = Condenser.from_config(config)
condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, LLMAttentionCondenser)
assert condenser.llm.config.model == 'gpt-4o'
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
assert condenser.max_size == 50
assert condenser.keep_first == 10
# Since this condenser can't take advantage of caching, we intercept the
# passed config and manually flip the caching prompt to False.
assert not condenser.llm.config.caching_prompt
# Create a mock LLM that doesn't support function calling
mock_llm = MagicMock()
mock_llm.is_function_calling_active.return_value = False
# Create a new registry that returns our mock LLM that doesn't support function calling
mock_registry = MagicMock(spec=LLMRegistry)
mock_registry.get_llm.return_value = mock_llm
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config, mock_registry)
def test_llm_attention_condenser_invalid_config():
"""Test that LLMAttentionCondenser raises an error if the configured LLM doesn't support response schema."""
config = LLMAttentionCondenserConfig(
max_size=50,
keep_first=10,
llm_config=LLMConfig(
model='claude-2', # Older model that doesn't support response schema
api_key='test_key',
),
)
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config)
def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
def test_llm_attention_condenser_gives_expected_view_size(mock_llm, mock_llm_registry):
"""Test that the LLMAttentionCondenser gives views of the expected size."""
max_size = 10
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
condenser = LLMAttentionCondenser(
max_size=max_size,
keep_first=0,
llm=mock_llm,
)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@ -534,10 +564,16 @@ def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
assert len(view) == harness.expected_size(i, max_size)
def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
def test_llm_attention_condenser_handles_events_outside_history(
mock_llm, mock_llm_registry
):
"""Test that the LLMAttentionCondenser handles event IDs that aren't from the event history."""
max_size = 2
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
condenser = LLMAttentionCondenser(
max_size=max_size,
keep_first=0,
llm=mock_llm,
)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@ -555,10 +591,14 @@ def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
assert len(view) == harness.expected_size(i, max_size)
def test_llm_attention_condenser_handles_too_many_events(mock_llm):
def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_llm_registry):
"""Test that the LLMAttentionCondenser handles when the response contains too many event IDs."""
max_size = 2
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
condenser = LLMAttentionCondenser(
max_size=max_size,
keep_first=0,
llm=mock_llm,
)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@ -576,12 +616,16 @@ def test_llm_attention_condenser_handles_too_many_events(mock_llm):
assert len(view) == harness.expected_size(i, max_size)
def test_llm_attention_condenser_handles_too_few_events(mock_llm):
def test_llm_attention_condenser_handles_too_few_events(mock_llm, mock_llm_registry):
"""Test that the LLMAttentionCondenser handles when the response contains too few event IDs."""
max_size = 2
# Developer note: We must specify keep_first=0 because
# keep_first (1) >= max_size//2 (1) is invalid.
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
condenser = LLMAttentionCondenser(
max_size=max_size,
keep_first=0,
llm=mock_llm,
)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@ -597,12 +641,14 @@ def test_llm_attention_condenser_handles_too_few_events(mock_llm):
assert len(view) == harness.expected_size(i, max_size)
def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
def test_llm_attention_condenser_handles_keep_first_events(mock_llm, mock_llm_registry):
"""Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size)."""
max_size = 12
keep_first = 4
condenser = LLMAttentionCondenser(
max_size=max_size, keep_first=keep_first, llm=mock_llm
max_size=max_size,
keep_first=keep_first,
llm=mock_llm,
)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@ -620,7 +666,7 @@ def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
assert view[:keep_first] == events[: min(keep_first, i + 1)]
def test_structured_summary_condenser_from_config():
def test_structured_summary_condenser_from_config(mock_llm_registry):
"""Test that StructuredSummaryCondenser objects can be made from config."""
config = StructuredSummaryCondenserConfig(
max_size=50,
@ -631,7 +677,7 @@ def test_structured_summary_condenser_from_config():
caching_prompt=True,
),
)
condenser = Condenser.from_config(config)
condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, StructuredSummaryCondenser)
assert condenser.llm.config.model == 'gpt-4o'
@ -639,40 +685,55 @@ def test_structured_summary_condenser_from_config():
assert condenser.max_size == 50
assert condenser.keep_first == 10
# Since this condenser can't take advantage of caching, we intercept the
# passed config and manually flip the caching prompt to False.
assert not condenser.llm.config.caching_prompt
def test_structured_summary_condenser_invalid_config():
def test_structured_summary_condenser_invalid_config(mock_llm):
"""Test that StructuredSummaryCondenser raises error when keep_first > max_size."""
# Since the condenser only works when function calling is on, we need to
# mock up the check for that.
llm = MagicMock()
llm.is_function_calling_active.return_value = True
mock_llm.is_function_calling_active.return_value = True
pytest.raises(
ValueError,
StructuredSummaryCondenser,
llm=llm,
llm=mock_llm,
max_size=4,
keep_first=2,
)
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, max_size=0)
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, keep_first=-1)
pytest.raises(
ValueError,
StructuredSummaryCondenser,
llm=mock_llm,
max_size=0,
)
pytest.raises(
ValueError,
StructuredSummaryCondenser,
llm=mock_llm,
keep_first=-1,
)
# If all other parameters are good but there's no function calling the
# condenser still counts as improperly configured.
llm.is_function_calling_active.return_value = False
# Create a mock LLM that doesn't support function calling
mock_llm_no_func = MagicMock()
mock_llm_no_func.is_function_calling_active.return_value = False
pytest.raises(
ValueError, StructuredSummaryCondenser, llm=llm, max_size=40, keep_first=2
ValueError,
StructuredSummaryCondenser,
llm=mock_llm_no_func,
max_size=40,
keep_first=2,
)
def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
def test_structured_summary_condenser_gives_expected_view_size(
mock_llm, mock_llm_registry
):
"""Test that StructuredSummaryCondenser maintains the correct view size."""
max_size = 10
mock_llm.is_function_calling_active.return_value = True
condenser = StructuredSummaryCondenser(max_size=max_size, llm=mock_llm)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@ -686,12 +747,17 @@ def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
assert len(view) == harness.expected_size(i, max_size)
def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
def test_structured_summary_condenser_keeps_first_and_summary_events(
mock_llm, mock_llm_registry
):
"""Test that the StructuredSummaryCondenser appropriately maintains the event prefix and any summary events."""
max_size = 10
keep_first = 3
mock_llm.is_function_calling_active.return_value = True
condenser = StructuredSummaryCondenser(
max_size=max_size, keep_first=keep_first, llm=mock_llm
max_size=max_size,
keep_first=keep_first,
llm=mock_llm,
)
mock_llm.set_mock_response_content('Summary of forgotten events')
@ -715,7 +781,7 @@ def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
assert isinstance(view[keep_first], AgentCondensationObservation)
def test_condenser_pipeline_from_config():
def test_condenser_pipeline_from_config(mock_llm_registry):
"""Test that CondenserPipeline condensers can be created from configuration objects."""
config = CondenserPipelineConfig(
condensers=[
@ -728,7 +794,7 @@ def test_condenser_pipeline_from_config():
),
]
)
condenser = Condenser.from_config(config)
condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, CondenserPipeline)
assert len(condenser.condensers) == 3

View File

@ -0,0 +1,490 @@
import base64
import pickle
from unittest.mock import patch
import pytest
from openhands.core.config import LLMConfig, OpenHandsConfig
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
from openhands.llm.metrics import Metrics
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_file_store():
"""Create a mock file store for testing."""
return InMemoryFileStore({})
@pytest.fixture
def conversation_stats(mock_file_store):
"""Create a ConversationStats instance for testing."""
return ConversationStats(
file_store=mock_file_store,
conversation_id='test-conversation-id',
user_id='test-user-id',
)
@pytest.fixture
def mock_llm_registry():
"""Create a mock LLM registry that properly simulates LLM registration."""
config = OpenHandsConfig()
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
return registry
@pytest.fixture
def connected_registry_and_stats(mock_llm_registry, conversation_stats):
"""Connect the LLMRegistry and ConversationStats properly."""
# Subscribe to LLM registry events to track metrics
mock_llm_registry.subscribe(conversation_stats.register_llm)
return mock_llm_registry, conversation_stats
def test_conversation_stats_initialization(conversation_stats):
"""Test that ConversationStats initializes correctly."""
assert conversation_stats.conversation_id == 'test-conversation-id'
assert conversation_stats.user_id == 'test-user-id'
assert conversation_stats.service_to_metrics == {}
assert isinstance(conversation_stats.restored_metrics, dict)
def test_save_metrics(conversation_stats, mock_file_store):
"""Test that metrics are saved correctly."""
# Add a service with metrics
service_id = 'test-service'
metrics = Metrics(model_name='gpt-4')
metrics.add_cost(0.05)
conversation_stats.service_to_metrics[service_id] = metrics
# Save metrics
conversation_stats.save_metrics()
# Verify that metrics were saved to the file store
try:
# Verify the saved content can be decoded and unpickled
encoded = mock_file_store.read(conversation_stats.metrics_path)
pickled = base64.b64decode(encoded)
restored = pickle.loads(pickled)
assert service_id in restored
assert restored[service_id].accumulated_cost == 0.05
except FileNotFoundError:
pytest.fail(f'File not found: {conversation_stats.metrics_path}')
def test_maybe_restore_metrics(mock_file_store):
"""Test that metrics are restored correctly."""
# Create metrics to save
service_id = 'test-service'
metrics = Metrics(model_name='gpt-4')
metrics.add_cost(0.1)
service_to_metrics = {service_id: metrics}
# Serialize and save metrics
pickled = pickle.dumps(service_to_metrics)
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
# Create a new ConversationStats with pre-populated file store
conversation_id = 'test-conversation-id'
user_id = 'test-user-id'
# Get the correct path using the same function as ConversationStats
from openhands.storage.locations import get_conversation_stats_filename
metrics_path = get_conversation_stats_filename(conversation_id, user_id)
# Write to the correct path
mock_file_store.write(metrics_path, serialized_metrics)
# Create ConversationStats which should restore metrics
stats = ConversationStats(
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
)
# Verify metrics were restored
assert service_id in stats.restored_metrics
assert stats.restored_metrics[service_id].accumulated_cost == 0.1
def test_get_combined_metrics(conversation_stats):
"""Test that combined metrics are calculated correctly."""
# Add multiple services with metrics
service1 = 'service1'
metrics1 = Metrics(model_name='gpt-4')
metrics1.add_cost(0.05)
metrics1.add_token_usage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=8000,
response_id='resp1',
)
service2 = 'service2'
metrics2 = Metrics(model_name='gpt-3.5')
metrics2.add_cost(0.02)
metrics2.add_token_usage(
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4000,
response_id='resp2',
)
conversation_stats.service_to_metrics[service1] = metrics1
conversation_stats.service_to_metrics[service2] = metrics2
# Get combined metrics
combined = conversation_stats.get_combined_metrics()
# Verify combined metrics
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
assert (
combined.accumulated_token_usage.context_window == 8000
) # max of 8000 and 4000
def test_get_metrics_for_service(conversation_stats):
"""Test that metrics for a specific service are retrieved correctly."""
# Add a service with metrics
service_id = 'test-service'
metrics = Metrics(model_name='gpt-4')
metrics.add_cost(0.05)
conversation_stats.service_to_metrics[service_id] = metrics
# Get metrics for the service
retrieved_metrics = conversation_stats.get_metrics_for_service(service_id)
# Verify metrics
assert retrieved_metrics.accumulated_cost == 0.05
assert retrieved_metrics is metrics # Should be the same object
# Test getting metrics for non-existent service
# Use a specific exception message pattern instead of a blind Exception
with pytest.raises(Exception, match='LLM service does not exist'):
conversation_stats.get_metrics_for_service('non-existent-service')
def test_register_llm_with_new_service(conversation_stats):
"""Test registering a new LLM service."""
# Create a real LLM instance with a mock config
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Patch the LLM class to avoid actual API calls
with patch('openhands.llm.llm.litellm_completion'):
llm = LLM(service_id='new-service', config=llm_config)
# Create a registry event
service_id = 'new-service'
event = RegistryEvent(llm=llm, service_id=service_id)
# Register the LLM
conversation_stats.register_llm(event)
# Verify the service was registered
assert service_id in conversation_stats.service_to_metrics
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
def test_register_llm_with_restored_metrics(conversation_stats):
"""Test registering an LLM service with restored metrics."""
# Create restored metrics
service_id = 'restored-service'
restored_metrics = Metrics(model_name='gpt-4')
restored_metrics.add_cost(0.1)
conversation_stats.restored_metrics = {service_id: restored_metrics}
# Create a real LLM instance with a mock config
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Patch the LLM class to avoid actual API calls
with patch('openhands.llm.llm.litellm_completion'):
llm = LLM(service_id=service_id, config=llm_config)
# Create a registry event
event = RegistryEvent(llm=llm, service_id=service_id)
# Register the LLM
conversation_stats.register_llm(event)
# Verify the service was registered with restored metrics
assert service_id in conversation_stats.service_to_metrics
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
assert llm.metrics.accumulated_cost == 0.1 # Restored cost
# Verify the specific service was removed from restored_metrics
assert service_id not in conversation_stats.restored_metrics
assert hasattr(
conversation_stats, 'restored_metrics'
) # The dict should still exist
def test_llm_registry_notifications(connected_registry_and_stats):
"""Test that LLM registry notifications update conversation stats."""
mock_llm_registry, conversation_stats = connected_registry_and_stats
# Create a new LLM through the registry
service_id = 'test-service'
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Get LLM from registry (this should trigger the notification)
llm = mock_llm_registry.get_llm(service_id, llm_config)
# Verify the service was registered in conversation stats
assert service_id in conversation_stats.service_to_metrics
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
# Add some metrics to the LLM
llm.metrics.add_cost(0.05)
llm.metrics.add_token_usage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=8000,
response_id='resp1',
)
# Verify the metrics are reflected in conversation stats
assert conversation_stats.service_to_metrics[service_id].accumulated_cost == 0.05
assert (
conversation_stats.service_to_metrics[
service_id
].accumulated_token_usage.prompt_tokens
== 100
)
assert (
conversation_stats.service_to_metrics[
service_id
].accumulated_token_usage.completion_tokens
== 50
)
# Get combined metrics and verify
combined = conversation_stats.get_combined_metrics()
assert combined.accumulated_cost == 0.05
assert combined.accumulated_token_usage.prompt_tokens == 100
assert combined.accumulated_token_usage.completion_tokens == 50
def test_multiple_llm_services(connected_registry_and_stats):
"""Test tracking metrics for multiple LLM services."""
mock_llm_registry, conversation_stats = connected_registry_and_stats
# Create multiple LLMs through the registry
service1 = 'service1'
service2 = 'service2'
llm_config1 = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
llm_config2 = LLMConfig(
model='gpt-3.5-turbo',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Get LLMs from registry (this should trigger notifications)
llm1 = mock_llm_registry.get_llm(service1, llm_config1)
llm2 = mock_llm_registry.get_llm(service2, llm_config2)
# Add different metrics to each LLM
llm1.metrics.add_cost(0.05)
llm1.metrics.add_token_usage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=8000,
response_id='resp1',
)
llm2.metrics.add_cost(0.02)
llm2.metrics.add_token_usage(
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4000,
response_id='resp2',
)
# Verify services were registered in conversation stats
assert service1 in conversation_stats.service_to_metrics
assert service2 in conversation_stats.service_to_metrics
# Verify individual metrics
assert conversation_stats.service_to_metrics[service1].accumulated_cost == 0.05
assert conversation_stats.service_to_metrics[service2].accumulated_cost == 0.02
# Get combined metrics and verify
combined = conversation_stats.get_combined_metrics()
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
assert (
combined.accumulated_token_usage.context_window == 8000
) # max of 8000 and 4000
def test_register_llm_with_multiple_restored_services_bug(conversation_stats):
"""Test that reproduces the bug where del self.restored_metrics deletes entire dict instead of specific service."""
# Create restored metrics for multiple services
service_id_1 = 'service-1'
service_id_2 = 'service-2'
restored_metrics_1 = Metrics(model_name='gpt-4')
restored_metrics_1.add_cost(0.1)
restored_metrics_2 = Metrics(model_name='gpt-3.5')
restored_metrics_2.add_cost(0.05)
# Set up restored metrics for both services
conversation_stats.restored_metrics = {
service_id_1: restored_metrics_1,
service_id_2: restored_metrics_2,
}
# Create LLM configs
llm_config_1 = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
llm_config_2 = LLMConfig(
model='gpt-3.5-turbo',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Patch the LLM class to avoid actual API calls
with patch('openhands.llm.llm.litellm_completion'):
# Register first LLM
llm_1 = LLM(service_id=service_id_1, config=llm_config_1)
event_1 = RegistryEvent(llm=llm_1, service_id=service_id_1)
conversation_stats.register_llm(event_1)
# Verify first service was registered with restored metrics
assert service_id_1 in conversation_stats.service_to_metrics
assert llm_1.metrics.accumulated_cost == 0.1
# After registering first service, restored_metrics should still contain service_id_2
assert service_id_2 in conversation_stats.restored_metrics
# Register second LLM - this should also work with restored metrics
llm_2 = LLM(service_id=service_id_2, config=llm_config_2)
event_2 = RegistryEvent(llm=llm_2, service_id=service_id_2)
conversation_stats.register_llm(event_2)
# Verify second service was registered with restored metrics
assert service_id_2 in conversation_stats.service_to_metrics
assert llm_2.metrics.accumulated_cost == 0.05
# After both services are registered, restored_metrics should be empty
assert len(conversation_stats.restored_metrics) == 0
def test_save_and_restore_workflow(mock_file_store):
"""Test the full workflow of saving and restoring metrics."""
# Create initial conversation stats
conversation_id = 'test-conversation-id'
user_id = 'test-user-id'
stats1 = ConversationStats(
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
)
# Add a service with metrics
service_id = 'test-service'
metrics = Metrics(model_name='gpt-4')
metrics.add_cost(0.05)
metrics.add_token_usage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=8000,
response_id='resp1',
)
stats1.service_to_metrics[service_id] = metrics
# Save metrics
stats1.save_metrics()
# Create a new conversation stats instance that should restore the metrics
stats2 = ConversationStats(
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
)
# Verify metrics were restored
assert service_id in stats2.restored_metrics
assert stats2.restored_metrics[service_id].accumulated_cost == 0.05
assert (
stats2.restored_metrics[service_id].accumulated_token_usage.prompt_tokens == 100
)
assert (
stats2.restored_metrics[service_id].accumulated_token_usage.completion_tokens
== 50
)
# Create a real LLM instance with a mock config
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Patch the LLM class to avoid actual API calls
with patch('openhands.llm.llm.litellm_completion'):
llm = LLM(service_id=service_id, config=llm_config)
# Create a registry event
event = RegistryEvent(llm=llm, service_id=service_id)
# Register the LLM to trigger restoration
stats2.register_llm(event)
# Verify metrics were applied to the LLM
assert llm.metrics.accumulated_cost == 0.05
assert llm.metrics.accumulated_token_usage.prompt_tokens == 100
assert llm.metrics.accumulated_token_usage.completion_tokens == 50

View File

@ -1,6 +1,6 @@
"""Tests for the conversation summary generator."""
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
@ -11,55 +11,51 @@ from openhands.utils.conversation_summary import generate_conversation_title
@pytest.mark.asyncio
async def test_generate_conversation_title_empty_message():
"""Test that an empty message returns None."""
result = await generate_conversation_title('', MagicMock())
mock_llm_registry = MagicMock()
mock_llm_config = LLMConfig(model='test-model')
result = await generate_conversation_title('', mock_llm_config, mock_llm_registry)
assert result is None
result = await generate_conversation_title(' ', MagicMock())
result = await generate_conversation_title(
' ', mock_llm_config, mock_llm_registry
)
assert result is None
@pytest.mark.asyncio
async def test_generate_conversation_title_success():
"""Test successful title generation."""
# Create a proper mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Generated Title'
# Create a mock LLM registry that returns a title
mock_llm_registry = MagicMock()
mock_llm_registry.request_extraneous_completion.return_value = 'Generated Title'
# Create a mock LLM instance with a synchronous completion method
mock_llm = MagicMock()
mock_llm.completion = MagicMock(return_value=mock_response)
mock_llm_config = LLMConfig(model='test-model')
# Patch the LLM class to return our mock
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
result = await generate_conversation_title(
'Can you help me with Python?', LLMConfig(model='test-model')
)
result = await generate_conversation_title(
'Can you help me with Python?', mock_llm_config, mock_llm_registry
)
assert result == 'Generated Title'
# Verify the mock was called with the expected arguments
mock_llm.completion.assert_called_once()
mock_llm_registry.request_extraneous_completion.assert_called_once()
@pytest.mark.asyncio
async def test_generate_conversation_title_long_title():
"""Test that long titles are truncated."""
# Create a proper mock response with a long title
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[
0
].message.content = 'This is a very long title that should be truncated because it exceeds the maximum length'
# Create a mock LLM registry that returns a long title
mock_llm_registry = MagicMock()
mock_llm_registry.request_extraneous_completion.return_value = 'This is a very long title that should be truncated because it exceeds the maximum length'
# Create a mock LLM instance with a synchronous completion method
mock_llm = MagicMock()
mock_llm.completion = MagicMock(return_value=mock_response)
mock_llm_config = LLMConfig(model='test-model')
# Patch the LLM class to return our mock
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
result = await generate_conversation_title(
'Can you help me with Python?', LLMConfig(model='test-model'), max_length=30
)
result = await generate_conversation_title(
'Can you help me with Python?',
mock_llm_config,
mock_llm_registry,
max_length=30,
)
# Verify the title is truncated correctly
assert len(result) <= 30
@ -69,15 +65,17 @@ async def test_generate_conversation_title_long_title():
@pytest.mark.asyncio
async def test_generate_conversation_title_exception():
"""Test that exceptions are handled gracefully."""
# Create a mock LLM instance with a synchronous completion method that raises an exception
mock_llm = MagicMock()
mock_llm.completion = MagicMock(side_effect=Exception('Test error'))
# Create a mock LLM registry that raises an exception
mock_llm_registry = MagicMock()
mock_llm_registry.request_extraneous_completion.side_effect = Exception(
'Test error'
)
# Patch the LLM class to return our mock
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
result = await generate_conversation_title(
'Can you help me with Python?', LLMConfig(model='test-model')
)
mock_llm_config = LLMConfig(model='test-model')
result = await generate_conversation_title(
'Can you help me with Python?', mock_llm_config, mock_llm_registry
)
# Verify that None is returned when an exception occurs
assert result is None

View File

@ -4,6 +4,7 @@ import pytest
from openhands.core.config import OpenHandsConfig
from openhands.events import EventStream
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
@ -40,12 +41,17 @@ def event_stream():
return MagicMock(spec=EventStream)
@pytest.fixture
def llm_registry():
return MagicMock(spec=LLMRegistry)
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
def test_container_stopped_when_keep_runtime_alive_false(
mock_stop_containers, mock_docker_client, config, event_stream
mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
):
# Arrange
runtime = DockerRuntime(config, event_stream, sid='test-sid')
runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
runtime.container = mock_docker_client.containers.get.return_value
# Act
@ -57,11 +63,11 @@ def test_container_stopped_when_keep_runtime_alive_false(
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
def test_container_not_stopped_when_keep_runtime_alive_true(
mock_stop_containers, mock_docker_client, config, event_stream
mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
):
# Arrange
config.sandbox.keep_runtime_alive = True
runtime = DockerRuntime(config, event_stream, sid='test-sid')
runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
runtime.container = mock_docker_client.containers.get.return_value
# Act

View File

@ -0,0 +1,178 @@
from __future__ import annotations
import unittest
from unittest.mock import MagicMock, patch
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
class TestLLMRegistry(unittest.TestCase):
def setUp(self):
"""Set up test environment before each test."""
# Create a basic LLM config for testing
self.llm_config = LLMConfig(model='test-model')
# Create a basic OpenHands config for testing
self.config = OpenHandsConfig(
llms={'llm': self.llm_config}, default_agent='CodeActAgent'
)
# Create a registry for testing
self.registry = LLMRegistry(config=self.config)
def test_get_llm_creates_new_llm(self):
"""Test that get_llm creates a new LLM when service doesn't exist."""
service_id = 'test-service'
# Mock the _create_new_llm method to avoid actual LLM initialization
with patch.object(self.registry, '_create_new_llm') as mock_create:
mock_llm = MagicMock()
mock_llm.config = self.llm_config
mock_create.return_value = mock_llm
# Get LLM for the first time
llm = self.registry.get_llm(service_id, self.llm_config)
# Verify LLM was created and stored
self.assertEqual(llm, mock_llm)
mock_create.assert_called_once_with(
config=self.llm_config, service_id=service_id
)
def test_get_llm_returns_existing_llm(self):
"""Test that get_llm returns existing LLM when service already exists."""
service_id = 'test-service'
# Mock the _create_new_llm method to avoid actual LLM initialization
with patch.object(self.registry, '_create_new_llm') as mock_create:
mock_llm = MagicMock()
mock_llm.config = self.llm_config
mock_create.return_value = mock_llm
# Get LLM for the first time
llm1 = self.registry.get_llm(service_id, self.llm_config)
# Manually add to registry to simulate existing LLM
self.registry.service_to_llm[service_id] = mock_llm
# Get LLM for the second time - should return the same instance
llm2 = self.registry.get_llm(service_id, self.llm_config)
# Verify same LLM instance is returned
self.assertEqual(llm1, llm2)
self.assertEqual(llm1, mock_llm)
# Verify _create_new_llm was only called once
mock_create.assert_called_once()
def test_get_llm_with_different_config_raises_error(self):
"""Test that requesting same service ID with different config raises an error."""
service_id = 'test-service'
different_config = LLMConfig(model='different-model')
# Manually add an LLM to the registry to simulate existing service
mock_llm = MagicMock()
mock_llm.config = self.llm_config
self.registry.service_to_llm[service_id] = mock_llm
# Attempt to get LLM with different config should raise ValueError
with self.assertRaises(ValueError) as context:
self.registry.get_llm(service_id, different_config)
self.assertIn('Requesting same service ID', str(context.exception))
self.assertIn('with different config', str(context.exception))
def test_get_llm_without_config_raises_error(self):
"""Test that requesting new LLM without config raises an error."""
service_id = 'test-service'
# Attempt to get LLM without providing config should raise ValueError
with self.assertRaises(ValueError) as context:
self.registry.get_llm(service_id, None)
self.assertIn(
'Requesting new LLM without specifying LLM config', str(context.exception)
)
def test_request_extraneous_completion(self):
"""Test that requesting an extraneous completion creates a new LLM if needed."""
service_id = 'extraneous-service'
messages = [{'role': 'user', 'content': 'Hello, world!'}]
# Mock the _create_new_llm method to avoid actual LLM initialization
with patch.object(self.registry, '_create_new_llm') as mock_create:
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = ' Hello from the LLM! '
mock_llm.completion.return_value = mock_response
mock_create.return_value = mock_llm
# Mock the side effect to add the LLM to the registry
def side_effect(*args, **kwargs):
self.registry.service_to_llm[service_id] = mock_llm
return mock_llm
mock_create.side_effect = side_effect
# Request a completion
response = self.registry.request_extraneous_completion(
service_id=service_id,
llm_config=self.llm_config,
messages=messages,
)
# Verify the response (should be stripped)
self.assertEqual(response, 'Hello from the LLM!')
# Verify that _create_new_llm was called with correct parameters
mock_create.assert_called_once_with(
config=self.llm_config, service_id=service_id, with_listener=False
)
# Verify completion was called with correct messages
mock_llm.completion.assert_called_once_with(messages=messages)
def test_get_active_llm(self):
"""Test that get_active_llm returns the active agent LLM."""
active_llm = self.registry.get_active_llm()
self.assertEqual(active_llm, self.registry.active_agent_llm)
def test_subscribe_and_notify(self):
"""Test the subscription and notification system."""
events_received = []
def callback(event: RegistryEvent):
events_received.append(event)
# Subscribe to events
self.registry.subscribe(callback)
# Should receive notification for the active agent LLM
self.assertEqual(len(events_received), 1)
self.assertEqual(events_received[0].llm, self.registry.active_agent_llm)
self.assertEqual(
events_received[0].service_id, self.registry.active_agent_llm.service_id
)
# Test that the subscriber is set correctly
self.assertIsNotNone(self.registry.subscriber)
# Test notify method directly with a mock event
with patch.object(self.registry, 'subscriber') as mock_subscriber:
mock_event = MagicMock()
self.registry.notify(mock_event)
mock_subscriber.assert_called_once_with(mock_event)
def test_registry_has_unique_id(self):
"""Test that each registry instance has a unique ID."""
registry2 = LLMRegistry(config=self.config)
self.assertNotEqual(self.registry.registry_id, registry2.registry_id)
self.assertTrue(len(self.registry.registry_id) > 0)
self.assertTrue(len(registry2.registry_id) > 0)
if __name__ == '__main__':
unittest.main()

View File

@ -12,6 +12,8 @@ from openhands.core.config.mcp_config import (
MCPSSEServerConfig,
MCPStdioServerConfig,
)
from openhands.llm.llm_registry import LLMRegistry
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.session.session import Session
from openhands.storage.memory import InMemoryFileStore
@ -428,6 +430,8 @@ async def test_session_preserves_env_mcp_config(monkeypatch):
file_store=InMemoryFileStore({}),
config=config,
sio=AsyncMock(),
llm_registry=LLMRegistry(config=OpenHandsConfig()),
convo_stats=ConversationStats(None, 'test-sid', None),
)
# Create empty settings

View File

@ -8,7 +8,8 @@ import pytest
from mcp import McpError
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController, AgentState
from openhands.controller.agent_controller import AgentController
from openhands.core.schema import AgentState
from openhands.events.action.mcp import MCPAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import EventSource
@ -17,6 +18,8 @@ from openhands.events.stream import EventStream
from openhands.mcp.client import MCPClient
from openhands.mcp.tool import MCPClientTool
from openhands.mcp.utils import call_tool_mcp
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore
class MockConfig:
@ -34,6 +37,11 @@ class MockLLM:
self.config = MockConfig()
@pytest.fixture
def convo_stats():
return ConversationStats(None, 'convo-id', None)
class MockAgent(Agent):
"""Mock agent for testing."""
@ -53,7 +61,7 @@ class MockAgent(Agent):
@pytest.mark.asyncio
async def test_mcp_tool_timeout_error_handling():
async def test_mcp_tool_timeout_error_handling(convo_stats):
"""Test that verifies MCP tool timeout errors are properly handled and returned as observations."""
# Create a mock MCPClient
mock_client = mock.MagicMock(spec=MCPClient)
@ -80,7 +88,7 @@ async def test_mcp_tool_timeout_error_handling():
mock_client.tool_map = {'test_tool': mock_tool}
# Create a mock file store
mock_file_store = mock.MagicMock()
mock_file_store = InMemoryFileStore({})
# Create a mock event stream
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
@ -90,13 +98,12 @@ async def test_mcp_tool_timeout_error_handling():
# Create a mock agent controller
controller = AgentController(
sid='test-session',
file_store=mock_file_store,
user_id='test-user',
agent=agent,
event_stream=event_stream,
convo_stats=convo_stats,
iteration_delta=10,
budget_per_task_delta=None,
sid='test-session',
)
# Set up the agent state
@ -143,7 +150,7 @@ async def test_mcp_tool_timeout_error_handling():
@pytest.mark.asyncio
async def test_mcp_tool_timeout_agent_continuation():
async def test_mcp_tool_timeout_agent_continuation(convo_stats):
"""Test that verifies the agent can continue processing after an MCP tool timeout."""
# Create a mock MCPClient
mock_client = mock.MagicMock(spec=MCPClient)
@ -170,7 +177,7 @@ async def test_mcp_tool_timeout_agent_continuation():
mock_client.tool_map = {'test_tool': mock_tool}
# Create a mock file store
mock_file_store = mock.MagicMock()
mock_file_store = InMemoryFileStore({})
# Create a mock event stream
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
@ -180,13 +187,12 @@ async def test_mcp_tool_timeout_agent_continuation():
# Create a mock agent controller
controller = AgentController(
sid='test-session',
file_store=mock_file_store,
user_id='test-user',
agent=agent,
event_stream=event_stream,
convo_stats=convo_stats,
iteration_delta=10,
budget_per_task_delta=None,
sid='test-session',
)
# Set up the agent state

View File

@ -21,11 +21,13 @@ from openhands.events.observation.agent import (
from openhands.events.serialization.observation import observation_from_dict
from openhands.events.stream import EventStream
from openhands.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
from openhands.utils.prompt import (
@ -42,6 +44,12 @@ def file_store():
return InMemoryFileStore({})
@pytest.fixture
def mock_llm_registry(file_store):
"""Create a mock LLMRegistry for testing."""
return MagicMock(spec=LLMRegistry)
@pytest.fixture
def event_stream(file_store):
"""Create a test event stream."""
@ -90,24 +98,29 @@ def mock_agent():
@pytest.mark.asyncio
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent):
async def test_memory_on_event_exception_handling(
memory, event_stream, mock_agent, mock_llm_registry
):
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
# Create a mock runtime
runtime = MagicMock(spec=ActionExecutionClient)
runtime.event_stream = event_stream
# Mock Memory method to raise an exception
with patch.object(
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
with (
patch.object(
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
),
patch('openhands.core.main.create_agent', return_value=mock_agent),
):
state = await run_controller(
config=OpenHandsConfig(),
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
llm_registry=mock_llm_registry,
)
# Verify that the controller's last error was set
@ -118,7 +131,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream, mock_age
@pytest.mark.asyncio
async def test_memory_on_workspace_context_recall_exception_handling(
memory, event_stream, mock_agent
memory, event_stream, mock_agent, mock_llm_registry
):
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
# Create a mock runtime
@ -126,19 +139,22 @@ async def test_memory_on_workspace_context_recall_exception_handling(
runtime.event_stream = event_stream
# Mock Memory._on_workspace_context_recall to raise an exception
with patch.object(
memory,
'_find_microagent_knowledge',
side_effect=Exception('Test error from _find_microagent_knowledge'),
with (
patch.object(
memory,
'_find_microagent_knowledge',
side_effect=Exception('Test error from _find_microagent_knowledge'),
),
patch('openhands.core.main.create_agent', return_value=mock_agent),
):
state = await run_controller(
config=OpenHandsConfig(),
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
llm_registry=mock_llm_registry,
)
# Verify that the controller's last error was set
@ -593,12 +609,14 @@ REPOSITORY INSTRUCTIONS: This is the second test repository.
@pytest.mark.asyncio
async def test_conversation_instructions_plumbed_to_memory(
mock_agent, event_stream, file_store
mock_agent, event_stream, file_store, mock_llm_registry
):
# Setup
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=ConversationStats(file_store, 'test-session', None),
)
# Create a mock runtime and set it up

View File

@ -3,26 +3,30 @@ from litellm import ModelResponse
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.events.action import MessageAction
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
@pytest.fixture
def mock_llm():
llm = LLM(
LLMConfig(
model='claude-3-5-sonnet-20241022',
api_key='fake',
caching_prompt=True,
)
def llm_config():
return LLMConfig(
model='claude-3-5-sonnet-20241022',
api_key='fake',
caching_prompt=True,
)
return llm
@pytest.fixture
def codeact_agent(mock_llm):
def llm_registry():
registry = LLMRegistry(config=OpenHandsConfig())
return registry
@pytest.fixture
def codeact_agent(llm_registry):
config = AgentConfig()
agent = CodeActAgent(mock_llm, config)
agent = CodeActAgent(config, llm_registry)
return agent

View File

@ -12,14 +12,26 @@ from openhands.events.observation import NullObservation, Observation
from openhands.events.stream import EventStream
from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType
from openhands.integrations.service_types import AuthenticationError, Repository
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.base import Runtime
from openhands.storage import get_file_store
class TestRuntime(Runtime):
class MockRuntime(Runtime):
"""A concrete implementation of Runtime for testing"""
def __init__(self, *args, **kwargs):
# Ensure llm_registry is provided if not already in kwargs
if 'llm_registry' not in kwargs and len(args) < 3:
# Create a mock LLMRegistry if not provided
config = (
kwargs.get('config')
if 'config' in kwargs
else args[0]
if args
else OpenHandsConfig()
)
kwargs['llm_registry'] = LLMRegistry(config=config)
super().__init__(*args, **kwargs)
self.run_action_calls = []
self._execute_shell_fn_git_handler = MagicMock(
@ -89,9 +101,11 @@ def runtime(temp_dir):
)
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
llm_registry = LLMRegistry(config=config)
runtime = MockRuntime(
config=config,
event_stream=event_stream,
llm_registry=llm_registry,
sid='test',
user_id='test_user',
git_provider_tokens=git_provider_tokens,
@ -119,7 +133,7 @@ async def test_export_latest_git_provider_tokens_no_user_id(temp_dir):
config = OpenHandsConfig()
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(config=config, event_stream=event_stream, sid='test')
runtime = MockRuntime(config=config, event_stream=event_stream, sid='test')
# Create a command that would normally trigger token export
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
@ -137,7 +151,7 @@ async def test_export_latest_git_provider_tokens_no_token_ref(temp_dir):
config = OpenHandsConfig()
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user'
)
@ -177,7 +191,7 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
)
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
runtime = MockRuntime(
config=config,
event_stream=event_stream,
sid='test',
@ -225,7 +239,7 @@ async def test_clone_or_init_repo_no_repo_init_git_in_empty_workspace(temp_dir):
config.init_git_in_empty_workspace = True
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id=None
)
@ -249,7 +263,7 @@ async def test_clone_or_init_repo_no_repo_no_user_id_with_workspace_base(temp_di
config.workspace_base = '/some/path' # Set workspace_base
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id=None
)
@ -267,7 +281,7 @@ async def test_clone_or_init_repo_auth_error(temp_dir):
config = OpenHandsConfig()
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user'
)
@ -298,7 +312,7 @@ async def test_clone_or_init_repo_github_with_token(temp_dir, monkeypatch):
{ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))}
)
runtime = TestRuntime(
runtime = MockRuntime(
config=config,
event_stream=event_stream,
sid='test',
@ -336,7 +350,7 @@ async def test_clone_or_init_repo_github_no_token(temp_dir, monkeypatch):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user'
)
@ -371,7 +385,7 @@ async def test_clone_or_init_repo_gitlab_with_token(temp_dir, monkeypatch):
{ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))}
)
runtime = TestRuntime(
runtime = MockRuntime(
config=config,
event_stream=event_stream,
sid='test',
@ -410,7 +424,7 @@ async def test_clone_or_init_repo_with_branch(temp_dir, monkeypatch):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
runtime = TestRuntime(
runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user'
)

View File

@ -9,10 +9,12 @@ import pytest
from openhands.core.config import OpenHandsConfig, SandboxConfig
from openhands.events import EventStream
from openhands.integrations.service_types import ProviderType, Repository
from openhands.llm.llm_registry import LLMRegistry
from openhands.microagent.microagent import (
RepoMicroagent,
)
from openhands.runtime.base import Runtime
from openhands.storage import get_file_store
class MockRuntime(Runtime):
@ -24,12 +26,21 @@ class MockRuntime(Runtime):
config.workspace_mount_path_in_sandbox = str(workspace_root)
config.sandbox = SandboxConfig()
# Create a mock event stream
# Create a mock event stream and file store
file_store = get_file_store('local', str(workspace_root))
event_stream = MagicMock(spec=EventStream)
event_stream.file_store = file_store
# Create a mock LLM registry
llm_registry = LLMRegistry(config)
# Initialize the parent class properly
super().__init__(
config=config, event_stream=event_stream, sid='test', git_provider_tokens={}
config=config,
event_stream=event_stream,
llm_registry=llm_registry,
sid='test',
git_provider_tokens={},
)
self._workspace_root = workspace_root

View File

@ -595,7 +595,7 @@ async def test_check_usertask(
analyzer = InvariantAnalyzer(event_stream)
mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
mock_litellm_completion.return_value = mock_response
analyzer.guardrail_llm = LLM(config=default_config)
analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
analyzer.check_browsing_alignment = True
data = [
(MessageAction(usertask), EventSource.USER),
@ -657,7 +657,7 @@ async def test_check_fillaction(
analyzer = InvariantAnalyzer(event_stream)
mock_response = {'choices': [{'message': {'content': is_harmful}}]}
mock_litellm_completion.return_value = mock_response
analyzer.guardrail_llm = LLM(config=default_config)
analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
analyzer.check_browsing_alignment = True
data = [
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),

View File

@ -7,7 +7,9 @@ from litellm.exceptions import (
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.session import Session
from openhands.storage.memory import InMemoryFileStore
@ -33,10 +35,28 @@ def default_llm_config():
)
@pytest.fixture
def llm_registry():
config = OpenHandsConfig()
return LLMRegistry(config=config)
@pytest.fixture
def conversation_stats():
file_store = InMemoryFileStore({})
return ConversationStats(
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
)
@pytest.mark.asyncio
@patch('openhands.llm.llm.litellm_completion')
async def test_notify_on_llm_retry(
mock_litellm_completion, mock_sio, default_llm_config
mock_litellm_completion,
mock_sio,
default_llm_config,
llm_registry,
conversation_stats,
):
config = OpenHandsConfig()
config.set_llm_config(default_llm_config)
@ -44,6 +64,8 @@ async def test_notify_on_llm_retry(
sid='..sid..',
file_store=InMemoryFileStore({}),
config=config,
llm_registry=llm_registry,
convo_stats=conversation_stats,
sio=mock_sio,
user_id='..uid..',
)
@ -56,12 +78,20 @@ async def test_notify_on_llm_retry(
),
{'choices': [{'message': {'content': 'Retry successful'}}]},
]
llm = session._create_llm('..cls..')
llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)
# Set the retry listener on the registry
llm_registry.retry_listner = session._notify_on_llm_retry
# Create an LLM through the registry
llm = llm_registry.get_llm(
service_id='test_service',
config=default_llm_config,
)
llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)
assert mock_litellm_completion.call_count == 2
session.queue_status_message.assert_called_once_with(