mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-25 21:36:52 +08:00
[Refactor]: Add LLMRegistry for llm services (#9589)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Graham Neubig <neubig@gmail.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
parent
17b1a21296
commit
25d9cf2890
@ -18,7 +18,7 @@ from openhands.events.action import (
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.plugins import (
|
||||
PluginRequirement,
|
||||
)
|
||||
@ -102,15 +102,15 @@ class BrowsingAgent(Agent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
config: AgentConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> None:
|
||||
"""Initializes a new instance of the BrowsingAgent class.
|
||||
|
||||
Parameters:
|
||||
- llm (LLM): The llm to be used by this agent
|
||||
"""
|
||||
super().__init__(llm, config)
|
||||
super().__init__(config, llm_registry)
|
||||
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
|
||||
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
|
||||
action_subsets = ['chat', 'bid']
|
||||
|
||||
@ -3,6 +3,8 @@ import sys
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ChatCompletionToolParam
|
||||
|
||||
@ -32,7 +34,6 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message
|
||||
from openhands.events.action import AgentFinishAction, MessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_utils import check_tools
|
||||
from openhands.memory.condenser import Condenser
|
||||
from openhands.memory.condenser.condenser import Condensation, View
|
||||
@ -74,18 +75,13 @@ class CodeActAgent(Agent):
|
||||
JupyterRequirement(),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
config: AgentConfig,
|
||||
) -> None:
|
||||
def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None:
|
||||
"""Initializes a new instance of the CodeActAgent class.
|
||||
|
||||
Parameters:
|
||||
- llm (LLM): The llm to be used by this agent
|
||||
- config (AgentConfig): The configuration for this agent
|
||||
"""
|
||||
super().__init__(llm, config)
|
||||
super().__init__(config, llm_registry)
|
||||
self.pending_actions: deque['Action'] = deque()
|
||||
self.reset()
|
||||
self.tools = self._get_tools()
|
||||
@ -93,7 +89,7 @@ class CodeActAgent(Agent):
|
||||
# Create a ConversationMemory instance
|
||||
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
|
||||
|
||||
self.condenser = Condenser.from_config(self.config.condenser)
|
||||
self.condenser = Condenser.from_config(self.config.condenser, llm_registry)
|
||||
logger.debug(f'Using condenser: {type(self.condenser)}')
|
||||
|
||||
@property
|
||||
|
||||
@ -22,7 +22,7 @@ from openhands.events.observation import (
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
|
||||
"""
|
||||
FIXME: There are a few problems this surfaced
|
||||
@ -42,8 +42,12 @@ class DummyAgent(Agent):
|
||||
without making any LLM calls.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLM, config: AgentConfig):
|
||||
super().__init__(llm, config)
|
||||
def __init__(
|
||||
self,
|
||||
config: AgentConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
):
|
||||
super().__init__(config, llm_registry)
|
||||
self.steps: list[ActionObs] = [
|
||||
{
|
||||
'action': MessageAction('Time to get started!'),
|
||||
|
||||
@ -4,7 +4,7 @@ import openhands.agenthub.loc_agent.function_calling as locagent_function_callin
|
||||
from openhands.agenthub.codeact_agent import CodeActAgent
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.events.action import Action
|
||||
@ -16,8 +16,8 @@ class LocAgent(CodeActAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
config: AgentConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> None:
|
||||
"""Initializes a new instance of the LocAgent class.
|
||||
|
||||
@ -25,7 +25,8 @@ class LocAgent(CodeActAgent):
|
||||
- llm (LLM): The llm to be used by this agent
|
||||
- config (AgentConfig): The configuration for the agent
|
||||
"""
|
||||
super().__init__(llm, config)
|
||||
|
||||
super().__init__(config, llm_registry)
|
||||
|
||||
self.tools = locagent_function_calling.get_tools()
|
||||
logger.debug(
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ChatCompletionToolParam
|
||||
|
||||
@ -15,7 +17,6 @@ from openhands.agenthub.readonly_agent import (
|
||||
)
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
@ -37,17 +38,16 @@ class ReadOnlyAgent(CodeActAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
config: AgentConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> None:
|
||||
"""Initializes a new instance of the ReadOnlyAgent class.
|
||||
|
||||
Parameters:
|
||||
- llm (LLM): The llm to be used by this agent
|
||||
- config (AgentConfig): The configuration for this agent
|
||||
"""
|
||||
# Initialize the CodeActAgent class; some of it is overridden with class methods
|
||||
super().__init__(llm, config)
|
||||
super().__init__(config, llm_registry)
|
||||
|
||||
logger.debug(
|
||||
f'TOOLS loaded for ReadOnlyAgent: {", ".join([tool.get("function").get("name") for tool in self.tools])}'
|
||||
|
||||
@ -16,7 +16,7 @@ from openhands.events.action import (
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.plugins import (
|
||||
PluginRequirement,
|
||||
)
|
||||
@ -127,17 +127,13 @@ class VisualBrowsingAgent(Agent):
|
||||
sandbox_plugins: list[PluginRequirement] = []
|
||||
response_parser = BrowsingResponseParser()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
config: AgentConfig,
|
||||
) -> None:
|
||||
def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None:
|
||||
"""Initializes a new instance of the VisualBrowsingAgent class.
|
||||
|
||||
Parameters:
|
||||
- llm (LLM): The llm to be used by this agent
|
||||
"""
|
||||
super().__init__(llm, config)
|
||||
super().__init__(config, llm_registry)
|
||||
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
|
||||
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
|
||||
action_subsets = [
|
||||
|
||||
@ -83,6 +83,7 @@ from openhands.microagent.microagent import BaseMicroagent
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage.settings.file_settings_store import FileSettingsStore
|
||||
from openhands.utils.utils import create_registry_and_convo_stats
|
||||
|
||||
|
||||
async def cleanup_session(
|
||||
@ -147,9 +148,16 @@ async def run_session(
|
||||
None, display_initialization_animation, 'Initializing...', is_loaded
|
||||
)
|
||||
|
||||
agent = create_agent(config)
|
||||
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||
config,
|
||||
sid,
|
||||
None,
|
||||
)
|
||||
|
||||
agent = create_agent(config, llm_registry)
|
||||
runtime = create_runtime(
|
||||
config,
|
||||
llm_registry,
|
||||
sid=sid,
|
||||
headless_mode=True,
|
||||
agent=agent,
|
||||
@ -161,7 +169,7 @@ async def run_session(
|
||||
|
||||
runtime.subscribe_to_shell_stream(stream_to_console)
|
||||
|
||||
controller, initial_state = create_controller(agent, runtime, config)
|
||||
controller, initial_state = create_controller(agent, runtime, config, convo_stats)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.events.action import Action
|
||||
@ -17,7 +19,6 @@ from openhands.core.exceptions import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
|
||||
|
||||
@ -38,10 +39,11 @@ class Agent(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
config: AgentConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
):
|
||||
self.llm = llm
|
||||
self.llm = llm_registry.get_llm_from_agent_config('agent', config)
|
||||
self.llm_registry = llm_registry
|
||||
self.config = config
|
||||
self._complete = False
|
||||
self._prompt_manager: 'PromptManager' | None = None
|
||||
|
||||
@ -73,9 +73,9 @@ from openhands.events.observation import (
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
# note: RESUME is only available on web GUI
|
||||
@ -109,6 +109,7 @@ class AgentController:
|
||||
self,
|
||||
agent: Agent,
|
||||
event_stream: EventStream,
|
||||
convo_stats: ConversationStats,
|
||||
iteration_delta: int,
|
||||
budget_per_task_delta: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
@ -148,6 +149,7 @@ class AgentController:
|
||||
self.agent = agent
|
||||
self.headless_mode = headless_mode
|
||||
self.is_delegate = is_delegate
|
||||
self.convo_stats = convo_stats
|
||||
|
||||
# the event stream must be set before maybe subscribing to it
|
||||
self.event_stream = event_stream
|
||||
@ -163,6 +165,7 @@ class AgentController:
|
||||
# state from the previous session, state from a parent agent, or a fresh state
|
||||
self.set_initial_state(
|
||||
state=initial_state,
|
||||
convo_stats=convo_stats,
|
||||
max_iterations=iteration_delta,
|
||||
max_budget_per_task=budget_per_task_delta,
|
||||
confirmation_mode=confirmation_mode,
|
||||
@ -477,11 +480,6 @@ class AgentController:
|
||||
log_level, str(observation_to_print), extra={'msg_type': 'OBSERVATION'}
|
||||
)
|
||||
|
||||
# TODO: these metrics come from the draft editor, and they get accumulated into controller's state metrics and the agent's llm metrics
|
||||
# In the future, we should have a more principled way to sharing metrics across all LLM instances for a given conversation
|
||||
if observation.llm_metrics is not None:
|
||||
self.state_tracker.merge_metrics(observation.llm_metrics)
|
||||
|
||||
# this happens for runnable actions and microagent actions
|
||||
if self._pending_action and self._pending_action.id == observation.cause:
|
||||
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||
@ -657,14 +655,10 @@ class AgentController:
|
||||
"""
|
||||
agent_cls: type[Agent] = Agent.get_cls(action.agent)
|
||||
agent_config = self.agent_configs.get(action.agent, self.agent.config)
|
||||
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
|
||||
# Make sure metrics are shared between parent and child for global accumulation
|
||||
llm = LLM(
|
||||
config=llm_config,
|
||||
retry_listener=self.agent.llm.retry_listener,
|
||||
metrics=self.state.metrics,
|
||||
delegate_agent = agent_cls(
|
||||
config=agent_config, llm_registry=self.agent.llm_registry
|
||||
)
|
||||
delegate_agent = agent_cls(llm=llm, config=agent_config)
|
||||
|
||||
# Take a snapshot of the current metrics before starting the delegate
|
||||
state = State(
|
||||
@ -683,7 +677,7 @@ class AgentController:
|
||||
)
|
||||
self.log(
|
||||
'debug',
|
||||
f'start delegate, creating agent {delegate_agent.name} using LLM {llm}',
|
||||
f'start delegate, creating agent {delegate_agent.name}',
|
||||
)
|
||||
|
||||
# Create the delegate with is_delegate=True so it does NOT subscribe directly
|
||||
@ -693,6 +687,7 @@ class AgentController:
|
||||
user_id=self.user_id,
|
||||
agent=delegate_agent,
|
||||
event_stream=self.event_stream,
|
||||
convo_stats=self.convo_stats,
|
||||
iteration_delta=self._initial_max_iterations,
|
||||
budget_per_task_delta=self._initial_max_budget_per_task,
|
||||
agent_to_llm_config=self.agent_to_llm_config,
|
||||
@ -795,13 +790,8 @@ class AgentController:
|
||||
extra={'msg_type': 'STEP'},
|
||||
)
|
||||
|
||||
# Ensure budget control flag is synchronized with the latest metrics.
|
||||
# In the future, we should centralized the use of one LLM object per conversation.
|
||||
# This will help us unify the cost for auto generating titles, running the condensor, etc.
|
||||
# Before many microservices will touh the same llm cost field, we should sync with the budget flag for the controller
|
||||
# and check that we haven't exceeded budget BEFORE executing an agent step.
|
||||
# Synchronize spend across all llm services with the budget flag
|
||||
self.state_tracker.sync_budget_flag_with_metrics()
|
||||
|
||||
if self._is_stuck():
|
||||
await self._react_to_exception(
|
||||
AgentStuckInLoopError('Agent got stuck in a loop')
|
||||
@ -961,14 +951,15 @@ class AgentController:
|
||||
def set_initial_state(
|
||||
self,
|
||||
state: State | None,
|
||||
convo_stats: ConversationStats,
|
||||
max_iterations: int,
|
||||
max_budget_per_task: float | None,
|
||||
confirmation_mode: bool = False,
|
||||
):
|
||||
self.state_tracker.set_initial_state(
|
||||
self.id,
|
||||
self.agent,
|
||||
state,
|
||||
convo_stats,
|
||||
max_iterations,
|
||||
max_budget_per_task,
|
||||
confirmation_mode,
|
||||
@ -1009,37 +1000,20 @@ class AgentController:
|
||||
action: The action to attach metrics to
|
||||
"""
|
||||
# Get metrics from agent LLM
|
||||
agent_metrics = self.state.metrics
|
||||
metrics = self.convo_stats.get_combined_metrics()
|
||||
|
||||
# Get metrics from condenser LLM if it exists
|
||||
condenser_metrics: Metrics | None = None
|
||||
if hasattr(self.agent, 'condenser') and hasattr(self.agent.condenser, 'llm'):
|
||||
condenser_metrics = self.agent.condenser.llm.metrics
|
||||
|
||||
# Create a new minimal metrics object with just what the frontend needs
|
||||
metrics = Metrics(model_name=agent_metrics.model_name)
|
||||
|
||||
# Set accumulated cost (sum of agent and condenser costs)
|
||||
metrics.accumulated_cost = agent_metrics.accumulated_cost
|
||||
if condenser_metrics:
|
||||
metrics.accumulated_cost += condenser_metrics.accumulated_cost
|
||||
# Create a clean copy with only the fields we want to keep
|
||||
clean_metrics = Metrics()
|
||||
clean_metrics.accumulated_cost = metrics.accumulated_cost
|
||||
clean_metrics._accumulated_token_usage = copy.deepcopy(
|
||||
metrics.accumulated_token_usage
|
||||
)
|
||||
|
||||
# Add max_budget_per_task to metrics
|
||||
if self.state.budget_flag:
|
||||
metrics.max_budget_per_task = self.state.budget_flag.max_value
|
||||
clean_metrics.max_budget_per_task = self.state.budget_flag.max_value
|
||||
|
||||
# Set accumulated token usage (sum of agent and condenser token usage)
|
||||
# Use a deep copy to ensure we don't modify the original object
|
||||
metrics._accumulated_token_usage = (
|
||||
agent_metrics.accumulated_token_usage.model_copy(deep=True)
|
||||
)
|
||||
if condenser_metrics:
|
||||
metrics._accumulated_token_usage = (
|
||||
metrics._accumulated_token_usage
|
||||
+ condenser_metrics.accumulated_token_usage
|
||||
)
|
||||
|
||||
action.llm_metrics = metrics
|
||||
action.llm_metrics = clean_metrics
|
||||
|
||||
# Log the metrics information for debugging
|
||||
# Get the latest usage directly from the agent's metrics
|
||||
|
||||
@ -21,6 +21,7 @@ from openhands.events.action.agent import AgentFinishAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.view import View
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_agent_state_filename
|
||||
|
||||
@ -84,6 +85,7 @@ class State:
|
||||
limit_increase_amount=100, current_value=0, max_value=100
|
||||
)
|
||||
)
|
||||
convo_stats: ConversationStats | None = None
|
||||
budget_flag: BudgetControlFlag | None = None
|
||||
confirmation_mode: bool = False
|
||||
history: list[Event] = field(default_factory=list)
|
||||
@ -91,8 +93,7 @@ class State:
|
||||
outputs: dict = field(default_factory=dict)
|
||||
agent_state: AgentState = AgentState.LOADING
|
||||
resume_state: AgentState | None = None
|
||||
# global metrics for the current task
|
||||
metrics: Metrics = field(default_factory=Metrics)
|
||||
|
||||
# root agent has level 0, and every delegate increases the level by one
|
||||
delegate_level: int = 0
|
||||
# start_id and end_id track the range of events in history
|
||||
@ -116,9 +117,14 @@ class State:
|
||||
local_metrics: Metrics | None = None
|
||||
delegates: dict[tuple[int, int], tuple[str, str]] | None = None
|
||||
|
||||
metrics: Metrics = field(default_factory=Metrics)
|
||||
|
||||
def save_to_session(
|
||||
self, sid: str, file_store: FileStore, user_id: str | None
|
||||
) -> None:
|
||||
convo_stats = self.convo_stats
|
||||
self.convo_stats = None # Don't save convo stats, handles itself
|
||||
|
||||
pickled = pickle.dumps(self)
|
||||
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
|
||||
encoded = base64.b64encode(pickled).decode('utf-8')
|
||||
@ -138,6 +144,8 @@ class State:
|
||||
logger.error(f'Failed to save state to session: {e}')
|
||||
raise e
|
||||
|
||||
self.convo_stats = convo_stats # restore reference
|
||||
|
||||
@staticmethod
|
||||
def restore_from_session(
|
||||
sid: str, file_store: FileStore, user_id: str | None = None
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.control_flags import (
|
||||
BudgetControlFlag,
|
||||
IterationControlFlag,
|
||||
@ -14,7 +13,7 @@ from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.serialization.event import event_to_trajectory
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
@ -51,8 +50,8 @@ class StateTracker:
|
||||
def set_initial_state(
|
||||
self,
|
||||
id: str,
|
||||
agent: Agent,
|
||||
state: State | None,
|
||||
convo_stats: ConversationStats,
|
||||
max_iterations: int,
|
||||
max_budget_per_task: float | None,
|
||||
confirmation_mode: bool = False,
|
||||
@ -75,6 +74,7 @@ class StateTracker:
|
||||
session_id=id.removesuffix('-delegate'),
|
||||
user_id=self.user_id,
|
||||
inputs={},
|
||||
convo_stats=convo_stats,
|
||||
iteration_flag=IterationControlFlag(
|
||||
limit_increase_amount=max_iterations,
|
||||
current_value=0,
|
||||
@ -99,13 +99,7 @@ class StateTracker:
|
||||
if self.state.start_id <= -1:
|
||||
self.state.start_id = 0
|
||||
|
||||
logger.info(
|
||||
f'AgentController {id} initializing history from event {self.state.start_id}',
|
||||
)
|
||||
|
||||
# Share the state metrics with the agent's LLM metrics
|
||||
# This ensures that all accumulated metrics are always in sync between controller and llm
|
||||
agent.llm.metrics = self.state.metrics
|
||||
state.convo_stats = convo_stats
|
||||
|
||||
def _init_history(self, event_stream: EventStream) -> None:
|
||||
"""Initializes the agent's history from the event stream.
|
||||
@ -254,6 +248,9 @@ class StateTracker:
|
||||
if self.sid and self.file_store:
|
||||
self.state.save_to_session(self.sid, self.file_store, self.user_id)
|
||||
|
||||
if self.state.convo_stats:
|
||||
self.state.convo_stats.save_metrics()
|
||||
|
||||
def run_control_flags(self):
|
||||
"""Performs one step of the control flags"""
|
||||
self.state.iteration_flag.step()
|
||||
@ -264,20 +261,8 @@ class StateTracker:
|
||||
"""Ensures that budget flag is up to date with accumulated costs from llm completions
|
||||
Budget flag will monitor for when budget is exceeded
|
||||
"""
|
||||
if self.state.budget_flag:
|
||||
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost
|
||||
|
||||
def merge_metrics(self, metrics: Metrics):
|
||||
"""Merges metrics with the state metrics
|
||||
|
||||
NOTE: this should be refactored in the future. We should have services (draft llm, title autocomplete, condenser, etc)
|
||||
use their own LLMs, but the metrics object should be shared. This way we have one source of truth for accumulated costs from
|
||||
all services
|
||||
|
||||
This would prevent having fragmented stores for metrics, and we don't have the burden of deciding where and how to store them
|
||||
if we decide introduce more specialized services that require llm completions
|
||||
|
||||
"""
|
||||
self.state.metrics.merge(metrics)
|
||||
if self.state.budget_flag:
|
||||
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost
|
||||
# Sync cost across all llm services from llm registry
|
||||
if self.state.budget_flag and self.state.convo_stats:
|
||||
self.state.budget_flag.current_value = (
|
||||
self.state.convo_stats.get_combined_metrics().accumulated_cost
|
||||
)
|
||||
|
||||
@ -157,13 +157,16 @@ class OpenHandsConfig(BaseModel):
|
||||
"""Get a map of agent names to llm configs."""
|
||||
return {name: self.get_llm_config_from_agent(name) for name in self.agents}
|
||||
|
||||
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig:
|
||||
agent_config: AgentConfig = self.get_agent_config(name)
|
||||
def get_llm_config_from_agent_config(self, agent_config: AgentConfig):
|
||||
llm_config_name = (
|
||||
agent_config.llm_config if agent_config.llm_config is not None else 'llm'
|
||||
)
|
||||
return self.get_llm_config(llm_config_name)
|
||||
|
||||
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig:
|
||||
agent_config: AgentConfig = self.get_agent_config(name)
|
||||
return self.get_llm_config_from_agent_config(agent_config)
|
||||
|
||||
def get_agent_configs(self) -> dict[str, AgentConfig]:
|
||||
return self.agents
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Callable, Protocol
|
||||
|
||||
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
|
||||
import openhands.cli.suppress_warnings # noqa: F401
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.replay import ReplayManager
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
@ -33,10 +32,12 @@ from openhands.events.action.action import Action
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.io import read_input, read_task
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.mcp import add_mcp_tools_to_agent
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
from openhands.utils.utils import create_registry_and_convo_stats
|
||||
|
||||
|
||||
class FakeUserResponseFunc(Protocol):
|
||||
@ -53,12 +54,12 @@ async def run_controller(
|
||||
initial_user_action: Action,
|
||||
sid: str | None = None,
|
||||
runtime: Runtime | None = None,
|
||||
agent: Agent | None = None,
|
||||
exit_on_message: bool = False,
|
||||
fake_user_response_fn: FakeUserResponseFunc | None = None,
|
||||
headless_mode: bool = True,
|
||||
memory: Memory | None = None,
|
||||
conversation_instructions: str | None = None,
|
||||
llm_registry: LLMRegistry | None = None,
|
||||
) -> State | None:
|
||||
"""Main coroutine to run the agent controller with task input flexibility.
|
||||
|
||||
@ -70,7 +71,6 @@ async def run_controller(
|
||||
sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing.
|
||||
Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
|
||||
runtime: (optional) A runtime for the agent to run on.
|
||||
agent: (optional) A agent to run.
|
||||
exit_on_message: quit if agent asks for a message from user (optional)
|
||||
fake_user_response_fn: An optional function that receives the current state
|
||||
(could be None) and returns a fake user response.
|
||||
@ -98,8 +98,13 @@ async def run_controller(
|
||||
"""
|
||||
sid = sid or generate_sid(config)
|
||||
|
||||
if agent is None:
|
||||
agent = create_agent(config)
|
||||
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||
config,
|
||||
sid,
|
||||
None,
|
||||
)
|
||||
|
||||
agent = create_agent(config, llm_registry)
|
||||
|
||||
# when the runtime is created, it will be connected and clone the selected repository
|
||||
repo_directory = None
|
||||
@ -108,6 +113,7 @@ async def run_controller(
|
||||
repo_tokens = get_provider_tokens()
|
||||
runtime = create_runtime(
|
||||
config,
|
||||
llm_registry,
|
||||
sid=sid,
|
||||
headless_mode=headless_mode,
|
||||
agent=agent,
|
||||
@ -159,7 +165,7 @@ async def run_controller(
|
||||
)
|
||||
|
||||
controller, initial_state = create_controller(
|
||||
agent, runtime, config, replay_events=replay_events
|
||||
agent, runtime, config, convo_stats, replay_events=replay_events
|
||||
)
|
||||
|
||||
assert isinstance(initial_user_action, Action), (
|
||||
|
||||
@ -21,12 +21,13 @@ from openhands.integrations.provider import (
|
||||
ProviderToken,
|
||||
ProviderType,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroagent
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.security import SecurityAnalyzer, options
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
@ -34,6 +35,7 @@ from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
def create_runtime(
|
||||
config: OpenHandsConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
sid: str | None = None,
|
||||
headless_mode: bool = True,
|
||||
agent: Agent | None = None,
|
||||
@ -82,6 +84,7 @@ def create_runtime(
|
||||
sid=session_id,
|
||||
plugins=agent_cls.sandbox_plugins,
|
||||
headless_mode=headless_mode,
|
||||
llm_registry=llm_registry,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
)
|
||||
|
||||
@ -203,16 +206,11 @@ def create_memory(
|
||||
return memory
|
||||
|
||||
|
||||
def create_agent(config: OpenHandsConfig) -> Agent:
|
||||
def create_agent(config: OpenHandsConfig, llm_registry: LLMRegistry) -> Agent:
|
||||
agent_cls: type[Agent] = Agent.get_cls(config.default_agent)
|
||||
agent_config = config.get_agent_config(config.default_agent)
|
||||
llm_config = config.get_llm_config_from_agent(config.default_agent)
|
||||
|
||||
agent = agent_cls(
|
||||
llm=LLM(config=llm_config),
|
||||
config=agent_config,
|
||||
)
|
||||
|
||||
config.get_llm_config_from_agent(config.default_agent)
|
||||
agent = agent_cls(config=agent_config, llm_registry=llm_registry)
|
||||
return agent
|
||||
|
||||
|
||||
@ -220,6 +218,7 @@ def create_controller(
|
||||
agent: Agent,
|
||||
runtime: Runtime,
|
||||
config: OpenHandsConfig,
|
||||
convo_stats: ConversationStats,
|
||||
headless_mode: bool = True,
|
||||
replay_events: list[Event] | None = None,
|
||||
) -> tuple[AgentController, State | None]:
|
||||
@ -237,6 +236,7 @@ def create_controller(
|
||||
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
convo_stats=convo_stats,
|
||||
iteration_delta=config.max_iterations,
|
||||
budget_per_task_delta=config.max_budget_per_task,
|
||||
agent_to_llm_config=config.get_agent_to_llm_config_map(),
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Callable
|
||||
import httpx
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
@ -34,7 +35,6 @@ from openhands.llm.fn_call_converter import (
|
||||
convert_fncall_messages_to_non_fncall_messages,
|
||||
convert_non_fncall_messages_to_fncall_messages,
|
||||
)
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.llm.retry_mixin import RetryMixin
|
||||
|
||||
__all__ = ['LLM']
|
||||
@ -133,6 +133,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig,
|
||||
service_id: str,
|
||||
metrics: Metrics | None = None,
|
||||
retry_listener: Callable[[int, int], None] | None = None,
|
||||
) -> None:
|
||||
@ -145,11 +146,12 @@ class LLM(RetryMixin, DebugMixin):
|
||||
metrics: The metrics to use.
|
||||
"""
|
||||
self._tried_model_info = False
|
||||
self.cost_metric_supported: bool = True
|
||||
self.config: LLMConfig = copy.deepcopy(config)
|
||||
self.service_id = service_id
|
||||
self.metrics: Metrics = (
|
||||
metrics if metrics is not None else Metrics(model_name=config.model)
|
||||
)
|
||||
self.cost_metric_supported: bool = True
|
||||
self.config: LLMConfig = copy.deepcopy(config)
|
||||
|
||||
self.model_info: ModelInfo | None = None
|
||||
self.retry_listener = retry_listener
|
||||
@ -408,8 +410,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
assert self.config.log_completions_folder is not None
|
||||
log_file = os.path.join(
|
||||
self.config.log_completions_folder,
|
||||
# use the metric model name (for draft editor)
|
||||
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
|
||||
f'{self.config.model.replace("/", "__")}-{time.time()}.json',
|
||||
)
|
||||
|
||||
# set up the dict to be logged
|
||||
|
||||
132
openhands/llm/llm_registry.py
Normal file
132
openhands/llm/llm_registry.py
Normal file
@ -0,0 +1,132 @@
|
||||
import copy
|
||||
from typing import Any, Callable
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
class RegistryEvent(BaseModel):
|
||||
llm: LLM
|
||||
service_id: str
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
class LLMRegistry:
|
||||
def __init__(
|
||||
self,
|
||||
config: OpenHandsConfig,
|
||||
agent_cls: str | None = None,
|
||||
retry_listener: Callable[[int, int], None] | None = None,
|
||||
):
|
||||
self.registry_id = str(uuid4())
|
||||
self.config = copy.deepcopy(config)
|
||||
self.retry_listner = retry_listener
|
||||
self.agent_to_llm_config = self.config.get_agent_to_llm_config_map()
|
||||
self.service_to_llm: dict[str, LLM] = {}
|
||||
self.subscriber: Callable[[Any], None] | None = None
|
||||
|
||||
selected_agent_cls = self.config.default_agent
|
||||
if agent_cls:
|
||||
selected_agent_cls = agent_cls
|
||||
|
||||
agent_name = selected_agent_cls if selected_agent_cls is not None else 'agent'
|
||||
llm_config = self.config.get_llm_config_from_agent(agent_name)
|
||||
self.active_agent_llm: LLM = self.get_llm('agent', llm_config)
|
||||
|
||||
def _create_new_llm(
|
||||
self, service_id: str, config: LLMConfig, with_listener: bool = True
|
||||
) -> LLM:
|
||||
if with_listener:
|
||||
llm = LLM(
|
||||
service_id=service_id, config=config, retry_listener=self.retry_listner
|
||||
)
|
||||
else:
|
||||
llm = LLM(service_id=service_id, config=config)
|
||||
self.service_to_llm[service_id] = llm
|
||||
self.notify(RegistryEvent(llm=llm, service_id=service_id))
|
||||
return llm
|
||||
|
||||
def request_extraneous_completion(
|
||||
self, service_id: str, llm_config: LLMConfig, messages: list[dict[str, str]]
|
||||
) -> str:
|
||||
logger.info(f'extraneous completion: {service_id}')
|
||||
if service_id not in self.service_to_llm:
|
||||
self._create_new_llm(
|
||||
config=llm_config, service_id=service_id, with_listener=False
|
||||
)
|
||||
|
||||
llm = self.service_to_llm[service_id]
|
||||
response = llm.completion(messages=messages)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def get_llm_from_agent_config(self, service_id: str, agent_config: AgentConfig):
|
||||
llm_config = self.config.get_llm_config_from_agent_config(agent_config)
|
||||
if service_id in self.service_to_llm:
|
||||
if self.service_to_llm[service_id].config != llm_config:
|
||||
# TODO: update llm config internally
|
||||
# Done when agent delegates has different config, we should reuse the existing LLM
|
||||
pass
|
||||
return self.service_to_llm[service_id]
|
||||
|
||||
return self._create_new_llm(config=llm_config, service_id=service_id)
|
||||
|
||||
def get_llm(
|
||||
self,
|
||||
service_id: str,
|
||||
config: LLMConfig | None = None,
|
||||
):
|
||||
logger.info(
|
||||
f'[LLM registry {self.registry_id}]: Registering service for {service_id}'
|
||||
)
|
||||
|
||||
# Attempting to switch configs for existing LLM
|
||||
if (
|
||||
service_id in self.service_to_llm
|
||||
and self.service_to_llm[service_id].config != config
|
||||
):
|
||||
raise ValueError(
|
||||
f'Requesting same service ID {service_id} with different config, use a new service ID'
|
||||
)
|
||||
|
||||
if service_id in self.service_to_llm:
|
||||
return self.service_to_llm[service_id]
|
||||
|
||||
if not config:
|
||||
raise ValueError('Requesting new LLM without specifying LLM config')
|
||||
|
||||
return self._create_new_llm(config=config, service_id=service_id)
|
||||
|
||||
def get_active_llm(self) -> LLM:
|
||||
return self.active_agent_llm
|
||||
|
||||
def _set_active_llm(self, service_id) -> None:
|
||||
if service_id not in self.service_to_llm:
|
||||
raise ValueError(f'Unrecognized service ID: {service_id}')
|
||||
self.active_agent_llm = self.service_to_llm[service_id]
|
||||
|
||||
def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None:
|
||||
self.subscriber = callback
|
||||
|
||||
# Subscriptions happen after default llm is initialized
|
||||
# Notify service of this llm
|
||||
self.notify(
|
||||
RegistryEvent(
|
||||
llm=self.active_agent_llm, service_id=self.active_agent_llm.service_id
|
||||
)
|
||||
)
|
||||
|
||||
def notify(self, event: RegistryEvent):
|
||||
if self.subscriber:
|
||||
try:
|
||||
self.subscriber(event)
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to emit event: {e}')
|
||||
@ -10,6 +10,7 @@ from openhands.controller.state.state import State
|
||||
from openhands.core.config.condenser_config import CondenserConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.view import View
|
||||
|
||||
CONDENSER_METADATA_KEY = 'condenser_meta'
|
||||
@ -144,7 +145,9 @@ class Condenser(ABC):
|
||||
CONDENSER_REGISTRY[configuration_type] = cls
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: CondenserConfig) -> Condenser:
|
||||
def from_config(
|
||||
cls, config: CondenserConfig, llm_registry: LLMRegistry
|
||||
) -> Condenser:
|
||||
"""Create a condenser from a configuration object.
|
||||
|
||||
Args:
|
||||
@ -158,7 +161,7 @@ class Condenser(ABC):
|
||||
"""
|
||||
try:
|
||||
condenser_class = CONDENSER_REGISTRY[type(config)]
|
||||
return condenser_class.from_config(config)
|
||||
return condenser_class.from_config(config, llm_registry)
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown condenser config: {config}')
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condensation,
|
||||
RollingCondenser,
|
||||
@ -58,7 +59,9 @@ class AmortizedForgettingCondenser(RollingCondenser):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: AmortizedForgettingCondenserConfig
|
||||
cls,
|
||||
config: AmortizedForgettingCondenserConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> AmortizedForgettingCondenser:
|
||||
return AmortizedForgettingCondenser(**config.model_dump(exclude={'type'}))
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import BrowserOutputCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
@ -40,7 +41,7 @@ class BrowserOutputCondenser(Condenser):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: BrowserOutputCondenserConfig
|
||||
cls, config: BrowserOutputCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> BrowserOutputCondenser:
|
||||
return BrowserOutputCondenser(**config.model_dump(exclude={'type'}))
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from openhands.events.action.agent import (
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation import Observation
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
|
||||
|
||||
|
||||
@ -177,7 +178,9 @@ class ConversationWindowCondenser(RollingCondenser):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, _config: ConversationWindowCondenserConfig
|
||||
cls,
|
||||
_config: ConversationWindowCondenserConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> ConversationWindowCondenser:
|
||||
return ConversationWindowCondenser()
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
from openhands.core.config.condenser_config import LLMAttentionCondenserConfig
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condensation,
|
||||
RollingCondenser,
|
||||
@ -22,7 +23,12 @@ class ImportantEventSelection(BaseModel):
|
||||
class LLMAttentionCondenser(RollingCondenser):
|
||||
"""Rolling condenser strategy that uses an LLM to select the most important events when condensing the history."""
|
||||
|
||||
def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 1):
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
max_size: int = 100,
|
||||
keep_first: int = 1,
|
||||
):
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
@ -113,15 +119,19 @@ class LLMAttentionCondenser(RollingCondenser):
|
||||
return len(view) > self.max_size
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: LLMAttentionCondenserConfig) -> LLMAttentionCondenser:
|
||||
def from_config(
|
||||
cls, config: LLMAttentionCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> LLMAttentionCondenser:
|
||||
# This condenser cannot take advantage of prompt caching. If it happens
|
||||
# to be set, we'll pay for the cache writes but never get a chance to
|
||||
# save on a read.
|
||||
llm_config = config.llm_config.model_copy()
|
||||
llm_config.caching_prompt = False
|
||||
|
||||
llm = llm_registry.get_llm('condenser', llm_config)
|
||||
|
||||
return LLMAttentionCondenser(
|
||||
llm=LLM(config=llm_config),
|
||||
llm=llm,
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
)
|
||||
|
||||
@ -5,7 +5,8 @@ from openhands.core.message import Message, TextContent
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condensation,
|
||||
RollingCondenser,
|
||||
@ -154,16 +155,17 @@ CURRENT_STATE: Last flip: Heads, Haiku count: 15/20"""
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: LLMSummarizingCondenserConfig
|
||||
cls, config: LLMSummarizingCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> LLMSummarizingCondenser:
|
||||
# This condenser cannot take advantage of prompt caching. If it happens
|
||||
# to be set, we'll pay for the cache writes but never get a chance to
|
||||
# save on a read.
|
||||
llm_config = config.llm_config.model_copy()
|
||||
llm_config.caching_prompt = False
|
||||
llm = llm_registry.get_llm('condenser', llm_config)
|
||||
|
||||
return LLMSummarizingCondenser(
|
||||
llm=LLM(config=llm_config),
|
||||
llm=llm,
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
max_event_length=config.max_event_length,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
@ -12,7 +13,9 @@ class NoOpCondenser(Condenser):
|
||||
return view
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser:
|
||||
def from_config(
|
||||
cls, config: NoOpCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> NoOpCondenser:
|
||||
return NoOpCondenser()
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import ObservationMaskingCondenserCo
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import Observation
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
@ -28,7 +29,9 @@ class ObservationMaskingCondenser(Condenser):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: ObservationMaskingCondenserConfig
|
||||
cls,
|
||||
config: ObservationMaskingCondenserConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> ObservationMaskingCondenser:
|
||||
return ObservationMaskingCondenser(**config.model_dump(exclude={'type'}))
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from contextlib import contextmanager
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config.condenser_config import CondenserPipelineConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser
|
||||
from openhands.memory.view import View
|
||||
|
||||
@ -39,8 +40,10 @@ class CondenserPipeline(Condenser):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: CondenserPipelineConfig) -> CondenserPipeline:
|
||||
condensers = [Condenser.from_config(c) for c in config.condensers]
|
||||
def from_config(
|
||||
cls, config: CondenserPipelineConfig, llm_registry: LLMRegistry
|
||||
) -> CondenserPipeline:
|
||||
condensers = [Condenser.from_config(c, llm_registry) for c in config.condensers]
|
||||
return CondenserPipeline(*condensers)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import RecentEventsCondenserConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
@ -21,7 +22,9 @@ class RecentEventsCondenser(Condenser):
|
||||
return View(events=head + tail)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: RecentEventsCondenserConfig) -> RecentEventsCondenser:
|
||||
def from_config(
|
||||
cls, config: RecentEventsCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> RecentEventsCondenser:
|
||||
return RecentEventsCondenser(**config.model_dump(exclude={'type'}))
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,8 @@ from openhands.core.message import Message, TextContent
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condensation,
|
||||
RollingCondenser,
|
||||
@ -180,15 +181,14 @@ class StructuredSummaryCondenser(RollingCondenser):
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
||||
|
||||
if not llm.is_function_calling_active():
|
||||
raise ValueError(
|
||||
'LLM must support function calling to use StructuredSummaryCondenser'
|
||||
)
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
self.max_event_length = max_event_length
|
||||
self.llm = llm
|
||||
if not self.llm.is_function_calling_active():
|
||||
raise ValueError(
|
||||
'LLM must support function calling to use StructuredSummaryCondenser'
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@ -309,16 +309,17 @@ Capture all relevant information, especially:
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: StructuredSummaryCondenserConfig
|
||||
cls, config: StructuredSummaryCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> StructuredSummaryCondenser:
|
||||
# This condenser cannot take advantage of prompt caching. If it happens
|
||||
# to be set, we'll pay for the cache writes but never get a chance to
|
||||
# save on a read.
|
||||
llm_config = config.llm_config.model_copy()
|
||||
llm_config.caching_prompt = False
|
||||
llm = llm_registry.get_llm('condenser', llm_config)
|
||||
|
||||
return StructuredSummaryCondenser(
|
||||
llm=LLM(config=llm_config),
|
||||
llm=llm,
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
max_event_length=config.max_event_length,
|
||||
|
||||
@ -23,7 +23,7 @@ class ServiceContext:
|
||||
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
|
||||
self._strategy = strategy
|
||||
if llm_config is not None:
|
||||
self.llm = LLM(llm_config)
|
||||
self.llm = LLM(llm_config, service_id='resolver')
|
||||
|
||||
def set_strategy(self, strategy: IssueHandlerInterface) -> None:
|
||||
self._strategy = strategy
|
||||
|
||||
@ -28,6 +28,7 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.resolver.interfaces.issue import Issue
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
@ -412,7 +413,8 @@ class IssueResolver:
|
||||
shutil.rmtree(self.workspace_base)
|
||||
shutil.copytree(os.path.join(self.output_dir, 'repo'), self.workspace_base)
|
||||
|
||||
runtime = create_runtime(self.app_config)
|
||||
llm_registry = LLMRegistry(self.app_config)
|
||||
runtime = create_runtime(self.app_config, llm_registry)
|
||||
await runtime.connect()
|
||||
|
||||
def on_event(evt: Event) -> None:
|
||||
|
||||
@ -463,7 +463,7 @@ def update_existing_pull_request(
|
||||
|
||||
# Summarize with LLM if provided
|
||||
if llm_config is not None:
|
||||
llm = LLM(llm_config)
|
||||
llm = LLM(llm_config, service_id='resolver')
|
||||
with open(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
|
||||
@ -54,6 +54,7 @@ from openhands.integrations.provider import (
|
||||
ProviderType,
|
||||
)
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.microagent import (
|
||||
BaseMicroagent,
|
||||
load_microagents_from_dir,
|
||||
@ -125,6 +126,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
self,
|
||||
config: OpenHandsConfig,
|
||||
event_stream: EventStream,
|
||||
llm_registry: LLMRegistry,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
@ -178,7 +180,9 @@ class Runtime(FileEditRuntimeMixin):
|
||||
|
||||
# Load mixins
|
||||
FileEditRuntimeMixin.__init__(
|
||||
self, enable_llm_editor=config.get_agent_config().enable_llm_editor
|
||||
self,
|
||||
enable_llm_editor=config.get_agent_config().enable_llm_editor,
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
|
||||
self.user_id = user_id
|
||||
|
||||
@ -43,6 +43,7 @@ from openhands.events.observation import (
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.request import send_request
|
||||
@ -68,6 +69,7 @@ class ActionExecutionClient(Runtime):
|
||||
self,
|
||||
config: OpenHandsConfig,
|
||||
event_stream: EventStream,
|
||||
llm_registry: LLMRegistry,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
@ -85,6 +87,7 @@ class ActionExecutionClient(Runtime):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
llm_registry,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
|
||||
@ -46,6 +46,7 @@ from openhands.events.observation import (
|
||||
Observation,
|
||||
)
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
@ -107,6 +108,7 @@ class CLIRuntime(Runtime):
|
||||
self,
|
||||
config: OpenHandsConfig,
|
||||
event_stream: EventStream,
|
||||
llm_registry: LLMRegistry,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
@ -119,6 +121,7 @@ class CLIRuntime(Runtime):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
llm_registry,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
|
||||
@ -20,6 +20,7 @@ from openhands.core.logger import DEBUG, DEBUG_RUNTIME
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.builder import DockerRuntimeBuilder
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
@ -90,6 +91,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
self,
|
||||
config: OpenHandsConfig,
|
||||
event_stream: EventStream,
|
||||
llm_registry: LLMRegistry,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
@ -143,6 +145,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
llm_registry,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
|
||||
@ -43,6 +43,7 @@ from openhands.core.logger import DEBUG
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
@ -81,6 +82,7 @@ class KubernetesRuntime(ActionExecutionClient):
|
||||
self,
|
||||
config: OpenHandsConfig,
|
||||
event_stream: EventStream,
|
||||
llm_registry: LLMRegistry,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
@ -137,6 +139,7 @@ class KubernetesRuntime(ActionExecutionClient):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
llm_registry,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
|
||||
@ -26,6 +26,7 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
@ -135,6 +136,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
self,
|
||||
config: OpenHandsConfig,
|
||||
event_stream: EventStream,
|
||||
llm_registry: LLMRegistry,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
@ -186,6 +188,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
llm_registry,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
@ -801,12 +804,6 @@ def _create_warm_server_in_background(
|
||||
|
||||
def _get_plugins(config: OpenHandsConfig) -> list[PluginRequirement]:
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
agent_config = config.get_agent_config(config.default_agent)
|
||||
llm = LLM(
|
||||
config=config.get_llm_config_from_agent(config.default_agent),
|
||||
)
|
||||
agent = Agent.get_cls(config.default_agent)(llm, agent_config)
|
||||
plugins = agent.sandbox_plugins
|
||||
plugins = Agent.get_cls(config.default_agent).sandbox_plugins
|
||||
return plugins
|
||||
|
||||
@ -19,6 +19,7 @@ from openhands.core.exceptions import (
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
@ -51,6 +52,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self,
|
||||
config: OpenHandsConfig,
|
||||
event_stream: EventStream,
|
||||
llm_registry: LLMRegistry,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
@ -64,6 +66,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
llm_registry,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
|
||||
@ -23,7 +23,7 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.linter import DefaultLinter
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
|
||||
|
||||
USER_MSG = """
|
||||
@ -128,7 +128,13 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
# This restricts the number of lines we can edit to avoid exceeding the token limit.
|
||||
MAX_LINES_TO_EDIT = 300
|
||||
|
||||
def __init__(self, enable_llm_editor: bool, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
enable_llm_editor: bool,
|
||||
llm_registry: LLMRegistry,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.enable_llm_editor = enable_llm_editor
|
||||
|
||||
@ -138,7 +144,6 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
draft_editor_config = self.config.get_llm_config('draft_editor')
|
||||
|
||||
# manually set the model name for the draft editor LLM to distinguish token costs
|
||||
llm_metrics = Metrics(model_name='draft_editor:' + draft_editor_config.model)
|
||||
if draft_editor_config.caching_prompt:
|
||||
logger.debug(
|
||||
'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
|
||||
@ -146,7 +151,9 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
)
|
||||
draft_editor_config.caching_prompt = False
|
||||
|
||||
self.draft_editor_llm = LLM(draft_editor_config, metrics=llm_metrics)
|
||||
self.draft_editor_llm = llm_registry.get_llm(
|
||||
'draft_editor_llm', draft_editor_config
|
||||
)
|
||||
logger.debug(
|
||||
f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
|
||||
)
|
||||
|
||||
@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
||||
import socketio
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
@ -136,6 +137,16 @@ class ConversationManager(ABC):
|
||||
) -> list[AgentLoopInfo]:
|
||||
"""Get the AgentLoopInfo for conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def request_llm_completion(
|
||||
self,
|
||||
sid: str,
|
||||
service_id: str,
|
||||
llm_config: LLMConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
"""Request extraneous llm completions for a conversation"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_instance(
|
||||
|
||||
@ -15,13 +15,13 @@ from docker.models.containers import Container
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.nested_event_store import NestedEventStore
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
@ -42,6 +42,7 @@ from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_dir
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.utils.utils import create_registry_and_convo_stats
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -275,6 +276,16 @@ class DockerNestedConversationManager(ConversationManager):
|
||||
# Not supported - clients should connect directly to the nested server!
|
||||
raise ValueError('unsupported_operation')
|
||||
|
||||
async def request_llm_completion(
|
||||
self,
|
||||
sid: str,
|
||||
service_id: str,
|
||||
llm_config: LLMConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
# Not supported - clients should connect directly to the nested server!
|
||||
raise ValueError('unsupported_operation')
|
||||
|
||||
async def send_event_to_conversation(self, sid, data):
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
@ -471,27 +482,27 @@ class DockerNestedConversationManager(ConversationManager):
|
||||
# This session is created here only because it is the easiest way to get a runtime, which
|
||||
# is the easiest way to create the needed docker container
|
||||
|
||||
# Run experiment manager variant test before creating session
|
||||
config: OpenHandsConfig = ExperimentManagerImpl.run_config_variant_test(
|
||||
user_id, sid, self.config
|
||||
)
|
||||
|
||||
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||
config, sid, user_id, settings
|
||||
)
|
||||
|
||||
session = Session(
|
||||
sid=sid,
|
||||
llm_registry=llm_registry,
|
||||
convo_stats=convo_stats,
|
||||
file_store=self.file_store,
|
||||
config=config,
|
||||
sio=self.sio,
|
||||
user_id=user_id,
|
||||
)
|
||||
llm_registry.retry_listner = session._notify_on_llm_retry
|
||||
agent_cls = settings.agent or config.default_agent
|
||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
||||
llm = LLM(
|
||||
config=config.get_llm_config_from_agent(agent_name),
|
||||
retry_listener=session._notify_on_llm_retry,
|
||||
)
|
||||
llm = session._create_llm(agent_cls)
|
||||
agent_config = config.get_agent_config(agent_cls)
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
agent = Agent.get_cls(agent_cls)(agent_config, llm_registry)
|
||||
|
||||
config = config.model_copy(deep=True)
|
||||
env_vars = config.sandbox.runtime_startup_env_vars
|
||||
@ -543,6 +554,7 @@ class DockerNestedConversationManager(ConversationManager):
|
||||
headless_mode=False,
|
||||
attach_to_existing=False,
|
||||
main_module='openhands.server',
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
|
||||
# Hack - disable setting initial env.
|
||||
|
||||
@ -6,12 +6,14 @@ from typing import Callable, Iterable
|
||||
|
||||
import socketio
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.stream import EventStreamSubscriber, session_exists
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.constants import ROOM_KEY
|
||||
@ -37,6 +39,7 @@ from openhands.utils.conversation_summary import (
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
from openhands.utils.utils import create_registry_and_convo_stats
|
||||
|
||||
from .conversation_manager import ConversationManager
|
||||
|
||||
@ -332,12 +335,15 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
await self.close_session(oldest_conversation_id)
|
||||
|
||||
config = self.config.model_copy(deep=True)
|
||||
|
||||
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||
self.config, sid, user_id, settings
|
||||
)
|
||||
session = Session(
|
||||
sid=sid,
|
||||
file_store=self.file_store,
|
||||
config=config,
|
||||
llm_registry=llm_registry,
|
||||
convo_stats=convo_stats,
|
||||
sio=self.sio,
|
||||
user_id=user_id,
|
||||
)
|
||||
@ -349,7 +355,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
try:
|
||||
session.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._create_conversation_update_callback(user_id, sid, settings),
|
||||
self._create_conversation_update_callback(
|
||||
user_id, sid, settings, session.llm_registry
|
||||
),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
@ -369,6 +377,21 @@ class StandaloneConversationManager(ConversationManager):
|
||||
raise RuntimeError(f'no_conversation:{sid}')
|
||||
await session.dispatch(data)
|
||||
|
||||
async def request_llm_completion(
|
||||
self,
|
||||
sid: str,
|
||||
service_id: str,
|
||||
llm_config: LLMConfig,
|
||||
messages: list[dict[str, str]],
|
||||
):
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if not session:
|
||||
raise RuntimeError(f'no_conversation:{sid}')
|
||||
llm_registry = session.llm_registry
|
||||
return llm_registry.request_extraneous_completion(
|
||||
service_id, llm_config, messages
|
||||
)
|
||||
|
||||
async def disconnect_from_session(self, connection_id: str):
|
||||
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
|
||||
logger.info(
|
||||
@ -450,6 +473,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id: str | None,
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> Callable:
|
||||
def callback(event, *args, **kwargs):
|
||||
call_async_from_sync(
|
||||
@ -458,6 +482,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id,
|
||||
conversation_id,
|
||||
settings,
|
||||
llm_registry,
|
||||
event,
|
||||
)
|
||||
|
||||
@ -468,6 +493,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
llm_registry: LLMRegistry,
|
||||
event=None,
|
||||
):
|
||||
conversation_store = await self._get_conversation_store(user_id)
|
||||
@ -495,7 +521,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
conversation.title == default_title
|
||||
): # attempt to autogenerate if default title is in use
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, self.file_store, settings
|
||||
conversation_id, user_id, self.file_store, settings, llm_registry
|
||||
)
|
||||
if title and not title.isspace():
|
||||
conversation.title = title
|
||||
|
||||
0
openhands/server/conversation_manager/utils.py
Normal file
0
openhands/server/conversation_manager/utils.py
Normal file
@ -33,7 +33,6 @@ from openhands.integrations.service_types import (
|
||||
ProviderType,
|
||||
SuggestedTask,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
@ -47,6 +46,7 @@ from openhands.server.services.conversation_service import (
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import (
|
||||
ConversationManagerImpl,
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
@ -364,7 +364,7 @@ async def get_prompt(
|
||||
)
|
||||
|
||||
prompt_template = generate_prompt_template(stringified_events)
|
||||
prompt = generate_prompt(llm_config, prompt_template)
|
||||
prompt = generate_prompt(llm_config, prompt_template, conversation_id)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
@ -380,8 +380,9 @@ def generate_prompt_template(events: str) -> str:
|
||||
return template.render(events=events)
|
||||
|
||||
|
||||
def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
|
||||
llm = LLM(llm_config)
|
||||
def generate_prompt(
|
||||
llm_config: LLMConfig, prompt_template: str, conversation_id: str
|
||||
) -> str:
|
||||
messages = [
|
||||
{
|
||||
'role': 'system',
|
||||
@ -393,8 +394,9 @@ def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
|
||||
},
|
||||
]
|
||||
|
||||
response = llm.completion(messages=messages)
|
||||
raw_prompt = response['choices'][0]['message']['content'].strip()
|
||||
raw_prompt = ConversationManagerImpl.request_llm_completion(
|
||||
'remember_prompt', conversation_id, llm_config, messages
|
||||
)
|
||||
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
|
||||
|
||||
if prompt:
|
||||
|
||||
@ -31,20 +31,60 @@ from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
|
||||
async def create_new_conversation(
|
||||
async def initialize_conversation(
|
||||
user_id: str | None,
|
||||
conversation_id: str | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
git_provider: ProviderType | None = None,
|
||||
) -> ConversationMetadata | None:
|
||||
if conversation_id is None:
|
||||
conversation_id = uuid.uuid4().hex
|
||||
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
|
||||
if not await conversation_store.exists(conversation_id):
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
convo_metadata = ConversationMetadata(
|
||||
trigger=conversation_trigger,
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
git_provider=git_provider,
|
||||
)
|
||||
|
||||
await conversation_store.save_metadata(convo_metadata)
|
||||
return convo_metadata
|
||||
|
||||
try:
|
||||
convo_metadata = await conversation_store.get_metadata(conversation_id)
|
||||
return convo_metadata
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def start_conversation(
|
||||
user_id: str | None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
initial_user_msg: str | None,
|
||||
image_urls: list[str] | None,
|
||||
replay_json: str | None,
|
||||
conversation_instructions: str | None = None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
attach_conversation_id: bool = False,
|
||||
git_provider: ProviderType | None = None,
|
||||
conversation_id: str | None = None,
|
||||
conversation_id: str,
|
||||
convo_metadata: ConversationMetadata,
|
||||
conversation_instructions: str | None,
|
||||
mcp_config: MCPConfig | None = None,
|
||||
) -> AgentLoopInfo:
|
||||
logger.info(
|
||||
@ -52,7 +92,7 @@ async def create_new_conversation(
|
||||
extra={
|
||||
'signal': 'create_conversation',
|
||||
'user_id': user_id,
|
||||
'trigger': conversation_trigger.value,
|
||||
'trigger': convo_metadata.trigger,
|
||||
},
|
||||
)
|
||||
logger.info('Loading settings')
|
||||
@ -79,53 +119,25 @@ async def create_new_conversation(
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
session_init_args['selected_repository'] = selected_repository
|
||||
session_init_args['selected_repository'] = convo_metadata.selected_repository
|
||||
session_init_args['custom_secrets'] = custom_secrets
|
||||
session_init_args['selected_branch'] = selected_branch
|
||||
session_init_args['git_provider'] = git_provider
|
||||
session_init_args['selected_branch'] = convo_metadata.selected_branch
|
||||
session_init_args['git_provider'] = convo_metadata.git_provider
|
||||
session_init_args['conversation_instructions'] = conversation_instructions
|
||||
if mcp_config:
|
||||
session_init_args['mcp_config'] = mcp_config
|
||||
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('ServerConversation store loaded')
|
||||
|
||||
# For nested runtimes, we allow a single conversation id, passed in on container creation
|
||||
if conversation_id is None:
|
||||
conversation_id = uuid.uuid4().hex
|
||||
|
||||
if not await conversation_store.exists(conversation_id):
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_init_data
|
||||
)
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
await conversation_store.save_metadata(
|
||||
ConversationMetadata(
|
||||
trigger=conversation_trigger,
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
git_provider=git_provider,
|
||||
llm_model=conversation_init_data.llm_model,
|
||||
)
|
||||
)
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_init_data
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Starting agent loop for conversation {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
initial_message_action = None
|
||||
if initial_user_msg or image_urls:
|
||||
initial_message_action = MessageAction(
|
||||
@ -133,9 +145,6 @@ async def create_new_conversation(
|
||||
image_urls=image_urls or [],
|
||||
)
|
||||
|
||||
if attach_conversation_id:
|
||||
logger.warning('Attaching conversation ID is deprecated, skipping process')
|
||||
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
conversation_id,
|
||||
conversation_init_data,
|
||||
@ -147,6 +156,47 @@ async def create_new_conversation(
|
||||
return agent_loop_info
|
||||
|
||||
|
||||
async def create_new_conversation(
|
||||
user_id: str | None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
initial_user_msg: str | None,
|
||||
image_urls: list[str] | None,
|
||||
replay_json: str | None,
|
||||
conversation_instructions: str | None = None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
git_provider: ProviderType | None = None,
|
||||
conversation_id: str | None = None,
|
||||
mcp_config: MCPConfig | None = None,
|
||||
) -> AgentLoopInfo:
|
||||
conversation_metadata = await initialize_conversation(
|
||||
user_id,
|
||||
conversation_id,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
conversation_trigger,
|
||||
git_provider,
|
||||
)
|
||||
|
||||
if not conversation_metadata:
|
||||
raise Exception('Failed to initialize conversation')
|
||||
|
||||
return await start_conversation(
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
custom_secrets,
|
||||
initial_user_msg,
|
||||
image_urls,
|
||||
replay_json,
|
||||
conversation_metadata.conversation_id,
|
||||
conversation_metadata,
|
||||
conversation_instructions,
|
||||
mcp_config,
|
||||
)
|
||||
|
||||
|
||||
def create_provider_tokens_object(
|
||||
providers_set: list[ProviderType],
|
||||
) -> PROVIDER_TOKEN_TYPE:
|
||||
|
||||
77
openhands/server/services/conversation_stats.py
Normal file
77
openhands/server/services/conversation_stats.py
Normal file
@ -0,0 +1,77 @@
|
||||
import base64
|
||||
import pickle
|
||||
from threading import Lock
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm.llm_registry import RegistryEvent
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_stats_filename
|
||||
|
||||
|
||||
class ConversationStats:
|
||||
def __init__(
|
||||
self,
|
||||
file_store: FileStore | None,
|
||||
conversation_id: str,
|
||||
user_id: str | None,
|
||||
):
|
||||
self.metrics_path = get_conversation_stats_filename(conversation_id, user_id)
|
||||
self.file_store = file_store
|
||||
self.conversation_id = conversation_id
|
||||
self.user_id = user_id
|
||||
|
||||
self._save_lock = Lock()
|
||||
|
||||
self.service_to_metrics: dict[str, Metrics] = {}
|
||||
self.restored_metrics: dict[str, Metrics] = {}
|
||||
|
||||
# Always attempt to restore registry if it exists
|
||||
self.maybe_restore_metrics()
|
||||
|
||||
def save_metrics(self):
|
||||
if not self.file_store:
|
||||
return
|
||||
|
||||
with self._save_lock:
|
||||
pickled = pickle.dumps(self.service_to_metrics)
|
||||
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
|
||||
self.file_store.write(self.metrics_path, serialized_metrics)
|
||||
|
||||
def maybe_restore_metrics(self):
|
||||
if not self.file_store or not self.conversation_id:
|
||||
return
|
||||
|
||||
try:
|
||||
encoded = self.file_store.read(self.metrics_path)
|
||||
pickled = base64.b64decode(encoded)
|
||||
self.restored_metrics = pickle.loads(pickled)
|
||||
logger.info(f'restored metrics: {self.conversation_id}')
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def get_combined_metrics(self) -> Metrics:
|
||||
total_metrics = Metrics()
|
||||
for metrics in self.service_to_metrics.values():
|
||||
total_metrics.merge(metrics)
|
||||
|
||||
logger.info(f'metrics by all services: {self.service_to_metrics}')
|
||||
logger.info(f'combined metrics\n\n{total_metrics}')
|
||||
return total_metrics
|
||||
|
||||
def get_metrics_for_service(self, service_id: str) -> Metrics:
|
||||
if service_id not in self.service_to_metrics:
|
||||
raise Exception(f'LLM service does not exist {service_id}')
|
||||
|
||||
return self.service_to_metrics[service_id]
|
||||
|
||||
def register_llm(self, event: RegistryEvent):
|
||||
# Listen for llm creations and track their metrics
|
||||
llm = event.llm
|
||||
service_id = event.service_id
|
||||
|
||||
if service_id in self.restored_metrics:
|
||||
llm.metrics = self.restored_metrics[service_id].copy()
|
||||
del self.restored_metrics[service_id]
|
||||
|
||||
self.service_to_metrics[service_id] = llm.metrics
|
||||
@ -21,6 +21,7 @@ from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
)
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.mcp import add_mcp_tools_to_agent
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroagent
|
||||
@ -29,6 +30,7 @@ from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.security import SecurityAnalyzer, options
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import EXECUTOR, call_sync_from_async
|
||||
@ -48,6 +50,7 @@ class AgentSession:
|
||||
sid: str
|
||||
user_id: str | None
|
||||
event_stream: EventStream
|
||||
llm_registry: LLMRegistry
|
||||
file_store: FileStore
|
||||
controller: AgentController | None = None
|
||||
runtime: Runtime | None = None
|
||||
@ -63,6 +66,8 @@ class AgentSession:
|
||||
self,
|
||||
sid: str,
|
||||
file_store: FileStore,
|
||||
llm_registry: LLMRegistry,
|
||||
convo_stats: ConversationStats,
|
||||
status_callback: Callable | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
@ -80,6 +85,8 @@ class AgentSession:
|
||||
self.logger = OpenHandsLoggerAdapter(
|
||||
extra={'session_id': sid, 'user_id': user_id}
|
||||
)
|
||||
self.llm_registry = llm_registry
|
||||
self.convo_stats = convo_stats
|
||||
|
||||
async def start(
|
||||
self,
|
||||
@ -340,6 +347,7 @@ class AgentSession:
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
llm_registry=self.llm_registry,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
status_callback=self._status_callback,
|
||||
@ -360,6 +368,7 @@ class AgentSession:
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
llm_registry=self.llm_registry,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
status_callback=self._status_callback,
|
||||
@ -441,6 +450,7 @@ class AgentSession:
|
||||
user_id=self.user_id,
|
||||
file_store=self.file_store,
|
||||
event_stream=self.event_stream,
|
||||
convo_stats=self.convo_stats,
|
||||
agent=agent,
|
||||
iteration_delta=int(max_iterations),
|
||||
budget_per_task_delta=max_budget_per_task,
|
||||
@ -490,6 +500,15 @@ class AgentSession:
|
||||
)
|
||||
return memory
|
||||
|
||||
def get_state(self) -> AgentState | None:
|
||||
controller = self.controller
|
||||
if controller:
|
||||
return controller.state.agent_state
|
||||
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
|
||||
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
|
||||
return AgentState.ERROR
|
||||
return None
|
||||
|
||||
def _maybe_restore_state(self) -> State | None:
|
||||
"""Helper method to handle state restore logic."""
|
||||
restored_state = None
|
||||
@ -510,14 +529,5 @@ class AgentSession:
|
||||
self.logger.debug('No events found, no state to restore')
|
||||
return restored_state
|
||||
|
||||
def get_state(self) -> AgentState | None:
|
||||
controller = self.controller
|
||||
if controller:
|
||||
return controller.state.agent_state
|
||||
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
|
||||
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
|
||||
return AgentState.ERROR
|
||||
return None
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
@ -2,6 +2,7 @@ import asyncio
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.security import SecurityAnalyzer, options
|
||||
@ -45,6 +46,7 @@ class ServerConversation:
|
||||
else:
|
||||
runtime_cls = get_runtime_cls(self.config.runtime)
|
||||
runtime = runtime_cls(
|
||||
llm_registry=LLMRegistry(self.config),
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from logging import LoggerAdapter
|
||||
|
||||
import socketio
|
||||
@ -28,9 +27,10 @@ from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.constants import ROOM_KEY
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
@ -45,6 +45,7 @@ class Session:
|
||||
agent_session: AgentSession
|
||||
loop: asyncio.AbstractEventLoop
|
||||
config: OpenHandsConfig
|
||||
llm_registry: LLMRegistry
|
||||
file_store: FileStore
|
||||
user_id: str | None
|
||||
logger: LoggerAdapter
|
||||
@ -53,6 +54,8 @@ class Session:
|
||||
self,
|
||||
sid: str,
|
||||
config: OpenHandsConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
convo_stats: ConversationStats,
|
||||
file_store: FileStore,
|
||||
sio: socketio.AsyncServer | None,
|
||||
user_id: str | None = None,
|
||||
@ -62,17 +65,21 @@ class Session:
|
||||
self.last_active_ts = int(time.time())
|
||||
self.file_store = file_store
|
||||
self.logger = OpenHandsLoggerAdapter(extra={'session_id': sid})
|
||||
self.llm_registry = llm_registry
|
||||
self.convo_stats = convo_stats
|
||||
self.agent_session = AgentSession(
|
||||
sid,
|
||||
file_store,
|
||||
llm_registry=self.llm_registry,
|
||||
convo_stats=convo_stats,
|
||||
status_callback=self.queue_status_message,
|
||||
user_id=user_id,
|
||||
)
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
||||
)
|
||||
# Copying this means that when we update variables they are not applied to the shared global configuration!
|
||||
self.config = deepcopy(config)
|
||||
self.config = config
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||
|
||||
@ -140,13 +147,6 @@ class Session:
|
||||
else self.config.max_budget_per_task
|
||||
)
|
||||
|
||||
# This is a shallow copy of the default LLM config, so changes here will
|
||||
# persist if we retrieve the default LLM config again when constructing
|
||||
# the agent
|
||||
default_llm_config = self.config.get_llm_config()
|
||||
default_llm_config.model = settings.llm_model or ''
|
||||
default_llm_config.api_key = settings.llm_api_key
|
||||
default_llm_config.base_url = settings.llm_base_url
|
||||
self.config.search_api_key = settings.search_api_key
|
||||
if settings.sandbox_api_key:
|
||||
self.config.sandbox.api_key = settings.sandbox_api_key.get_secret_value()
|
||||
@ -181,10 +181,9 @@ class Session:
|
||||
)
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
llm = self._create_llm(agent_cls)
|
||||
agent_config = self.config.get_agent_config(agent_cls)
|
||||
|
||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
||||
llm_config = self.config.get_llm_config_from_agent(agent_name)
|
||||
if settings.enable_default_condenser:
|
||||
# Default condenser chains three condensers together:
|
||||
# 1. a conversation window condenser that handles explicit
|
||||
@ -200,7 +199,7 @@ class Session:
|
||||
ConversationWindowCondenserConfig(),
|
||||
BrowserOutputCondenserConfig(attention_window=2),
|
||||
LLMSummarizingCondenserConfig(
|
||||
llm_config=llm.config, keep_first=4, max_size=120
|
||||
llm_config=llm_config, keep_first=4, max_size=120
|
||||
),
|
||||
]
|
||||
)
|
||||
@ -208,12 +207,14 @@ class Session:
|
||||
self.logger.info(
|
||||
f'Enabling pipeline condenser with:'
|
||||
f' browser_output_masking(attention_window=2), '
|
||||
f' llm(model="{llm.config.model}", '
|
||||
f' base_url="{llm.config.base_url}", '
|
||||
f' llm(model="{llm_config.model}", '
|
||||
f' base_url="{llm_config.base_url}", '
|
||||
f' keep_first=4, max_size=80)'
|
||||
)
|
||||
agent_config.condenser = default_condenser_config
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
agent = Agent.get_cls(agent_cls)(agent_config, self.llm_registry)
|
||||
|
||||
self.llm_registry.retry_listner = self._notify_on_llm_retry
|
||||
|
||||
git_provider_tokens = None
|
||||
selected_repository = None
|
||||
@ -269,14 +270,6 @@ class Session:
|
||||
)
|
||||
return
|
||||
|
||||
def _create_llm(self, agent_cls: str | None) -> LLM:
|
||||
"""Initialize LLM, extracted for testing."""
|
||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
||||
return LLM(
|
||||
config=self.config.get_llm_config_from_agent(agent_name),
|
||||
retry_listener=self._notify_on_llm_retry,
|
||||
)
|
||||
|
||||
def _notify_on_llm_retry(self, retries: int, max: int) -> None:
|
||||
self.queue_status_message(
|
||||
'info', RuntimeStatus.LLM_RETRY, f'Retrying LLM request, {retries} / {max}'
|
||||
|
||||
@ -30,5 +30,13 @@ def get_conversation_agent_state_filename(sid: str, user_id: str | None = None)
|
||||
return f'{get_conversation_dir(sid, user_id)}agent_state.pkl'
|
||||
|
||||
|
||||
def get_conversation_llm_registry_filename(sid: str, user_id: str | None = None) -> str:
|
||||
return f'{get_conversation_dir(sid, user_id)}llm_registry.json'
|
||||
|
||||
|
||||
def get_conversation_stats_filename(sid: str, user_id: str | None = None) -> str:
|
||||
return f'{get_conversation_dir(sid, user_id)}convo_stats.json'
|
||||
|
||||
|
||||
def get_experiment_config_filename(sid: str, user_id: str | None = None) -> str:
|
||||
return f'{get_conversation_dir(sid, user_id)}exp_config.json'
|
||||
|
||||
@ -7,13 +7,16 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
async def generate_conversation_title(
|
||||
message: str, llm_config: LLMConfig, max_length: int = 50
|
||||
message: str,
|
||||
llm_config: LLMConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
max_length: int = 50,
|
||||
) -> Optional[str]:
|
||||
"""Generate a concise title for a conversation based on the first user message.
|
||||
|
||||
@ -35,8 +38,6 @@ async def generate_conversation_title(
|
||||
truncated_message = message
|
||||
|
||||
try:
|
||||
llm = LLM(llm_config)
|
||||
|
||||
# Create a simple prompt for the LLM to generate a title
|
||||
messages = [
|
||||
{
|
||||
@ -49,8 +50,9 @@ async def generate_conversation_title(
|
||||
},
|
||||
]
|
||||
|
||||
response = llm.completion(messages=messages)
|
||||
title = response.choices[0].message.content.strip()
|
||||
title = llm_registry.request_extraneous_completion(
|
||||
'convo_title_creator', llm_config, messages
|
||||
)
|
||||
|
||||
# Ensure the title isn't too long
|
||||
if len(title) > max_length:
|
||||
@ -75,7 +77,11 @@ def get_default_conversation_title(conversation_id: str) -> str:
|
||||
|
||||
|
||||
async def auto_generate_title(
|
||||
conversation_id: str, user_id: str | None, file_store: FileStore, settings: Settings
|
||||
conversation_id: str,
|
||||
user_id: str | None,
|
||||
file_store: FileStore,
|
||||
settings: Settings,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> str:
|
||||
"""Auto-generate a title for a conversation based on the first user message.
|
||||
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
|
||||
@ -116,7 +122,7 @@ async def auto_generate_title(
|
||||
|
||||
# Try to generate title using LLM
|
||||
llm_title = await generate_conversation_title(
|
||||
first_user_message, llm_config
|
||||
first_user_message, llm_config, llm_registry
|
||||
)
|
||||
if llm_title:
|
||||
logger.info(f'Generated title using LLM: {llm_title}')
|
||||
|
||||
37
openhands/utils/utils.py
Normal file
37
openhands/utils/utils.py
Normal file
@ -0,0 +1,37 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
def setup_llm_config(config: OpenHandsConfig, settings: Settings) -> OpenHandsConfig:
|
||||
# Copying this means that when we update variables they are not applied to the shared global configuration!
|
||||
config = deepcopy(config)
|
||||
|
||||
llm_config = config.get_llm_config()
|
||||
llm_config.model = settings.llm_model or ''
|
||||
llm_config.api_key = settings.llm_api_key
|
||||
llm_config.base_url = settings.llm_base_url
|
||||
config.set_llm_config(llm_config)
|
||||
return config
|
||||
|
||||
|
||||
def create_registry_and_convo_stats(
|
||||
config: OpenHandsConfig,
|
||||
sid: str,
|
||||
user_id: str | None,
|
||||
user_settings: Settings | None = None,
|
||||
) -> tuple[LLMRegistry, ConversationStats, OpenHandsConfig]:
|
||||
user_config = config
|
||||
if user_settings:
|
||||
user_config = setup_llm_config(config, user_settings)
|
||||
|
||||
agent_cls = user_settings.agent if user_settings else None
|
||||
llm_registry = LLMRegistry(user_config, agent_cls)
|
||||
file_store = get_file_store(user_config.file_store, user_config.file_store_path)
|
||||
convo_stats = ConversationStats(file_store, sid, user_id)
|
||||
llm_registry.subscribe(convo_stats.register_llm)
|
||||
return llm_registry, convo_stats, user_config
|
||||
@ -1,3 +1,4 @@
|
||||
[pytest]
|
||||
addopts = -p no:warnings
|
||||
asyncio_mode = auto
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@ -10,6 +10,7 @@ from pytest import TempPathFactory
|
||||
from openhands.core.config import MCPConfig, OpenHandsConfig, load_openhands_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||
@ -268,9 +269,13 @@ def _load_runtime(
|
||||
)
|
||||
event_stream = EventStream(sid, file_store)
|
||||
|
||||
# Create a LLMRegistry instance for the runtime
|
||||
llm_registry = LLMRegistry(config=OpenHandsConfig())
|
||||
|
||||
runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
llm_registry=llm_registry,
|
||||
sid=sid,
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
@ -20,7 +20,7 @@ def test_llm():
|
||||
|
||||
def _get_llm(type_: type[LLM]):
|
||||
with _patch_http():
|
||||
return type_(config=config.get_llm_config())
|
||||
return type_(config=config.get_llm_config(), service_id='test_service')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -38,7 +38,7 @@ def default_config():
|
||||
|
||||
|
||||
def test_llm_init_with_default_config(default_config):
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
assert llm.config.model == 'gpt-4o'
|
||||
assert llm.config.api_key.get_secret_value() == 'test_key'
|
||||
assert isinstance(llm.metrics, Metrics)
|
||||
@ -129,7 +129,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
||||
'max_input_tokens': 8000,
|
||||
'max_output_tokens': 2000,
|
||||
}
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 8000
|
||||
assert llm.config.max_output_tokens == 2000
|
||||
@ -138,7 +138,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
||||
@patch('openhands.llm.llm.litellm.get_model_info')
|
||||
def test_llm_init_without_model_info(mock_get_model_info, default_config):
|
||||
mock_get_model_info.side_effect = Exception('Model info not available')
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens is None
|
||||
assert llm.config.max_output_tokens is None
|
||||
@ -154,7 +154,7 @@ def test_llm_init_with_custom_config():
|
||||
top_p=0.9,
|
||||
top_k=None,
|
||||
)
|
||||
llm = LLM(custom_config)
|
||||
llm = LLM(custom_config, service_id='test-service')
|
||||
assert llm.config.model == 'custom-model'
|
||||
assert llm.config.api_key.get_secret_value() == 'custom_key'
|
||||
assert llm.config.max_input_tokens == 5000
|
||||
@ -168,7 +168,7 @@ def test_llm_init_with_custom_config():
|
||||
def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
|
||||
# Create a config with top_k set
|
||||
config_with_top_k = LLMConfig(top_k=50)
|
||||
llm = LLM(config_with_top_k)
|
||||
llm = LLM(config_with_top_k, service_id='test-service')
|
||||
|
||||
# Define a side effect function to check top_k
|
||||
def side_effect(*args, **kwargs):
|
||||
@ -186,7 +186,7 @@ def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
|
||||
def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
|
||||
# Create a config with top_k set to None
|
||||
config_without_top_k = LLMConfig(top_k=None)
|
||||
llm = LLM(config_without_top_k)
|
||||
llm = LLM(config_without_top_k, service_id='test-service')
|
||||
|
||||
# Define a side effect function to check top_k
|
||||
def side_effect(*args, **kwargs):
|
||||
@ -202,7 +202,7 @@ def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
|
||||
def test_llm_init_with_metrics():
|
||||
config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
metrics = Metrics()
|
||||
llm = LLM(config, metrics=metrics)
|
||||
llm = LLM(config, metrics=metrics, service_id='test-service')
|
||||
assert llm.metrics is metrics
|
||||
assert (
|
||||
llm.metrics.model_name == 'default'
|
||||
@ -224,7 +224,7 @@ def test_response_latency_tracking(mock_time, mock_litellm_completion):
|
||||
|
||||
# Create LLM instance and make a completion call
|
||||
config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
|
||||
|
||||
# Verify the response latency was tracked correctly
|
||||
@ -257,7 +257,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
|
||||
'max_input_tokens': 7000,
|
||||
'max_output_tokens': 1500,
|
||||
}
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 7000
|
||||
assert llm.config.max_output_tokens == 1500
|
||||
@ -280,7 +280,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
|
||||
default_config.model = (
|
||||
'custom-model' # Use a model not in FUNCTION_CALLING_SUPPORTED_MODELS
|
||||
)
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
tools=[
|
||||
@ -292,7 +292,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
|
||||
|
||||
# Test with Grok-4 model that doesn't support stop parameter
|
||||
default_config.model = 'xai/grok-4-0709'
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
tools=[
|
||||
@ -314,7 +314,7 @@ def test_completion_with_mocked_logger(
|
||||
'choices': [{'message': {'content': 'Test response'}}]
|
||||
}
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -345,7 +345,7 @@ def test_completion_retries(
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -365,7 +365,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -387,7 +387,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
|
||||
def test_completion_operation_cancelled(mock_litellm_completion, default_config):
|
||||
mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
with pytest.raises(OperationCancelled):
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
@ -404,7 +404,7 @@ def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
|
||||
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
with pytest.raises(OperationCancelled):
|
||||
try:
|
||||
llm.completion(
|
||||
@ -428,7 +428,7 @@ def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_
|
||||
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
result = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -469,7 +469,7 @@ def test_completion_retry_with_llm_no_response_error_zero_temp(
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
# Create LLM instance and make a completion call
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -509,7 +509,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp(
|
||||
'LLM did not return a response'
|
||||
)
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
with pytest.raises(LLMNoResponseError):
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
@ -575,7 +575,7 @@ def test_gemini_25_pro_function_calling(mock_httpx_get, mock_get_model_info):
|
||||
|
||||
for model_name, expected_support in test_cases:
|
||||
config = LLMConfig(model=model_name, api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
assert llm.is_function_calling_active() == expected_support, (
|
||||
f'Expected function calling support to be {expected_support} for model {model_name}'
|
||||
@ -617,7 +617,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp_successful_ret
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
# Create LLM instance and make a completion call with non-zero temperature
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -677,7 +677,7 @@ def test_completion_retry_with_llm_no_response_error_successful_retry(
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
# Create LLM instance and make a completion call with explicit temperature=0
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -709,7 +709,7 @@ def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
|
||||
}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
|
||||
test_llm = LLM(config=default_config)
|
||||
test_llm = LLM(config=default_config, service_id='test-service')
|
||||
response = test_llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -743,7 +743,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
|
||||
}
|
||||
|
||||
# Initialize LLM and call completion
|
||||
llm = LLM(config=gemini_config)
|
||||
llm = LLM(config=gemini_config, service_id='test-service')
|
||||
llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
|
||||
|
||||
# Verify that litellm_completion was called with the 'thinking' parameter
|
||||
@ -762,7 +762,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
|
||||
@patch('openhands.llm.llm.litellm.token_counter')
|
||||
def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
|
||||
mock_token_counter.return_value = 42
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||
|
||||
token_count = llm.get_token_count(messages)
|
||||
@ -777,7 +777,7 @@ def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
|
||||
def test_get_token_count_with_message_objects(
|
||||
mock_token_counter, default_config, mock_logger
|
||||
):
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
|
||||
# Create a Message object and its equivalent dict
|
||||
message_obj = Message(role='user', content=[TextContent(text='Hello!')])
|
||||
@ -806,7 +806,7 @@ def test_get_token_count_with_custom_tokenizer(
|
||||
|
||||
config = copy.deepcopy(default_config)
|
||||
config.custom_tokenizer = 'custom/tokenizer'
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||
|
||||
token_count = llm.get_token_count(messages)
|
||||
@ -823,7 +823,7 @@ def test_get_token_count_error_handling(
|
||||
mock_token_counter, default_config, mock_logger
|
||||
):
|
||||
mock_token_counter.side_effect = Exception('Token counting failed')
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||
|
||||
token_count = llm.get_token_count(messages)
|
||||
@ -865,7 +865,7 @@ def test_llm_token_usage(mock_litellm_completion, default_config):
|
||||
# We'll make mock_litellm_completion return these responses in sequence
|
||||
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
|
||||
# First call
|
||||
llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}])
|
||||
@ -924,7 +924,7 @@ def test_accumulated_token_usage(mock_litellm_completion, default_config):
|
||||
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
|
||||
|
||||
# Create LLM instance
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
|
||||
# First call
|
||||
llm.completion(messages=[{'role': 'user', 'content': 'First message'}])
|
||||
@ -980,7 +980,7 @@ def test_completion_with_log_completions(mock_litellm_completion, default_config
|
||||
}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
|
||||
test_llm = LLM(config=default_config)
|
||||
test_llm = LLM(config=default_config, service_id='test-service')
|
||||
response = test_llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -1006,7 +1006,7 @@ def test_llm_base_url_auto_protocol_patch(mock_get):
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.json.return_value = {'model': 'fake'}
|
||||
|
||||
llm = LLM(config=config)
|
||||
llm = LLM(config=config, service_id='test-service')
|
||||
llm.init_model_info()
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
@ -1020,7 +1020,7 @@ def test_unknown_model_token_limits():
|
||||
"""Test that models without known token limits get None for both max_output_tokens and max_input_tokens."""
|
||||
# Create LLM instance with a non-existent model to avoid litellm having model info for it
|
||||
config = LLMConfig(model='non-existent-model', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# Verify max_output_tokens and max_input_tokens are initialized to None (default value)
|
||||
assert llm.config.max_output_tokens is None
|
||||
@ -1031,7 +1031,7 @@ def test_max_tokens_from_model_info():
|
||||
"""Test that max_output_tokens and max_input_tokens are correctly initialized from model info."""
|
||||
# Create LLM instance with GPT-4 model which has known token limits
|
||||
config = LLMConfig(model='gpt-4', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# GPT-4 has specific token limits
|
||||
# These are the expected values from litellm
|
||||
@ -1043,7 +1043,7 @@ def test_claude_3_7_sonnet_max_output_tokens():
|
||||
"""Test that Claude 3.7 Sonnet models get the special 64000 max_output_tokens value and default max_input_tokens."""
|
||||
# Create LLM instance with Claude 3.7 Sonnet model
|
||||
config = LLMConfig(model='claude-3-7-sonnet', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# Verify max_output_tokens is set to 64000 for Claude 3.7 Sonnet
|
||||
assert llm.config.max_output_tokens == 64000
|
||||
@ -1055,7 +1055,7 @@ def test_claude_sonnet_4_max_output_tokens():
|
||||
"""Test that Claude Sonnet 4 models get the correct max_output_tokens and max_input_tokens values."""
|
||||
# Create LLM instance with a Claude Sonnet 4 model
|
||||
config = LLMConfig(model='claude-sonnet-4-20250514', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# Verify max_output_tokens is set to the expected value
|
||||
assert llm.config.max_output_tokens == 64000
|
||||
@ -1068,7 +1068,7 @@ def test_sambanova_deepseek_model_max_output_tokens():
|
||||
"""Test that SambaNova DeepSeek-V3-0324 model gets the correct max_output_tokens value."""
|
||||
# Create LLM instance with SambaNova DeepSeek model
|
||||
config = LLMConfig(model='sambanova/DeepSeek-V3-0324', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# SambaNova DeepSeek model has specific token limits
|
||||
# This is the expected value from litellm
|
||||
@ -1081,7 +1081,7 @@ def test_max_output_tokens_override_in_config():
|
||||
config = LLMConfig(
|
||||
model='claude-sonnet-4-20250514', api_key='test_key', max_output_tokens=2048
|
||||
)
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# Verify the config has the overridden max_output_tokens value
|
||||
assert llm.config.max_output_tokens == 2048
|
||||
@ -1098,7 +1098,7 @@ def test_azure_model_default_max_tokens():
|
||||
)
|
||||
|
||||
# Create LLM instance with Azure model
|
||||
llm = LLM(azure_config)
|
||||
llm = LLM(azure_config, service_id='test-service')
|
||||
|
||||
# Verify the config has the default max_output_tokens value
|
||||
assert llm.config.max_output_tokens is None # Default value
|
||||
@ -1143,7 +1143,7 @@ def test_gemini_none_reasoning_effort_uses_thinking_budget(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@ -1167,7 +1167,7 @@ def test_gemini_low_reasoning_effort_uses_thinking_budget(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@ -1191,7 +1191,7 @@ def test_gemini_medium_reasoning_effort_passes_through(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@ -1214,7 +1214,7 @@ def test_gemini_high_reasoning_effort_passes_through(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@ -1235,7 +1235,7 @@ def test_non_gemini_uses_reasoning_effort(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@ -1259,7 +1259,7 @@ def test_non_reasoning_model_no_optimization(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@ -1285,7 +1285,7 @@ def test_gemini_performance_optimization_end_to_end(mock_completion):
|
||||
assert config.reasoning_effort is None
|
||||
|
||||
# Create LLM and make completion
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
messages = [{'role': 'user', 'content': 'Solve this complex problem'}]
|
||||
|
||||
response = llm.completion(messages=messages)
|
||||
|
||||
@ -207,7 +207,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
|
||||
),
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
handler = ServiceContextIssue(
|
||||
GithubIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||
)
|
||||
@ -251,7 +251,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
|
||||
)
|
||||
|
||||
# Initialize LLM and handler
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
handler = ServiceContextPR(
|
||||
GithubPRHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||
)
|
||||
|
||||
@ -463,7 +463,7 @@ async def test_process_issue(
|
||||
[],
|
||||
)
|
||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
handler_instance.llm = LLM(llm_config)
|
||||
handler_instance.llm = LLM(llm_config, service_id='test-service')
|
||||
|
||||
# Mock the runtime and its methods
|
||||
mock_runtime = MagicMock()
|
||||
|
||||
@ -209,7 +209,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
|
||||
),
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
handler = ServiceContextIssue(
|
||||
GitlabIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||
)
|
||||
@ -253,7 +253,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
|
||||
)
|
||||
|
||||
# Initialize LLM and handler
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
handler = ServiceContextPR(
|
||||
GitlabPRHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||
)
|
||||
|
||||
@ -500,7 +500,7 @@ async def test_process_issue(
|
||||
[],
|
||||
)
|
||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
handler_instance.llm = LLM(llm_config)
|
||||
handler_instance.llm = LLM(llm_config, service_id='test-service')
|
||||
|
||||
# Create mock runtime and mock run_controller
|
||||
mock_runtime = MagicMock()
|
||||
|
||||
@ -18,6 +18,7 @@ from openhands.controller.state.control_flags import (
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.main import run_controller
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||
@ -33,6 +34,7 @@ from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
from openhands.memory.condenser.condenser import Condensation
|
||||
from openhands.memory.condenser.impl.conversation_window_condenser import (
|
||||
@ -45,6 +47,7 @@ from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@ -61,15 +64,43 @@ def event_loop():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = OpenHandsConfig().get_llm_config()
|
||||
def mock_agent_with_stats():
|
||||
"""Create a mock agent with properly connected LLM registry and conversation stats."""
|
||||
import uuid
|
||||
|
||||
# Add config with enable_mcp attribute
|
||||
agent.config = MagicMock(spec=AgentConfig)
|
||||
agent.config.enable_mcp = True
|
||||
# Create LLM registry
|
||||
config = OpenHandsConfig()
|
||||
llm_registry = LLMRegistry(config=config)
|
||||
|
||||
# Create conversation stats
|
||||
file_store = InMemoryFileStore({})
|
||||
conversation_id = f'test-conversation-{uuid.uuid4()}'
|
||||
conversation_stats = ConversationStats(
|
||||
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
|
||||
)
|
||||
|
||||
# Connect registry to stats (this is the key requirement)
|
||||
llm_registry.subscribe(conversation_stats.register_llm)
|
||||
|
||||
# Create mock agent
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
agent_config.disabled_microagents = []
|
||||
agent_config.enable_mcp = True
|
||||
llm_registry.service_to_llm.clear()
|
||||
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
|
||||
agent.llm = mock_llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
|
||||
# Add a proper system message mock
|
||||
system_message = SystemMessageAction(
|
||||
@ -79,7 +110,7 @@ def mock_agent():
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
return agent
|
||||
return agent, conversation_stats, llm_registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -134,10 +165,13 @@ async def send_event_to_controller(controller, event):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_agent_state(mock_agent, mock_event_stream):
|
||||
async def test_set_agent_state(mock_agent_with_stats, mock_event_stream):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -152,10 +186,13 @@ async def test_set_agent_state(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_message_action(mock_agent, mock_event_stream):
|
||||
async def test_on_event_message_action(mock_agent_with_stats, mock_event_stream):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -169,10 +206,15 @@ async def test_on_event_message_action(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
|
||||
async def test_on_event_change_agent_state_action(
|
||||
mock_agent_with_stats, mock_event_stream
|
||||
):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -186,10 +228,17 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
|
||||
async def test_react_to_exception(
|
||||
mock_agent_with_stats,
|
||||
mock_event_stream,
|
||||
mock_status_callback,
|
||||
):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
status_callback=mock_status_callback,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
@ -204,12 +253,17 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_to_content_policy_violation(
|
||||
mock_agent, mock_event_stream, mock_status_callback
|
||||
mock_agent_with_stats,
|
||||
mock_event_stream,
|
||||
mock_status_callback,
|
||||
):
|
||||
"""Test that the controller properly handles content policy violations from the LLM."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
status_callback=mock_status_callback,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
@ -246,18 +300,16 @@ async def test_react_to_content_policy_violation(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_fatal_error(
|
||||
test_event_stream, mock_memory, mock_agent
|
||||
test_event_stream, mock_memory, mock_agent_with_stats
|
||||
):
|
||||
config = OpenHandsConfig()
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
mock_agent.step = agent_step_fn
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
@ -284,15 +336,17 @@ async def test_run_controller_with_fatal_error(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
print(f'state: {state}')
|
||||
events = list(test_event_stream.get_events())
|
||||
print(f'event_stream: {events}')
|
||||
@ -312,18 +366,16 @@ async def test_run_controller_with_fatal_error(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_stop_with_stuck(
|
||||
test_event_stream, mock_memory, mock_agent
|
||||
test_event_stream, mock_memory, mock_agent_with_stats
|
||||
):
|
||||
config = OpenHandsConfig()
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
mock_agent.step = agent_step_fn
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
@ -352,15 +404,17 @@ async def test_run_controller_stop_with_stuck(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
events = list(test_event_stream.get_events())
|
||||
print(f'state: {state}')
|
||||
for i, event in enumerate(events):
|
||||
@ -391,11 +445,14 @@ async def test_run_controller_stop_with_stuck(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
async def test_max_iterations_extension(mock_agent_with_stats, mock_event_stream):
|
||||
# Test with headless_mode=False - should extend max_iterations
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -426,6 +483,7 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -450,7 +508,9 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_budget(mock_agent, mock_event_stream):
|
||||
async def test_step_max_budget(mock_agent_with_stats, mock_event_stream):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Metrics are always synced with budget flag before
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 10.1
|
||||
@ -458,9 +518,13 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
|
||||
limit_increase_amount=10, current_value=10.1, max_value=10
|
||||
)
|
||||
|
||||
# Update agent's LLM metrics in place
|
||||
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=10,
|
||||
sid='test',
|
||||
@ -475,7 +539,9 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
||||
async def test_step_max_budget_headless(mock_agent_with_stats, mock_event_stream):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Metrics are always synced with budget flag before
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 10.1
|
||||
@ -483,9 +549,13 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
||||
limit_increase_amount=10, current_value=10.1, max_value=10
|
||||
)
|
||||
|
||||
# Update agent's LLM metrics in place
|
||||
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=10,
|
||||
sid='test',
|
||||
@ -500,12 +570,14 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
||||
async def test_budget_reset_on_continue(mock_agent_with_stats, mock_event_stream):
|
||||
"""Test that when a user continues after hitting the budget limit:
|
||||
1. Error is thrown when budget cap is exceeded
|
||||
2. LLM budget does not reset when user continues
|
||||
3. Budget is extended by adding the initial budget cap to the current accumulated cost
|
||||
"""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Create a real Metrics instance shared between controller state and llm
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 6.0
|
||||
@ -521,10 +593,14 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
||||
),
|
||||
)
|
||||
|
||||
# Update agent's LLM metrics in place
|
||||
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
|
||||
|
||||
# Create controller with budget cap
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=initial_budget,
|
||||
sid='test',
|
||||
@ -570,11 +646,17 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
|
||||
async def test_reset_with_pending_action_no_observation(
|
||||
mock_agent_with_stats, mock_event_stream
|
||||
):
|
||||
"""Test reset() when there's a pending action with tool call metadata but no observation."""
|
||||
# Connect LLM registry to conversation stats
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -617,11 +699,17 @@ async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_s
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_stream):
|
||||
async def test_reset_with_pending_action_stopped_state(
|
||||
mock_agent_with_stats, mock_event_stream
|
||||
):
|
||||
"""Test reset() when there's a pending action and agent state is STOPPED."""
|
||||
# Connect LLM registry to conversation stats
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -665,12 +753,16 @@ async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_st
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_existing_observation(
|
||||
mock_agent, mock_event_stream
|
||||
mock_agent_with_stats, mock_event_stream
|
||||
):
|
||||
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
|
||||
# Connect LLM registry to conversation stats
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -708,11 +800,15 @@ async def test_reset_with_pending_action_existing_observation(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
||||
async def test_reset_without_pending_action(mock_agent_with_stats, mock_event_stream):
|
||||
"""Test reset() when there's no pending action."""
|
||||
# Connect LLM registry to conversation stats
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -738,12 +834,15 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_no_metadata(
|
||||
mock_agent, mock_event_stream, monkeypatch
|
||||
mock_agent_with_stats, mock_event_stream, monkeypatch
|
||||
):
|
||||
"""Test reset() when there's a pending action without tool call metadata."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -782,16 +881,13 @@ async def test_reset_with_pending_action_no_metadata(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_max_iterations_has_metrics(
|
||||
test_event_stream, mock_memory, mock_agent
|
||||
test_event_stream, mock_memory, mock_agent_with_stats
|
||||
):
|
||||
config = OpenHandsConfig(
|
||||
max_iterations=3,
|
||||
)
|
||||
event_stream = test_event_stream
|
||||
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
step_count = 0
|
||||
|
||||
@ -833,15 +929,17 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()))
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
|
||||
state.metrics = mock_agent.llm.metrics
|
||||
assert state.iteration_flag.current_value == 3
|
||||
@ -867,10 +965,17 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback):
|
||||
async def test_notify_on_llm_retry(
|
||||
mock_agent_with_stats,
|
||||
mock_event_stream,
|
||||
mock_status_callback,
|
||||
):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
status_callback=mock_status_callback,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
@ -908,9 +1013,15 @@ async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_ca
|
||||
],
|
||||
)
|
||||
async def test_context_window_exceeded_error_handling(
|
||||
context_window_error, mock_agent, mock_runtime, test_event_stream, mock_memory
|
||||
context_window_error,
|
||||
mock_agent_with_stats,
|
||||
mock_runtime,
|
||||
test_event_stream,
|
||||
mock_memory,
|
||||
):
|
||||
"""Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
max_iterations = 5
|
||||
error_after = 2
|
||||
|
||||
@ -973,18 +1084,20 @@ async def test_context_window_exceeded_error_handling(
|
||||
# state is set to error out before then, if this terminates and we have a
|
||||
# record of the error being thrown we can be confident that the controller
|
||||
# handles the truncation correctly.
|
||||
final_state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
final_state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Check that the context window exception was thrown and the controller
|
||||
# called the agent's `step` function the right number of times.
|
||||
@ -1072,9 +1185,13 @@ async def test_context_window_exceeded_error_handling(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
||||
mock_agent_with_stats,
|
||||
mock_runtime,
|
||||
mock_memory,
|
||||
test_event_stream,
|
||||
):
|
||||
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
class StepState:
|
||||
def __init__(self):
|
||||
@ -1121,18 +1238,20 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
mock_runtime.config = copy.deepcopy(config)
|
||||
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# A timeout error indicates the run_controller entrypoint is not making
|
||||
# progress
|
||||
@ -1156,9 +1275,13 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
||||
mock_agent_with_stats,
|
||||
mock_runtime,
|
||||
mock_memory,
|
||||
test_event_stream,
|
||||
):
|
||||
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
class StepState:
|
||||
def __init__(self):
|
||||
@ -1199,18 +1322,20 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
config = OpenHandsConfig(max_iterations=3)
|
||||
mock_runtime.config = copy.deepcopy(config)
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# A timeout error indicates the run_controller entrypoint is not making
|
||||
# progress
|
||||
@ -1244,7 +1369,11 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
async def test_run_controller_with_memory_error(
|
||||
test_event_stream, mock_agent_with_stats
|
||||
):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
config = OpenHandsConfig()
|
||||
event_stream = test_event_stream
|
||||
|
||||
@ -1273,15 +1402,17 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
with patch.object(
|
||||
memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge
|
||||
):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
|
||||
assert state.iteration_flag.current_value == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
@ -1289,7 +1420,9 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_metrics_copy(mock_agent):
|
||||
async def test_action_metrics_copy(mock_agent_with_stats):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
@ -1299,8 +1432,7 @@ async def test_action_metrics_copy(mock_agent):
|
||||
|
||||
initial_state = State(metrics=metrics, budget_flag=None)
|
||||
|
||||
# Create agent with metrics
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
# Update agent's LLM metrics
|
||||
|
||||
# Add multiple token usages - we should get the last one in the action
|
||||
usage1 = TokenUsage(
|
||||
@ -1342,6 +1474,11 @@ async def test_action_metrics_copy(mock_agent):
|
||||
|
||||
mock_agent.llm.metrics = metrics
|
||||
|
||||
# Register the metrics with the LLM registry
|
||||
llm_registry.service_to_llm['agent'] = mock_agent.llm
|
||||
# Manually notify the conversation stats about the LLM registration
|
||||
llm_registry.notify(RegistryEvent(llm=mock_agent.llm, service_id='agent'))
|
||||
|
||||
# Mock agent step to return an action
|
||||
action = MessageAction(content='Test message')
|
||||
|
||||
@ -1354,6 +1491,7 @@ async def test_action_metrics_copy(mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -1411,12 +1549,13 @@ async def test_action_metrics_copy(mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
async def test_condenser_metrics_included(mock_agent_with_stats, test_event_stream):
|
||||
"""Test that metrics from the condenser's LLM are included in the action metrics."""
|
||||
# Set up agent metrics
|
||||
agent_metrics = Metrics(model_name='agent-model')
|
||||
agent_metrics.accumulated_cost = 0.05
|
||||
agent_metrics._accumulated_token_usage = TokenUsage(
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Set up agent metrics in place
|
||||
mock_agent.llm.metrics.accumulated_cost = 0.05
|
||||
mock_agent.llm.metrics._accumulated_token_usage = TokenUsage(
|
||||
model='agent-model',
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
@ -1424,7 +1563,6 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
cache_write_tokens=10,
|
||||
response_id='agent-accumulated',
|
||||
)
|
||||
# mock_agent.llm.metrics = agent_metrics
|
||||
mock_agent.name = 'TestAgent'
|
||||
|
||||
# Create condenser with its own metrics
|
||||
@ -1442,6 +1580,11 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
)
|
||||
condenser.llm.metrics = condenser_metrics
|
||||
|
||||
# Register the condenser metrics with the LLM registry
|
||||
llm_registry.service_to_llm['condenser'] = condenser.llm
|
||||
# Manually notify the conversation stats about the condenser LLM registration
|
||||
llm_registry.notify(RegistryEvent(llm=condenser.llm, service_id='condenser'))
|
||||
|
||||
# Attach the condenser to the mock_agent
|
||||
mock_agent.condenser = condenser
|
||||
|
||||
@ -1463,11 +1606,12 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=test_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
initial_state=State(metrics=agent_metrics, budget_flag=None),
|
||||
initial_state=State(metrics=mock_agent.llm.metrics, budget_flag=None),
|
||||
)
|
||||
|
||||
# Execute one step
|
||||
@ -1505,7 +1649,9 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_user_message_with_identical_content(test_event_stream, mock_agent):
|
||||
async def test_first_user_message_with_identical_content(
|
||||
test_event_stream, mock_agent_with_stats
|
||||
):
|
||||
"""Test that _first_user_message correctly identifies the first user message.
|
||||
|
||||
This test verifies that messages with identical content but different IDs are properly
|
||||
@ -1514,14 +1660,12 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
|
||||
The issue we're checking is that the comparison (action == self._first_user_message())
|
||||
should correctly differentiate between messages with the same content but different IDs.
|
||||
"""
|
||||
# Create an agent controller
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=test_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -1569,11 +1713,15 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_controller_processes_null_observation_with_cause():
|
||||
async def test_agent_controller_processes_null_observation_with_cause(
|
||||
mock_agent_with_stats,
|
||||
):
|
||||
"""Test that AgentController processes NullObservation events with a cause value.
|
||||
|
||||
And that the agent's step method is called as a result.
|
||||
"""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Create an in-memory file store and real event stream
|
||||
file_store = InMemoryFileStore()
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
@ -1581,19 +1729,11 @@ async def test_agent_controller_processes_null_observation_with_cause():
|
||||
# Create a Memory instance - not used directly in this test but needed for setup
|
||||
Memory(event_stream=event_stream, sid='test-session')
|
||||
|
||||
# Create a mock agent with necessary attributes
|
||||
mock_agent = MagicMock(spec=Agent)
|
||||
mock_agent.get_system_message = MagicMock(
|
||||
return_value=None,
|
||||
)
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
|
||||
|
||||
# Create a controller with the mock agent
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test-session',
|
||||
)
|
||||
@ -1655,8 +1795,12 @@ async def test_agent_controller_processes_null_observation_with_cause():
|
||||
)
|
||||
|
||||
|
||||
def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agent):
|
||||
def test_agent_controller_should_step_with_null_observation_cause_zero(
|
||||
mock_agent_with_stats,
|
||||
):
|
||||
"""Test that AgentController's should_step method returns False for NullObservation with cause = 0."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Create a mock event stream
|
||||
file_store = InMemoryFileStore()
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
@ -1665,6 +1809,7 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test-session',
|
||||
)
|
||||
@ -1683,10 +1828,15 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
||||
)
|
||||
|
||||
|
||||
def test_system_message_in_event_stream(mock_agent, test_event_stream):
|
||||
def test_system_message_in_event_stream(mock_agent_with_stats, test_event_stream):
|
||||
"""Test that SystemMessageAction is added to event stream in AgentController."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
_ = AgentController(
|
||||
agent=mock_agent, event_stream=test_event_stream, iteration_delta=10
|
||||
agent=mock_agent,
|
||||
event_stream=test_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
)
|
||||
|
||||
# Get events from the event stream
|
||||
|
||||
@ -12,8 +12,9 @@ from openhands.controller.state.control_flags import (
|
||||
IterationControlFlag,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream
|
||||
from openhands.events.action import (
|
||||
@ -28,11 +29,39 @@ from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_registry():
|
||||
config = OpenHandsConfig()
|
||||
return LLMRegistry(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_stats():
|
||||
import uuid
|
||||
|
||||
file_store = InMemoryFileStore({})
|
||||
# Use a unique conversation ID for each test to avoid conflicts
|
||||
conversation_id = f'test-conversation-{uuid.uuid4()}'
|
||||
return ConversationStats(
|
||||
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connected_registry_and_stats(llm_registry, conversation_stats):
|
||||
"""Connect the LLMRegistry and ConversationStats properly"""
|
||||
# Subscribe to LLM registry events to track metrics
|
||||
llm_registry.subscribe(conversation_stats.register_llm)
|
||||
return llm_registry, conversation_stats
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
"""Creates an event stream in memory."""
|
||||
@ -42,15 +71,17 @@ def mock_event_stream():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parent_agent():
|
||||
def mock_parent_agent(llm_registry):
|
||||
"""Creates a mock parent agent for testing delegation."""
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.name = 'ParentAgent'
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.service_id = 'main_agent'
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.llm.retry_listener = None # Add retry_listener attribute
|
||||
agent.config = AgentConfig()
|
||||
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
|
||||
|
||||
# Add a proper system message mock
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
@ -61,15 +92,17 @@ def mock_parent_agent():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_child_agent():
|
||||
def mock_child_agent(llm_registry):
|
||||
"""Creates a mock child agent for testing delegation."""
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.name = 'ChildAgent'
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.service_id = 'main_agent'
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.llm.retry_listener = None # Add retry_listener attribute
|
||||
agent.config = AgentConfig()
|
||||
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
@ -78,15 +111,37 @@ def mock_child_agent():
|
||||
return agent
|
||||
|
||||
|
||||
def create_mock_agent_factory(mock_child_agent, llm_registry):
|
||||
"""Helper function to create a mock agent factory with proper LLM registration."""
|
||||
|
||||
def create_mock_agent(config, llm_registry=None):
|
||||
# Register the mock agent's LLM in the registry so get_combined_metrics() can find it
|
||||
if llm_registry:
|
||||
mock_child_agent.llm = llm_registry.get_llm('agent_llm', LLMConfig())
|
||||
mock_child_agent.llm_registry = (
|
||||
llm_registry # Set the llm_registry attribute
|
||||
)
|
||||
return mock_child_agent
|
||||
|
||||
return create_mock_agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream):
|
||||
"""Test that when the parent agent delegates to a child
|
||||
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
|
||||
2. metrics are accumulated globally (delegate is adding to the parents metrics)
|
||||
3. local metrics for the delegate are still accessible
|
||||
async def test_delegation_flow(
|
||||
mock_parent_agent, mock_child_agent, mock_event_stream, connected_registry_and_stats
|
||||
):
|
||||
"""
|
||||
Test that when the parent agent delegates to a child
|
||||
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
|
||||
2. metrics are accumulated globally via LLM registry (delegate adds to the global metrics)
|
||||
3. global metrics tracking works correctly through the LLM registry
|
||||
"""
|
||||
llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
||||
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
|
||||
Agent.get_cls = Mock(
|
||||
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
|
||||
)
|
||||
|
||||
step_count = 0
|
||||
|
||||
@ -97,6 +152,12 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
|
||||
mock_child_agent.step = agent_step_fn
|
||||
|
||||
# Set up the parent agent's LLM with initial cost and register it in the registry
|
||||
# The parent agent's LLM should use the existing registered LLM to ensure proper tracking
|
||||
parent_llm = llm_registry.service_to_llm['agent']
|
||||
parent_llm.metrics.accumulated_cost = 2
|
||||
mock_parent_agent.llm = parent_llm
|
||||
|
||||
parent_metrics = Metrics()
|
||||
parent_metrics.accumulated_cost = 2
|
||||
# Create parent controller
|
||||
@ -114,6 +175,7 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
parent_controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
sid='parent',
|
||||
confirmation_mode=False,
|
||||
@ -180,21 +242,23 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
for i in range(4):
|
||||
delegate_controller.state.iteration_flag.step()
|
||||
delegate_controller.agent.step(delegate_controller.state)
|
||||
# Update the agent's LLM metrics (not the deprecated state metrics)
|
||||
delegate_controller.agent.llm.metrics.add_cost(1.0)
|
||||
|
||||
assert (
|
||||
delegate_controller.state.get_local_step() == 4
|
||||
) # verify local metrics are accessible via snapshot
|
||||
|
||||
# Check that the conversation stats has the combined metrics (parent + delegate)
|
||||
combined_metrics = delegate_controller.state.convo_stats.get_combined_metrics()
|
||||
assert (
|
||||
delegate_controller.state.metrics.accumulated_cost
|
||||
== 6 # Make sure delegate tracks global cost
|
||||
combined_metrics.accumulated_cost
|
||||
== 6 # Make sure delegate tracks global cost (2 from parent + 4 from delegate)
|
||||
)
|
||||
|
||||
assert (
|
||||
delegate_controller.state.get_local_metrics().accumulated_cost
|
||||
== 4 # Delegate spent one dollar per step
|
||||
)
|
||||
# Since metrics are now global via LLM registry, local metrics tracking
|
||||
# is handled differently. The delegate's LLM shares the same metrics object
|
||||
# as the parent for global tracking, so we verify the global total is correct.
|
||||
|
||||
delegate_controller.state.outputs = {'delegate_result': 'done'}
|
||||
|
||||
@ -228,15 +292,18 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
],
|
||||
)
|
||||
async def test_delegate_step_different_states(
|
||||
mock_parent_agent, mock_event_stream, delegate_state
|
||||
mock_parent_agent, mock_event_stream, delegate_state, connected_registry_and_stats
|
||||
):
|
||||
"""Ensure that delegate is closed or remains open based on the delegate's state."""
|
||||
llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Create a state with iteration_flag.max_value set to 10
|
||||
state = State(inputs={})
|
||||
state.iteration_flag.max_value = 10
|
||||
controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -292,11 +359,23 @@ async def test_delegate_step_different_states(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegate_hits_global_limits(
|
||||
mock_child_agent, mock_event_stream, mock_parent_agent
|
||||
mock_child_agent, mock_event_stream, mock_parent_agent, connected_registry_and_stats
|
||||
):
|
||||
"""Global limits from control flags should apply to delegates"""
|
||||
"""
|
||||
Global limits from control flags should apply to delegates
|
||||
"""
|
||||
llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
||||
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
|
||||
Agent.get_cls = Mock(
|
||||
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
|
||||
)
|
||||
|
||||
# Set up the parent agent's LLM with initial cost and register it in the registry
|
||||
mock_parent_agent.llm.metrics.accumulated_cost = 2
|
||||
mock_parent_agent.llm.service_id = 'main_agent'
|
||||
# Register the parent agent's LLM in the registry
|
||||
llm_registry.service_to_llm['main_agent'] = mock_parent_agent.llm
|
||||
|
||||
parent_metrics = Metrics()
|
||||
parent_metrics.accumulated_cost = 2
|
||||
@ -315,6 +394,7 @@ async def test_delegate_hits_global_limits(
|
||||
parent_controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
sid='parent',
|
||||
confirmation_mode=False,
|
||||
|
||||
@ -9,12 +9,13 @@ from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.events import EventStream, EventStreamSubscriber
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
@ -22,44 +23,70 @@ from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a properly configured mock agent with all required nested attributes"""
|
||||
# Create the base mocks
|
||||
agent = MagicMock(spec=Agent)
|
||||
llm = MagicMock(spec=LLM)
|
||||
metrics = MagicMock(spec=Metrics)
|
||||
llm_config = MagicMock(spec=LLMConfig)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
def mock_llm_registry():
|
||||
"""Create a mock LLM registry that properly simulates LLM registration"""
|
||||
config = OpenHandsConfig()
|
||||
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||
return registry
|
||||
|
||||
# Configure the LLM config
|
||||
llm_config.model = 'test-model'
|
||||
llm_config.base_url = 'http://test'
|
||||
llm_config.max_message_chars = 1000
|
||||
|
||||
# Configure the agent config
|
||||
agent_config.disabled_microagents = []
|
||||
agent_config.enable_mcp = True
|
||||
@pytest.fixture
|
||||
def mock_conversation_stats():
|
||||
"""Create a mock ConversationStats that properly simulates metrics tracking"""
|
||||
file_store = InMemoryFileStore({})
|
||||
stats = ConversationStats(
|
||||
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
|
||||
)
|
||||
return stats
|
||||
|
||||
# Set up the chain of mocks
|
||||
llm.metrics = metrics
|
||||
llm.config = llm_config
|
||||
agent.llm = llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
|
||||
return agent
|
||||
@pytest.fixture
|
||||
def connected_registry_and_stats(mock_llm_registry, mock_conversation_stats):
|
||||
"""Connect the LLMRegistry and ConversationStats properly"""
|
||||
# Subscribe to LLM registry events to track metrics
|
||||
mock_llm_registry.subscribe(mock_conversation_stats.register_llm)
|
||||
return mock_llm_registry, mock_conversation_stats
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_mock_agent():
|
||||
def _make_mock_agent(llm_registry):
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
agent_config.disabled_microagents = []
|
||||
agent_config.enable_mcp = True
|
||||
llm_registry.service_to_llm.clear()
|
||||
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
|
||||
agent.llm = mock_llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
return agent
|
||||
|
||||
return _make_mock_agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_session_start_with_no_state(mock_agent):
|
||||
async def test_agent_session_start_with_no_state(
|
||||
make_mock_agent, mock_llm_registry, mock_conversation_stats
|
||||
):
|
||||
"""Test that AgentSession.start() works correctly when there's no state to restore"""
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@ -140,13 +167,18 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
async def test_agent_session_start_with_restored_state(
|
||||
make_mock_agent, mock_llm_registry, mock_conversation_stats
|
||||
):
|
||||
"""Test that AgentSession.start() works correctly when there's a state to restore"""
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@ -230,13 +262,21 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
"""Test that metrics are centralized and shared between controller and agent."""
|
||||
async def test_metrics_centralization_via_conversation_stats(
|
||||
make_mock_agent, connected_registry_and_stats
|
||||
):
|
||||
"""Test that metrics are centralized through the ConversationStats service."""
|
||||
|
||||
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@ -262,6 +302,8 @@ async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# The registry already has a real metrics object set up in the fixture
|
||||
|
||||
# Patch necessary components
|
||||
with (
|
||||
patch(
|
||||
@ -281,49 +323,50 @@ async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
# Verify that the agent's LLM metrics and controller's state metrics are the same object
|
||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
||||
# Verify that the ConversationStats is properly set up
|
||||
assert session.controller.state.convo_stats is mock_conversation_stats
|
||||
|
||||
# Add some metrics to the agent's LLM
|
||||
# Add some metrics to the agent's LLM (simulating LLM usage)
|
||||
test_cost = 0.05
|
||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||
|
||||
# Verify that the cost is reflected in the controller's state metrics
|
||||
assert session.controller.state.metrics.accumulated_cost == test_cost
|
||||
# Verify that the cost is reflected in the combined metrics from the conversation stats
|
||||
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
|
||||
assert combined_metrics.accumulated_cost == test_cost
|
||||
|
||||
# Create a test metrics object to simulate an observation with metrics
|
||||
test_observation_metrics = Metrics()
|
||||
test_observation_metrics.add_cost(0.1)
|
||||
# Add more cost to simulate additional LLM usage
|
||||
additional_cost = 0.1
|
||||
session.controller.agent.llm.metrics.add_cost(additional_cost)
|
||||
|
||||
# Get the current accumulated cost before merging
|
||||
current_cost = session.controller.state.metrics.accumulated_cost
|
||||
# Verify the combined metrics reflect the total cost
|
||||
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
|
||||
assert combined_metrics.accumulated_cost == test_cost + additional_cost
|
||||
|
||||
# Simulate merging metrics from an observation
|
||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
||||
|
||||
# Verify that the merged metrics are reflected in both agent and controller
|
||||
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
|
||||
assert (
|
||||
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
|
||||
)
|
||||
|
||||
# Reset the agent and verify that metrics are not reset
|
||||
# Reset the agent and verify that combined metrics are preserved
|
||||
session.controller.agent.reset()
|
||||
|
||||
# Metrics should still be the same after reset
|
||||
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1
|
||||
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1
|
||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
||||
# Combined metrics should still be preserved after agent reset
|
||||
assert (
|
||||
session.controller.state.convo_stats.get_combined_metrics().accumulated_cost
|
||||
== test_cost + additional_cost
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
async def test_budget_control_flag_syncs_with_metrics(
|
||||
make_mock_agent, connected_registry_and_stats
|
||||
):
|
||||
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
|
||||
|
||||
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@ -349,6 +392,8 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# The registry already has a real metrics object set up in the fixture
|
||||
|
||||
# Patch necessary components
|
||||
with (
|
||||
patch(
|
||||
@ -375,7 +420,7 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
assert session.controller.state.budget_flag.max_value == 1.0
|
||||
assert session.controller.state.budget_flag.current_value == 0.0
|
||||
|
||||
# Add some metrics to the agent's LLM
|
||||
# Add some metrics to the agent's LLM (simulating LLM usage)
|
||||
test_cost = 0.05
|
||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||
|
||||
@ -384,24 +429,31 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert session.controller.state.budget_flag.current_value == test_cost
|
||||
|
||||
# Create a test metrics object to simulate an observation with metrics
|
||||
test_observation_metrics = Metrics()
|
||||
test_observation_metrics.add_cost(0.1)
|
||||
# Add more cost to simulate additional LLM usage
|
||||
additional_cost = 0.1
|
||||
session.controller.agent.llm.metrics.add_cost(additional_cost)
|
||||
|
||||
# Simulate merging metrics from an observation
|
||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
||||
# Sync again and verify the budget flag is updated
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert (
|
||||
session.controller.state.budget_flag.current_value
|
||||
== test_cost + additional_cost
|
||||
)
|
||||
|
||||
# Verify that the budget control flag's current value is updated to match the new accumulated cost
|
||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
||||
|
||||
# Reset the agent and verify that metrics and budget flag are not reset
|
||||
# Reset the agent and verify that budget flag still reflects the accumulated cost
|
||||
session.controller.agent.reset()
|
||||
|
||||
# Budget control flag should still reflect the accumulated cost after reset
|
||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert (
|
||||
session.controller.state.budget_flag.current_value
|
||||
== test_cost + additional_cost
|
||||
)
|
||||
|
||||
|
||||
def test_override_provider_tokens_with_custom_secret():
|
||||
def test_override_provider_tokens_with_custom_secret(
|
||||
mock_llm_registry, mock_conversation_stats
|
||||
):
|
||||
"""Test that override_provider_tokens_with_custom_secret works correctly.
|
||||
|
||||
This test verifies that the method properly removes provider tokens when
|
||||
@ -413,6 +465,8 @@ def test_override_provider_tokens_with_custom_secret():
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create test data
|
||||
|
||||
@ -30,6 +30,7 @@ from openhands.agenthub.readonly_agent.tools import (
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.exceptions import FunctionCallNotExistsError
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.events.action import (
|
||||
@ -42,10 +43,20 @@ from openhands.events.observation.commands import (
|
||||
CmdOutputObservation,
|
||||
)
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser import View
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_llm_registry():
|
||||
def _get_registry(llm_config):
|
||||
config = OpenHandsConfig()
|
||||
config.set_llm_config(llm_config)
|
||||
return LLMRegistry(config=config)
|
||||
|
||||
return _get_registry
|
||||
|
||||
|
||||
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
|
||||
def agent_class(request):
|
||||
if request.param == 'CodeActAgent':
|
||||
@ -57,18 +68,22 @@ def agent_class(request):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(agent_class) -> Union[CodeActAgent, ReadOnlyAgent]:
|
||||
def agent(agent_class, create_llm_registry) -> Union[CodeActAgent, ReadOnlyAgent]:
|
||||
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
config = AgentConfig()
|
||||
agent = agent_class(llm=LLM(LLMConfig()), config=config)
|
||||
agent = agent_class(config=config, llm_registry=create_llm_registry(llm_config))
|
||||
agent.llm = Mock()
|
||||
agent.llm.config = Mock()
|
||||
agent.llm.config.max_message_chars = 1000
|
||||
return agent
|
||||
|
||||
|
||||
def test_agent_with_default_config_has_default_tools():
|
||||
def test_agent_with_default_config_has_default_tools(create_llm_registry):
|
||||
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
config = AgentConfig()
|
||||
codeact_agent = CodeActAgent(llm=LLM(LLMConfig()), config=config)
|
||||
codeact_agent = CodeActAgent(
|
||||
config=config, llm_registry=create_llm_registry(llm_config)
|
||||
)
|
||||
assert len(codeact_agent.tools) > 0
|
||||
default_tool_names = [tool['function']['name'] for tool in codeact_agent.tools]
|
||||
assert {
|
||||
@ -231,7 +246,7 @@ def test_response_to_actions_invalid_tool():
|
||||
readonly_response_to_actions(mock_response)
|
||||
|
||||
|
||||
def test_step_with_no_pending_actions(mock_state: State):
|
||||
def test_step_with_no_pending_actions(mock_state: State, create_llm_registry):
|
||||
# Mock the LLM response
|
||||
mock_response = Mock()
|
||||
mock_response.id = 'mock_id'
|
||||
@ -252,9 +267,12 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
llm.format_messages_for_llm = Mock(return_value=[]) # Mock message formatting
|
||||
|
||||
# Create agent with mocked LLM
|
||||
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
config = AgentConfig()
|
||||
config.enable_prompt_extensions = False
|
||||
agent = CodeActAgent(llm=llm, config=config)
|
||||
agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
|
||||
# Replace the LLM with our mock after creation
|
||||
agent.llm = llm
|
||||
|
||||
# Test step with no pending actions
|
||||
mock_state.latest_user_message = None
|
||||
@ -281,15 +299,10 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
|
||||
@pytest.mark.parametrize('agent_type', ['CodeActAgent', 'ReadOnlyAgent'])
|
||||
def test_correct_tool_description_loaded_based_on_model_name(
|
||||
agent_type, mock_state: State
|
||||
agent_type, create_llm_registry
|
||||
):
|
||||
"""Tests that the simplified tool descriptions are loaded for specific models."""
|
||||
o3_mock_config = Mock()
|
||||
o3_mock_config.model = 'mock_o3_model'
|
||||
|
||||
llm = Mock()
|
||||
llm.config = o3_mock_config
|
||||
|
||||
o3_mock_config = LLMConfig(model='mock_o3_model', api_key='test_key')
|
||||
if agent_type == 'CodeActAgent':
|
||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
|
||||
@ -299,16 +312,19 @@ def test_correct_tool_description_loaded_based_on_model_name(
|
||||
|
||||
agent_class = ReadOnlyAgent
|
||||
|
||||
agent = agent_class(llm=llm, config=AgentConfig())
|
||||
agent = agent_class(
|
||||
config=AgentConfig(),
|
||||
llm_registry=create_llm_registry(o3_mock_config),
|
||||
)
|
||||
for tool in agent.tools:
|
||||
# Assert all descriptions have less than 1024 characters
|
||||
assert len(tool['function']['description']) < 1024
|
||||
|
||||
sonnet_mock_config = Mock()
|
||||
sonnet_mock_config.model = 'mock_sonnet_model'
|
||||
|
||||
llm.config = sonnet_mock_config
|
||||
agent = agent_class(llm=llm, config=AgentConfig())
|
||||
sonnect_mock_config = LLMConfig(model='mock_sonnet_model', api_key='test_key')
|
||||
agent = agent_class(
|
||||
config=AgentConfig(),
|
||||
llm_registry=create_llm_registry(sonnect_mock_config),
|
||||
)
|
||||
# Assert existence of the detailed tool descriptions that are longer than 1024 characters
|
||||
if agent_type == 'CodeActAgent':
|
||||
# This only holds for CodeActAgent
|
||||
@ -481,10 +497,12 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
|
||||
assert isinstance(enhanced_messages[5].content[0], ImageContent)
|
||||
|
||||
|
||||
def test_get_system_message():
|
||||
def test_get_system_message(create_llm_registry):
|
||||
"""Test that the Agent.get_system_message method returns a SystemMessageAction."""
|
||||
# Create a mock agent
|
||||
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
|
||||
config = AgentConfig()
|
||||
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
|
||||
|
||||
result = agent.get_system_message()
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ def test_completion_retries_api_connection_error(
|
||||
]
|
||||
|
||||
# Create an LLM instance and call completion
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@ -70,7 +70,7 @@ def test_completion_max_retries_api_connection_error(
|
||||
]
|
||||
|
||||
# Create an LLM instance and call completion
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
|
||||
# The completion should raise an APIConnectionError after exhausting all retries
|
||||
with pytest.raises(APIConnectionError) as excinfo:
|
||||
|
||||
@ -5,11 +5,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.server.conversation_manager.standalone_conversation_manager import (
|
||||
StandaloneConversationManager,
|
||||
)
|
||||
@ -24,6 +24,7 @@ async def test_auto_generate_title_with_llm():
|
||||
"""Test auto-generating a title using LLM."""
|
||||
# Mock dependencies
|
||||
file_store = InMemoryFileStore()
|
||||
llm_registry = MagicMock(spec=LLMRegistry)
|
||||
|
||||
# Create test conversation with a user message
|
||||
conversation_id = 'test-conversation'
|
||||
@ -46,43 +47,33 @@ async def test_auto_generate_title_with_llm():
|
||||
mock_event_store.search_events.return_value = [user_message]
|
||||
mock_event_store_cls.return_value = mock_event_store
|
||||
|
||||
# Mock the LLM response
|
||||
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
|
||||
mock_llm = mock_llm_cls.return_value
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Python Data Analysis Script'
|
||||
mock_llm.completion.return_value = mock_response
|
||||
# Mock the LLM registry response
|
||||
llm_registry.request_extraneous_completion.return_value = (
|
||||
'Python Data Analysis Script'
|
||||
)
|
||||
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings
|
||||
)
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings, llm_registry
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert title == 'Python Data Analysis Script'
|
||||
# Verify the result
|
||||
assert title == 'Python Data Analysis Script'
|
||||
|
||||
# Verify EventStore was created with the correct parameters
|
||||
mock_event_store_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
# Verify EventStore was created with the correct parameters
|
||||
mock_event_store_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
|
||||
# Verify LLM was called with appropriate parameters
|
||||
mock_llm_cls.assert_called_once_with(
|
||||
LLMConfig(
|
||||
model='test-model',
|
||||
api_key='test-key',
|
||||
base_url='test-url',
|
||||
)
|
||||
)
|
||||
mock_llm.completion.assert_called_once()
|
||||
# Verify LLM registry was called with appropriate parameters
|
||||
llm_registry.request_extraneous_completion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -90,6 +81,7 @@ async def test_auto_generate_title_fallback():
|
||||
"""Test auto-generating a title with fallback to truncation when LLM fails."""
|
||||
# Mock dependencies
|
||||
file_store = InMemoryFileStore()
|
||||
llm_registry = MagicMock(spec=LLMRegistry)
|
||||
|
||||
# Create test conversation with a user message
|
||||
conversation_id = 'test-conversation'
|
||||
@ -111,31 +103,29 @@ async def test_auto_generate_title_fallback():
|
||||
mock_event_store.search_events.return_value = [user_message]
|
||||
mock_event_store_cls.return_value = mock_event_store
|
||||
|
||||
# Mock the LLM to raise an exception
|
||||
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
|
||||
mock_llm = mock_llm_cls.return_value
|
||||
mock_llm.completion.side_effect = Exception('Test error')
|
||||
# Mock the LLM registry to raise an exception
|
||||
llm_registry.request_extraneous_completion.side_effect = Exception('Test error')
|
||||
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings
|
||||
)
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings, llm_registry
|
||||
)
|
||||
|
||||
# Verify the result is a truncated version of the message
|
||||
assert title == 'This is a very long message th...'
|
||||
assert len(title) <= 35
|
||||
# Verify the result is a truncated version of the message
|
||||
assert title == 'This is a very long message th...'
|
||||
assert len(title) <= 35
|
||||
|
||||
# Verify EventStore was created with the correct parameters
|
||||
mock_event_store_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
# Verify EventStore was created with the correct parameters
|
||||
mock_event_store_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -143,6 +133,7 @@ async def test_auto_generate_title_no_messages():
|
||||
"""Test auto-generating a title when there are no user messages."""
|
||||
# Mock dependencies
|
||||
file_store = InMemoryFileStore()
|
||||
llm_registry = MagicMock(spec=LLMRegistry)
|
||||
|
||||
# Create test conversation with no messages
|
||||
conversation_id = 'test-conversation'
|
||||
@ -166,7 +157,7 @@ async def test_auto_generate_title_no_messages():
|
||||
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings
|
||||
conversation_id, user_id, file_store, settings, llm_registry
|
||||
)
|
||||
|
||||
# Verify the result is empty
|
||||
@ -186,6 +177,7 @@ async def test_update_conversation_with_title():
|
||||
sio.emit = AsyncMock()
|
||||
file_store = InMemoryFileStore()
|
||||
server_config = MagicMock()
|
||||
llm_registry = MagicMock(spec=LLMRegistry)
|
||||
|
||||
# Create test conversation
|
||||
conversation_id = 'test-conversation'
|
||||
@ -222,7 +214,9 @@ async def test_update_conversation_with_title():
|
||||
AsyncMock(return_value='Generated Title'),
|
||||
):
|
||||
# Call the method
|
||||
await manager._update_conversation_for_event(user_id, conversation_id, settings)
|
||||
await manager._update_conversation_for_event(
|
||||
user_id, conversation_id, settings, llm_registry
|
||||
)
|
||||
|
||||
# Verify the title was updated
|
||||
assert mock_metadata.title == 'Generated Title'
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest_asyncio
|
||||
|
||||
from openhands.cli import main as cli
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import MessageAction
|
||||
|
||||
@ -124,12 +125,14 @@ def mock_config():
|
||||
'' # Empty string, not starting with 'tvly-'
|
||||
)
|
||||
config.search_api_key = search_api_key_mock
|
||||
config.get_llm_config_from_agent.return_value = LLMConfig(model='model')
|
||||
|
||||
# Mock sandbox with volumes attribute to prevent finalize_config issues
|
||||
config.sandbox = MagicMock()
|
||||
config.sandbox.volumes = (
|
||||
None # This prevents finalize_config from overriding workspace_base
|
||||
)
|
||||
config.model_name = 'model'
|
||||
|
||||
return config
|
||||
|
||||
@ -213,7 +216,11 @@ async def test_run_session_without_initial_action(
|
||||
# Assertions for initialization flow
|
||||
mock_display_runtime_init.assert_called_once_with('local')
|
||||
mock_display_animation.assert_called_once()
|
||||
mock_create_agent.assert_called_once_with(mock_config)
|
||||
# Check that mock_config is the first parameter to create_agent
|
||||
mock_create_agent.assert_called_once()
|
||||
assert mock_create_agent.call_args[0][0] == mock_config, (
|
||||
'First parameter to create_agent should be mock_config'
|
||||
)
|
||||
mock_add_mcp_tools.assert_called_once_with(mock_agent, mock_runtime, mock_memory)
|
||||
mock_create_runtime.assert_called_once()
|
||||
mock_create_controller.assert_called_once()
|
||||
|
||||
@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from litellm.exceptions import AuthenticationError
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.cli import main as cli
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import MessageAction
|
||||
|
||||
@ -45,11 +47,10 @@ def mock_config():
|
||||
config.workspace_base = '/test/dir'
|
||||
|
||||
# Set up LLM config to use OpenHands provider
|
||||
llm_config = MagicMock()
|
||||
llm_config = LLMConfig(model='openhands/o3', api_key=SecretStr('invalid-api-key'))
|
||||
llm_config.model = 'openhands/o3' # Use OpenHands provider with o3 model
|
||||
llm_config.api_key = MagicMock()
|
||||
llm_config.api_key.get_secret_value.return_value = 'invalid-api-key'
|
||||
config.llm = llm_config
|
||||
config.get_llm_config.return_value = llm_config
|
||||
config.get_llm_config_from_agent.return_value = llm_config
|
||||
|
||||
# Mock search_api_key with get_secret_value method
|
||||
search_api_key_mock = MagicMock()
|
||||
|
||||
@ -13,6 +13,7 @@ from openhands.core.config.mcp_config import (
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation import ErrorObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||
|
||||
|
||||
@ -23,8 +24,12 @@ class TestCLIRuntimeMCP:
|
||||
"""Set up test fixtures."""
|
||||
self.config = OpenHandsConfig()
|
||||
self.event_stream = MagicMock()
|
||||
llm_registry = LLMRegistry(config=OpenHandsConfig())
|
||||
self.runtime = CLIRuntime(
|
||||
config=self.config, event_stream=self.event_stream, sid='test-session'
|
||||
config=self.config,
|
||||
event_stream=self.event_stream,
|
||||
sid='test-session',
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -7,10 +7,18 @@ import pytest
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.events import EventStream
|
||||
|
||||
# Mock LLMRegistry
|
||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
# Create a mock LLMRegistry class
|
||||
class MockLLMRegistry:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for testing."""
|
||||
@ -25,7 +33,8 @@ def cli_runtime(temp_dir):
|
||||
event_stream = EventStream('test', file_store)
|
||||
config = OpenHandsConfig()
|
||||
config.workspace_base = temp_dir
|
||||
runtime = CLIRuntime(config, event_stream)
|
||||
llm_registry = MockLLMRegistry(config)
|
||||
runtime = CLIRuntime(config, event_stream, llm_registry)
|
||||
runtime._runtime_initialized = True # Skip initialization
|
||||
return runtime
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from openhands.core.config.condenser_config import (
|
||||
StructuredSummaryCondenserConfig,
|
||||
)
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.core.schema.action import ActionType
|
||||
from openhands.events.event import Event, EventSource
|
||||
@ -24,6 +25,7 @@ from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser import Condenser
|
||||
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
|
||||
from openhands.memory.condenser.impl import (
|
||||
@ -38,6 +40,7 @@ from openhands.memory.condenser.impl import (
|
||||
StructuredSummaryCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.pipeline import CondenserPipeline
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
|
||||
|
||||
def create_test_event(
|
||||
@ -56,12 +59,15 @@ def create_test_event(
|
||||
@pytest.fixture
|
||||
def mock_llm() -> LLM:
|
||||
"""Mocks an LLM object with a utility function for setting and resetting response contents in unit tests."""
|
||||
# Create a real LLMConfig instead of a mock to properly handle SecretStr api_key
|
||||
real_config = LLMConfig(
|
||||
model='gpt-4o', api_key='test_key', custom_llm_provider=None
|
||||
)
|
||||
|
||||
# Create a MagicMock for the LLM object
|
||||
mock_llm = MagicMock(
|
||||
spec=LLM,
|
||||
config=MagicMock(
|
||||
spec=LLMConfig, model='gpt-4o', api_key='test_key', custom_llm_provider=None
|
||||
),
|
||||
config=real_config,
|
||||
metrics=MagicMock(),
|
||||
)
|
||||
_mock_content = None
|
||||
@ -95,6 +101,23 @@ def mock_llm() -> LLM:
|
||||
return mock_llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_stats() -> ConversationStats:
|
||||
"""Creates a mock ConversationStats service."""
|
||||
mock_stats = MagicMock(spec=ConversationStats)
|
||||
return mock_stats
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_registry(mock_llm, mock_conversation_stats) -> LLMRegistry:
|
||||
"""Creates an actual LLMRegistry that returns real LLMs."""
|
||||
# Create an actual LLMRegistry with a basic OpenHandsConfig
|
||||
config = OpenHandsConfig()
|
||||
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
class RollingCondenserTestHarness:
|
||||
"""Test harness for rolling condensers.
|
||||
|
||||
@ -165,10 +188,10 @@ class RollingCondenserTestHarness:
|
||||
return ((index - max_size) // target_size) + 1
|
||||
|
||||
|
||||
def test_noop_condenser_from_config():
|
||||
def test_noop_condenser_from_config(mock_llm_registry):
|
||||
"""Test that the NoOpCondenser objects can be made from config."""
|
||||
config = NoOpCondenserConfig()
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, NoOpCondenser)
|
||||
|
||||
@ -189,11 +212,11 @@ def test_noop_condenser():
|
||||
assert result == View(events=events)
|
||||
|
||||
|
||||
def test_observation_masking_condenser_from_config():
|
||||
def test_observation_masking_condenser_from_config(mock_llm_registry):
|
||||
"""Test that ObservationMaskingCondenser objects can be made from config."""
|
||||
attention_window = 5
|
||||
config = ObservationMaskingCondenserConfig(attention_window=attention_window)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, ObservationMaskingCondenser)
|
||||
assert condenser.attention_window == attention_window
|
||||
@ -229,11 +252,11 @@ def test_observation_masking_condenser_respects_attention_window():
|
||||
assert event == condensed_event
|
||||
|
||||
|
||||
def test_browser_output_condenser_from_config():
|
||||
def test_browser_output_condenser_from_config(mock_llm_registry):
|
||||
"""Test that BrowserOutputCondenser objects can be made from config."""
|
||||
attention_window = 5
|
||||
config = BrowserOutputCondenserConfig(attention_window=attention_window)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, BrowserOutputCondenser)
|
||||
assert condenser.attention_window == attention_window
|
||||
@ -271,12 +294,12 @@ def test_browser_output_condenser_respects_attention_window():
|
||||
assert event == condensed_event
|
||||
|
||||
|
||||
def test_recent_events_condenser_from_config():
|
||||
def test_recent_events_condenser_from_config(mock_llm_registry):
|
||||
"""Test that RecentEventsCondenser objects can be made from config."""
|
||||
max_events = 5
|
||||
keep_first = True
|
||||
config = RecentEventsCondenserConfig(keep_first=keep_first, max_events=max_events)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, RecentEventsCondenser)
|
||||
assert condenser.max_events == max_events
|
||||
@ -334,14 +357,14 @@ def test_recent_events_condenser():
|
||||
assert result[2]._message == 'Event 5' # kept from max_events
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_from_config():
|
||||
def test_llm_summarizing_condenser_from_config(mock_llm_registry):
|
||||
"""Test that LLMSummarizingCondenser objects can be made from config."""
|
||||
config = LLMSummarizingCondenserConfig(
|
||||
max_size=50,
|
||||
keep_first=10,
|
||||
llm_config=LLMConfig(model='gpt-4o', api_key='test_key', caching_prompt=True),
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, LLMSummarizingCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
@ -349,25 +372,33 @@ def test_llm_summarizing_condenser_from_config():
|
||||
assert condenser.max_size == 50
|
||||
assert condenser.keep_first == 10
|
||||
|
||||
# Since this condenser can't take advantage of caching, we intercept the
|
||||
# passed config and manually flip the caching prompt to False.
|
||||
assert not condenser.llm.config.caching_prompt
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_invalid_config():
|
||||
def test_llm_summarizing_condenser_invalid_config(mock_llm, mock_llm_registry):
|
||||
"""Test that LLMSummarizingCondenser raises error when keep_first > max_size."""
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
LLMSummarizingCondenser,
|
||||
llm=MagicMock(),
|
||||
llm=mock_llm,
|
||||
max_size=4,
|
||||
keep_first=2,
|
||||
)
|
||||
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), max_size=0)
|
||||
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), keep_first=-1)
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
LLMSummarizingCondenser,
|
||||
llm=mock_llm,
|
||||
max_size=0,
|
||||
)
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
LLMSummarizingCondenser,
|
||||
llm=mock_llm,
|
||||
keep_first=-1,
|
||||
)
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
|
||||
def test_llm_summarizing_condenser_gives_expected_view_size(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that LLMSummarizingCondenser maintains the correct view size."""
|
||||
max_size = 10
|
||||
condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm)
|
||||
@ -383,12 +414,16 @@ def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
def test_llm_summarizing_condenser_keeps_first_and_summary_events(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events."""
|
||||
max_size = 10
|
||||
keep_first = 3
|
||||
condenser = LLMSummarizingCondenser(
|
||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
||||
max_size=max_size,
|
||||
keep_first=keep_first,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||
@ -412,14 +447,14 @@ def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
assert isinstance(view[keep_first], AgentCondensationObservation)
|
||||
|
||||
|
||||
def test_amortized_forgetting_condenser_from_config():
|
||||
def test_amortized_forgetting_condenser_from_config(mock_llm_registry):
|
||||
"""Test that AmortizedForgettingCondenser objects can be made from config."""
|
||||
max_size = 50
|
||||
keep_first = 10
|
||||
config = AmortizedForgettingCondenserConfig(
|
||||
max_size=max_size, keep_first=keep_first
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, AmortizedForgettingCondenser)
|
||||
assert condenser.max_size == max_size
|
||||
@ -475,7 +510,7 @@ def test_amortized_forgetting_condenser_keeps_first_and_last_events():
|
||||
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
||||
|
||||
|
||||
def test_llm_attention_condenser_from_config():
|
||||
def test_llm_attention_condenser_from_config(mock_llm_registry):
|
||||
"""Test that LLMAttentionCondenser objects can be made from config."""
|
||||
config = LLMAttentionCondenserConfig(
|
||||
max_size=50,
|
||||
@ -486,37 +521,32 @@ def test_llm_attention_condenser_from_config():
|
||||
caching_prompt=True,
|
||||
),
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, LLMAttentionCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
|
||||
assert condenser.max_size == 50
|
||||
assert condenser.keep_first == 10
|
||||
|
||||
# Since this condenser can't take advantage of caching, we intercept the
|
||||
# passed config and manually flip the caching prompt to False.
|
||||
assert not condenser.llm.config.caching_prompt
|
||||
# Create a mock LLM that doesn't support function calling
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.is_function_calling_active.return_value = False
|
||||
|
||||
# Create a new registry that returns our mock LLM that doesn't support function calling
|
||||
mock_registry = MagicMock(spec=LLMRegistry)
|
||||
mock_registry.get_llm.return_value = mock_llm
|
||||
|
||||
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config, mock_registry)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_invalid_config():
|
||||
"""Test that LLMAttentionCondenser raises an error if the configured LLM doesn't support response schema."""
|
||||
config = LLMAttentionCondenserConfig(
|
||||
max_size=50,
|
||||
keep_first=10,
|
||||
llm_config=LLMConfig(
|
||||
model='claude-2', # Older model that doesn't support response schema
|
||||
api_key='test_key',
|
||||
),
|
||||
)
|
||||
|
||||
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
|
||||
def test_llm_attention_condenser_gives_expected_view_size(mock_llm, mock_llm_registry):
|
||||
"""Test that the LLMAttentionCondenser gives views of the expected size."""
|
||||
max_size = 10
|
||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size,
|
||||
keep_first=0,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
|
||||
@ -534,10 +564,16 @@ def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
|
||||
def test_llm_attention_condenser_handles_events_outside_history(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that the LLMAttentionCondenser handles event IDs that aren't from the event history."""
|
||||
max_size = 2
|
||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size,
|
||||
keep_first=0,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
|
||||
@ -555,10 +591,14 @@ def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_handles_too_many_events(mock_llm):
|
||||
def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_llm_registry):
|
||||
"""Test that the LLMAttentionCondenser handles when the response contains too many event IDs."""
|
||||
max_size = 2
|
||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size,
|
||||
keep_first=0,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
|
||||
@ -576,12 +616,16 @@ def test_llm_attention_condenser_handles_too_many_events(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_handles_too_few_events(mock_llm):
|
||||
def test_llm_attention_condenser_handles_too_few_events(mock_llm, mock_llm_registry):
|
||||
"""Test that the LLMAttentionCondenser handles when the response contains too few event IDs."""
|
||||
max_size = 2
|
||||
# Developer note: We must specify keep_first=0 because
|
||||
# keep_first (1) >= max_size//2 (1) is invalid.
|
||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size,
|
||||
keep_first=0,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
|
||||
@ -597,12 +641,14 @@ def test_llm_attention_condenser_handles_too_few_events(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
|
||||
def test_llm_attention_condenser_handles_keep_first_events(mock_llm, mock_llm_registry):
|
||||
"""Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size)."""
|
||||
max_size = 12
|
||||
keep_first = 4
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
||||
max_size=max_size,
|
||||
keep_first=keep_first,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
@ -620,7 +666,7 @@ def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
|
||||
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
||||
|
||||
|
||||
def test_structured_summary_condenser_from_config():
|
||||
def test_structured_summary_condenser_from_config(mock_llm_registry):
|
||||
"""Test that StructuredSummaryCondenser objects can be made from config."""
|
||||
config = StructuredSummaryCondenserConfig(
|
||||
max_size=50,
|
||||
@ -631,7 +677,7 @@ def test_structured_summary_condenser_from_config():
|
||||
caching_prompt=True,
|
||||
),
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, StructuredSummaryCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
@ -639,40 +685,55 @@ def test_structured_summary_condenser_from_config():
|
||||
assert condenser.max_size == 50
|
||||
assert condenser.keep_first == 10
|
||||
|
||||
# Since this condenser can't take advantage of caching, we intercept the
|
||||
# passed config and manually flip the caching prompt to False.
|
||||
assert not condenser.llm.config.caching_prompt
|
||||
|
||||
|
||||
def test_structured_summary_condenser_invalid_config():
|
||||
def test_structured_summary_condenser_invalid_config(mock_llm):
|
||||
"""Test that StructuredSummaryCondenser raises error when keep_first > max_size."""
|
||||
# Since the condenser only works when function calling is on, we need to
|
||||
# mock up the check for that.
|
||||
llm = MagicMock()
|
||||
llm.is_function_calling_active.return_value = True
|
||||
mock_llm.is_function_calling_active.return_value = True
|
||||
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
StructuredSummaryCondenser,
|
||||
llm=llm,
|
||||
llm=mock_llm,
|
||||
max_size=4,
|
||||
keep_first=2,
|
||||
)
|
||||
|
||||
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, max_size=0)
|
||||
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, keep_first=-1)
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
StructuredSummaryCondenser,
|
||||
llm=mock_llm,
|
||||
max_size=0,
|
||||
)
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
StructuredSummaryCondenser,
|
||||
llm=mock_llm,
|
||||
keep_first=-1,
|
||||
)
|
||||
|
||||
# If all other parameters are good but there's no function calling the
|
||||
# condenser still counts as improperly configured.
|
||||
llm.is_function_calling_active.return_value = False
|
||||
# Create a mock LLM that doesn't support function calling
|
||||
mock_llm_no_func = MagicMock()
|
||||
mock_llm_no_func.is_function_calling_active.return_value = False
|
||||
|
||||
pytest.raises(
|
||||
ValueError, StructuredSummaryCondenser, llm=llm, max_size=40, keep_first=2
|
||||
ValueError,
|
||||
StructuredSummaryCondenser,
|
||||
llm=mock_llm_no_func,
|
||||
max_size=40,
|
||||
keep_first=2,
|
||||
)
|
||||
|
||||
|
||||
def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
|
||||
def test_structured_summary_condenser_gives_expected_view_size(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that StructuredSummaryCondenser maintains the correct view size."""
|
||||
max_size = 10
|
||||
mock_llm.is_function_calling_active.return_value = True
|
||||
condenser = StructuredSummaryCondenser(max_size=max_size, llm=mock_llm)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
@ -686,12 +747,17 @@ def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
def test_structured_summary_condenser_keeps_first_and_summary_events(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that the StructuredSummaryCondenser appropriately maintains the event prefix and any summary events."""
|
||||
max_size = 10
|
||||
keep_first = 3
|
||||
mock_llm.is_function_calling_active.return_value = True
|
||||
condenser = StructuredSummaryCondenser(
|
||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
||||
max_size=max_size,
|
||||
keep_first=keep_first,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||
@ -715,7 +781,7 @@ def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
assert isinstance(view[keep_first], AgentCondensationObservation)
|
||||
|
||||
|
||||
def test_condenser_pipeline_from_config():
|
||||
def test_condenser_pipeline_from_config(mock_llm_registry):
|
||||
"""Test that CondenserPipeline condensers can be created from configuration objects."""
|
||||
config = CondenserPipelineConfig(
|
||||
condensers=[
|
||||
@ -728,7 +794,7 @@ def test_condenser_pipeline_from_config():
|
||||
),
|
||||
]
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, CondenserPipeline)
|
||||
assert len(condenser.condensers) == 3
|
||||
|
||||
490
tests/unit/test_conversation_stats.py
Normal file
490
tests/unit/test_conversation_stats.py
Normal file
@ -0,0 +1,490 @@
|
||||
import base64
|
||||
import pickle
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store():
|
||||
"""Create a mock file store for testing."""
|
||||
return InMemoryFileStore({})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_stats(mock_file_store):
|
||||
"""Create a ConversationStats instance for testing."""
|
||||
return ConversationStats(
|
||||
file_store=mock_file_store,
|
||||
conversation_id='test-conversation-id',
|
||||
user_id='test-user-id',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_registry():
|
||||
"""Create a mock LLM registry that properly simulates LLM registration."""
|
||||
config = OpenHandsConfig()
|
||||
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connected_registry_and_stats(mock_llm_registry, conversation_stats):
|
||||
"""Connect the LLMRegistry and ConversationStats properly."""
|
||||
# Subscribe to LLM registry events to track metrics
|
||||
mock_llm_registry.subscribe(conversation_stats.register_llm)
|
||||
return mock_llm_registry, conversation_stats
|
||||
|
||||
|
||||
def test_conversation_stats_initialization(conversation_stats):
|
||||
"""Test that ConversationStats initializes correctly."""
|
||||
assert conversation_stats.conversation_id == 'test-conversation-id'
|
||||
assert conversation_stats.user_id == 'test-user-id'
|
||||
assert conversation_stats.service_to_metrics == {}
|
||||
assert isinstance(conversation_stats.restored_metrics, dict)
|
||||
|
||||
|
||||
def test_save_metrics(conversation_stats, mock_file_store):
|
||||
"""Test that metrics are saved correctly."""
|
||||
# Add a service with metrics
|
||||
service_id = 'test-service'
|
||||
metrics = Metrics(model_name='gpt-4')
|
||||
metrics.add_cost(0.05)
|
||||
conversation_stats.service_to_metrics[service_id] = metrics
|
||||
|
||||
# Save metrics
|
||||
conversation_stats.save_metrics()
|
||||
|
||||
# Verify that metrics were saved to the file store
|
||||
try:
|
||||
# Verify the saved content can be decoded and unpickled
|
||||
encoded = mock_file_store.read(conversation_stats.metrics_path)
|
||||
pickled = base64.b64decode(encoded)
|
||||
restored = pickle.loads(pickled)
|
||||
|
||||
assert service_id in restored
|
||||
assert restored[service_id].accumulated_cost == 0.05
|
||||
except FileNotFoundError:
|
||||
pytest.fail(f'File not found: {conversation_stats.metrics_path}')
|
||||
|
||||
|
||||
def test_maybe_restore_metrics(mock_file_store):
|
||||
"""Test that metrics are restored correctly."""
|
||||
# Create metrics to save
|
||||
service_id = 'test-service'
|
||||
metrics = Metrics(model_name='gpt-4')
|
||||
metrics.add_cost(0.1)
|
||||
service_to_metrics = {service_id: metrics}
|
||||
|
||||
# Serialize and save metrics
|
||||
pickled = pickle.dumps(service_to_metrics)
|
||||
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
|
||||
|
||||
# Create a new ConversationStats with pre-populated file store
|
||||
conversation_id = 'test-conversation-id'
|
||||
user_id = 'test-user-id'
|
||||
|
||||
# Get the correct path using the same function as ConversationStats
|
||||
from openhands.storage.locations import get_conversation_stats_filename
|
||||
|
||||
metrics_path = get_conversation_stats_filename(conversation_id, user_id)
|
||||
|
||||
# Write to the correct path
|
||||
mock_file_store.write(metrics_path, serialized_metrics)
|
||||
|
||||
# Create ConversationStats which should restore metrics
|
||||
stats = ConversationStats(
|
||||
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Verify metrics were restored
|
||||
assert service_id in stats.restored_metrics
|
||||
assert stats.restored_metrics[service_id].accumulated_cost == 0.1
|
||||
|
||||
|
||||
def test_get_combined_metrics(conversation_stats):
|
||||
"""Test that combined metrics are calculated correctly."""
|
||||
# Add multiple services with metrics
|
||||
service1 = 'service1'
|
||||
metrics1 = Metrics(model_name='gpt-4')
|
||||
metrics1.add_cost(0.05)
|
||||
metrics1.add_token_usage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=8000,
|
||||
response_id='resp1',
|
||||
)
|
||||
|
||||
service2 = 'service2'
|
||||
metrics2 = Metrics(model_name='gpt-3.5')
|
||||
metrics2.add_cost(0.02)
|
||||
metrics2.add_token_usage(
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4000,
|
||||
response_id='resp2',
|
||||
)
|
||||
|
||||
conversation_stats.service_to_metrics[service1] = metrics1
|
||||
conversation_stats.service_to_metrics[service2] = metrics2
|
||||
|
||||
# Get combined metrics
|
||||
combined = conversation_stats.get_combined_metrics()
|
||||
|
||||
# Verify combined metrics
|
||||
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
|
||||
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
|
||||
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
|
||||
assert (
|
||||
combined.accumulated_token_usage.context_window == 8000
|
||||
) # max of 8000 and 4000
|
||||
|
||||
|
||||
def test_get_metrics_for_service(conversation_stats):
|
||||
"""Test that metrics for a specific service are retrieved correctly."""
|
||||
# Add a service with metrics
|
||||
service_id = 'test-service'
|
||||
metrics = Metrics(model_name='gpt-4')
|
||||
metrics.add_cost(0.05)
|
||||
conversation_stats.service_to_metrics[service_id] = metrics
|
||||
|
||||
# Get metrics for the service
|
||||
retrieved_metrics = conversation_stats.get_metrics_for_service(service_id)
|
||||
|
||||
# Verify metrics
|
||||
assert retrieved_metrics.accumulated_cost == 0.05
|
||||
assert retrieved_metrics is metrics # Should be the same object
|
||||
|
||||
# Test getting metrics for non-existent service
|
||||
# Use a specific exception message pattern instead of a blind Exception
|
||||
with pytest.raises(Exception, match='LLM service does not exist'):
|
||||
conversation_stats.get_metrics_for_service('non-existent-service')
|
||||
|
||||
|
||||
def test_register_llm_with_new_service(conversation_stats):
|
||||
"""Test registering a new LLM service."""
|
||||
# Create a real LLM instance with a mock config
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Patch the LLM class to avoid actual API calls
|
||||
with patch('openhands.llm.llm.litellm_completion'):
|
||||
llm = LLM(service_id='new-service', config=llm_config)
|
||||
|
||||
# Create a registry event
|
||||
service_id = 'new-service'
|
||||
event = RegistryEvent(llm=llm, service_id=service_id)
|
||||
|
||||
# Register the LLM
|
||||
conversation_stats.register_llm(event)
|
||||
|
||||
# Verify the service was registered
|
||||
assert service_id in conversation_stats.service_to_metrics
|
||||
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
|
||||
|
||||
|
||||
def test_register_llm_with_restored_metrics(conversation_stats):
|
||||
"""Test registering an LLM service with restored metrics."""
|
||||
# Create restored metrics
|
||||
service_id = 'restored-service'
|
||||
restored_metrics = Metrics(model_name='gpt-4')
|
||||
restored_metrics.add_cost(0.1)
|
||||
conversation_stats.restored_metrics = {service_id: restored_metrics}
|
||||
|
||||
# Create a real LLM instance with a mock config
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Patch the LLM class to avoid actual API calls
|
||||
with patch('openhands.llm.llm.litellm_completion'):
|
||||
llm = LLM(service_id=service_id, config=llm_config)
|
||||
|
||||
# Create a registry event
|
||||
event = RegistryEvent(llm=llm, service_id=service_id)
|
||||
|
||||
# Register the LLM
|
||||
conversation_stats.register_llm(event)
|
||||
|
||||
# Verify the service was registered with restored metrics
|
||||
assert service_id in conversation_stats.service_to_metrics
|
||||
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
|
||||
assert llm.metrics.accumulated_cost == 0.1 # Restored cost
|
||||
|
||||
# Verify the specific service was removed from restored_metrics
|
||||
assert service_id not in conversation_stats.restored_metrics
|
||||
assert hasattr(
|
||||
conversation_stats, 'restored_metrics'
|
||||
) # The dict should still exist
|
||||
|
||||
|
||||
def test_llm_registry_notifications(connected_registry_and_stats):
|
||||
"""Test that LLM registry notifications update conversation stats."""
|
||||
mock_llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Create a new LLM through the registry
|
||||
service_id = 'test-service'
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Get LLM from registry (this should trigger the notification)
|
||||
llm = mock_llm_registry.get_llm(service_id, llm_config)
|
||||
|
||||
# Verify the service was registered in conversation stats
|
||||
assert service_id in conversation_stats.service_to_metrics
|
||||
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
|
||||
|
||||
# Add some metrics to the LLM
|
||||
llm.metrics.add_cost(0.05)
|
||||
llm.metrics.add_token_usage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=8000,
|
||||
response_id='resp1',
|
||||
)
|
||||
|
||||
# Verify the metrics are reflected in conversation stats
|
||||
assert conversation_stats.service_to_metrics[service_id].accumulated_cost == 0.05
|
||||
assert (
|
||||
conversation_stats.service_to_metrics[
|
||||
service_id
|
||||
].accumulated_token_usage.prompt_tokens
|
||||
== 100
|
||||
)
|
||||
assert (
|
||||
conversation_stats.service_to_metrics[
|
||||
service_id
|
||||
].accumulated_token_usage.completion_tokens
|
||||
== 50
|
||||
)
|
||||
|
||||
# Get combined metrics and verify
|
||||
combined = conversation_stats.get_combined_metrics()
|
||||
assert combined.accumulated_cost == 0.05
|
||||
assert combined.accumulated_token_usage.prompt_tokens == 100
|
||||
assert combined.accumulated_token_usage.completion_tokens == 50
|
||||
|
||||
|
||||
def test_multiple_llm_services(connected_registry_and_stats):
|
||||
"""Test tracking metrics for multiple LLM services."""
|
||||
mock_llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Create multiple LLMs through the registry
|
||||
service1 = 'service1'
|
||||
service2 = 'service2'
|
||||
|
||||
llm_config1 = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
llm_config2 = LLMConfig(
|
||||
model='gpt-3.5-turbo',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Get LLMs from registry (this should trigger notifications)
|
||||
llm1 = mock_llm_registry.get_llm(service1, llm_config1)
|
||||
llm2 = mock_llm_registry.get_llm(service2, llm_config2)
|
||||
|
||||
# Add different metrics to each LLM
|
||||
llm1.metrics.add_cost(0.05)
|
||||
llm1.metrics.add_token_usage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=8000,
|
||||
response_id='resp1',
|
||||
)
|
||||
|
||||
llm2.metrics.add_cost(0.02)
|
||||
llm2.metrics.add_token_usage(
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4000,
|
||||
response_id='resp2',
|
||||
)
|
||||
|
||||
# Verify services were registered in conversation stats
|
||||
assert service1 in conversation_stats.service_to_metrics
|
||||
assert service2 in conversation_stats.service_to_metrics
|
||||
|
||||
# Verify individual metrics
|
||||
assert conversation_stats.service_to_metrics[service1].accumulated_cost == 0.05
|
||||
assert conversation_stats.service_to_metrics[service2].accumulated_cost == 0.02
|
||||
|
||||
# Get combined metrics and verify
|
||||
combined = conversation_stats.get_combined_metrics()
|
||||
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
|
||||
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
|
||||
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
|
||||
assert (
|
||||
combined.accumulated_token_usage.context_window == 8000
|
||||
) # max of 8000 and 4000
|
||||
|
||||
|
||||
def test_register_llm_with_multiple_restored_services_bug(conversation_stats):
|
||||
"""Test that reproduces the bug where del self.restored_metrics deletes entire dict instead of specific service."""
|
||||
# Create restored metrics for multiple services
|
||||
service_id_1 = 'service-1'
|
||||
service_id_2 = 'service-2'
|
||||
|
||||
restored_metrics_1 = Metrics(model_name='gpt-4')
|
||||
restored_metrics_1.add_cost(0.1)
|
||||
|
||||
restored_metrics_2 = Metrics(model_name='gpt-3.5')
|
||||
restored_metrics_2.add_cost(0.05)
|
||||
|
||||
# Set up restored metrics for both services
|
||||
conversation_stats.restored_metrics = {
|
||||
service_id_1: restored_metrics_1,
|
||||
service_id_2: restored_metrics_2,
|
||||
}
|
||||
|
||||
# Create LLM configs
|
||||
llm_config_1 = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
llm_config_2 = LLMConfig(
|
||||
model='gpt-3.5-turbo',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Patch the LLM class to avoid actual API calls
|
||||
with patch('openhands.llm.llm.litellm_completion'):
|
||||
# Register first LLM
|
||||
llm_1 = LLM(service_id=service_id_1, config=llm_config_1)
|
||||
event_1 = RegistryEvent(llm=llm_1, service_id=service_id_1)
|
||||
conversation_stats.register_llm(event_1)
|
||||
|
||||
# Verify first service was registered with restored metrics
|
||||
assert service_id_1 in conversation_stats.service_to_metrics
|
||||
assert llm_1.metrics.accumulated_cost == 0.1
|
||||
|
||||
# After registering first service, restored_metrics should still contain service_id_2
|
||||
assert service_id_2 in conversation_stats.restored_metrics
|
||||
|
||||
# Register second LLM - this should also work with restored metrics
|
||||
llm_2 = LLM(service_id=service_id_2, config=llm_config_2)
|
||||
event_2 = RegistryEvent(llm=llm_2, service_id=service_id_2)
|
||||
conversation_stats.register_llm(event_2)
|
||||
|
||||
# Verify second service was registered with restored metrics
|
||||
assert service_id_2 in conversation_stats.service_to_metrics
|
||||
assert llm_2.metrics.accumulated_cost == 0.05
|
||||
|
||||
# After both services are registered, restored_metrics should be empty
|
||||
assert len(conversation_stats.restored_metrics) == 0
|
||||
|
||||
|
||||
def test_save_and_restore_workflow(mock_file_store):
|
||||
"""Test the full workflow of saving and restoring metrics."""
|
||||
# Create initial conversation stats
|
||||
conversation_id = 'test-conversation-id'
|
||||
user_id = 'test-user-id'
|
||||
|
||||
stats1 = ConversationStats(
|
||||
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Add a service with metrics
|
||||
service_id = 'test-service'
|
||||
metrics = Metrics(model_name='gpt-4')
|
||||
metrics.add_cost(0.05)
|
||||
metrics.add_token_usage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=8000,
|
||||
response_id='resp1',
|
||||
)
|
||||
stats1.service_to_metrics[service_id] = metrics
|
||||
|
||||
# Save metrics
|
||||
stats1.save_metrics()
|
||||
|
||||
# Create a new conversation stats instance that should restore the metrics
|
||||
stats2 = ConversationStats(
|
||||
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Verify metrics were restored
|
||||
assert service_id in stats2.restored_metrics
|
||||
assert stats2.restored_metrics[service_id].accumulated_cost == 0.05
|
||||
assert (
|
||||
stats2.restored_metrics[service_id].accumulated_token_usage.prompt_tokens == 100
|
||||
)
|
||||
assert (
|
||||
stats2.restored_metrics[service_id].accumulated_token_usage.completion_tokens
|
||||
== 50
|
||||
)
|
||||
|
||||
# Create a real LLM instance with a mock config
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Patch the LLM class to avoid actual API calls
|
||||
with patch('openhands.llm.llm.litellm_completion'):
|
||||
llm = LLM(service_id=service_id, config=llm_config)
|
||||
|
||||
# Create a registry event
|
||||
event = RegistryEvent(llm=llm, service_id=service_id)
|
||||
|
||||
# Register the LLM to trigger restoration
|
||||
stats2.register_llm(event)
|
||||
|
||||
# Verify metrics were applied to the LLM
|
||||
assert llm.metrics.accumulated_cost == 0.05
|
||||
assert llm.metrics.accumulated_token_usage.prompt_tokens == 100
|
||||
assert llm.metrics.accumulated_token_usage.completion_tokens == 50
|
||||
@ -1,6 +1,6 @@
|
||||
"""Tests for the conversation summary generator."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -11,55 +11,51 @@ from openhands.utils.conversation_summary import generate_conversation_title
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_conversation_title_empty_message():
|
||||
"""Test that an empty message returns None."""
|
||||
result = await generate_conversation_title('', MagicMock())
|
||||
mock_llm_registry = MagicMock()
|
||||
mock_llm_config = LLMConfig(model='test-model')
|
||||
|
||||
result = await generate_conversation_title('', mock_llm_config, mock_llm_registry)
|
||||
assert result is None
|
||||
|
||||
result = await generate_conversation_title(' ', MagicMock())
|
||||
result = await generate_conversation_title(
|
||||
' ', mock_llm_config, mock_llm_registry
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_conversation_title_success():
|
||||
"""Test successful title generation."""
|
||||
# Create a proper mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Generated Title'
|
||||
# Create a mock LLM registry that returns a title
|
||||
mock_llm_registry = MagicMock()
|
||||
mock_llm_registry.request_extraneous_completion.return_value = 'Generated Title'
|
||||
|
||||
# Create a mock LLM instance with a synchronous completion method
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.completion = MagicMock(return_value=mock_response)
|
||||
mock_llm_config = LLMConfig(model='test-model')
|
||||
|
||||
# Patch the LLM class to return our mock
|
||||
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', LLMConfig(model='test-model')
|
||||
)
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', mock_llm_config, mock_llm_registry
|
||||
)
|
||||
|
||||
assert result == 'Generated Title'
|
||||
# Verify the mock was called with the expected arguments
|
||||
mock_llm.completion.assert_called_once()
|
||||
mock_llm_registry.request_extraneous_completion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_conversation_title_long_title():
|
||||
"""Test that long titles are truncated."""
|
||||
# Create a proper mock response with a long title
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[
|
||||
0
|
||||
].message.content = 'This is a very long title that should be truncated because it exceeds the maximum length'
|
||||
# Create a mock LLM registry that returns a long title
|
||||
mock_llm_registry = MagicMock()
|
||||
mock_llm_registry.request_extraneous_completion.return_value = 'This is a very long title that should be truncated because it exceeds the maximum length'
|
||||
|
||||
# Create a mock LLM instance with a synchronous completion method
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.completion = MagicMock(return_value=mock_response)
|
||||
mock_llm_config = LLMConfig(model='test-model')
|
||||
|
||||
# Patch the LLM class to return our mock
|
||||
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', LLMConfig(model='test-model'), max_length=30
|
||||
)
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?',
|
||||
mock_llm_config,
|
||||
mock_llm_registry,
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
# Verify the title is truncated correctly
|
||||
assert len(result) <= 30
|
||||
@ -69,15 +65,17 @@ async def test_generate_conversation_title_long_title():
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_conversation_title_exception():
|
||||
"""Test that exceptions are handled gracefully."""
|
||||
# Create a mock LLM instance with a synchronous completion method that raises an exception
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.completion = MagicMock(side_effect=Exception('Test error'))
|
||||
# Create a mock LLM registry that raises an exception
|
||||
mock_llm_registry = MagicMock()
|
||||
mock_llm_registry.request_extraneous_completion.side_effect = Exception(
|
||||
'Test error'
|
||||
)
|
||||
|
||||
# Patch the LLM class to return our mock
|
||||
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', LLMConfig(model='test-model')
|
||||
)
|
||||
mock_llm_config = LLMConfig(model='test-model')
|
||||
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', mock_llm_config, mock_llm_registry
|
||||
)
|
||||
|
||||
# Verify that None is returned when an exception occurs
|
||||
assert result is None
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.events import EventStream
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||
|
||||
|
||||
@ -40,12 +41,17 @@ def event_stream():
|
||||
return MagicMock(spec=EventStream)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_registry():
|
||||
return MagicMock(spec=LLMRegistry)
|
||||
|
||||
|
||||
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
|
||||
def test_container_stopped_when_keep_runtime_alive_false(
|
||||
mock_stop_containers, mock_docker_client, config, event_stream
|
||||
mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
|
||||
):
|
||||
# Arrange
|
||||
runtime = DockerRuntime(config, event_stream, sid='test-sid')
|
||||
runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
|
||||
runtime.container = mock_docker_client.containers.get.return_value
|
||||
|
||||
# Act
|
||||
@ -57,11 +63,11 @@ def test_container_stopped_when_keep_runtime_alive_false(
|
||||
|
||||
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
|
||||
def test_container_not_stopped_when_keep_runtime_alive_true(
|
||||
mock_stop_containers, mock_docker_client, config, event_stream
|
||||
mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
|
||||
):
|
||||
# Arrange
|
||||
config.sandbox.keep_runtime_alive = True
|
||||
runtime = DockerRuntime(config, event_stream, sid='test-sid')
|
||||
runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
|
||||
runtime.container = mock_docker_client.containers.get.return_value
|
||||
|
||||
# Act
|
||||
|
||||
178
tests/unit/test_llm_registry.py
Normal file
178
tests/unit/test_llm_registry.py
Normal file
@ -0,0 +1,178 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
|
||||
|
||||
|
||||
class TestLLMRegistry(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test environment before each test."""
|
||||
# Create a basic LLM config for testing
|
||||
self.llm_config = LLMConfig(model='test-model')
|
||||
|
||||
# Create a basic OpenHands config for testing
|
||||
self.config = OpenHandsConfig(
|
||||
llms={'llm': self.llm_config}, default_agent='CodeActAgent'
|
||||
)
|
||||
|
||||
# Create a registry for testing
|
||||
self.registry = LLMRegistry(config=self.config)
|
||||
|
||||
def test_get_llm_creates_new_llm(self):
|
||||
"""Test that get_llm creates a new LLM when service doesn't exist."""
|
||||
service_id = 'test-service'
|
||||
|
||||
# Mock the _create_new_llm method to avoid actual LLM initialization
|
||||
with patch.object(self.registry, '_create_new_llm') as mock_create:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.config = self.llm_config
|
||||
mock_create.return_value = mock_llm
|
||||
|
||||
# Get LLM for the first time
|
||||
llm = self.registry.get_llm(service_id, self.llm_config)
|
||||
|
||||
# Verify LLM was created and stored
|
||||
self.assertEqual(llm, mock_llm)
|
||||
mock_create.assert_called_once_with(
|
||||
config=self.llm_config, service_id=service_id
|
||||
)
|
||||
|
||||
def test_get_llm_returns_existing_llm(self):
|
||||
"""Test that get_llm returns existing LLM when service already exists."""
|
||||
service_id = 'test-service'
|
||||
|
||||
# Mock the _create_new_llm method to avoid actual LLM initialization
|
||||
with patch.object(self.registry, '_create_new_llm') as mock_create:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.config = self.llm_config
|
||||
mock_create.return_value = mock_llm
|
||||
|
||||
# Get LLM for the first time
|
||||
llm1 = self.registry.get_llm(service_id, self.llm_config)
|
||||
|
||||
# Manually add to registry to simulate existing LLM
|
||||
self.registry.service_to_llm[service_id] = mock_llm
|
||||
|
||||
# Get LLM for the second time - should return the same instance
|
||||
llm2 = self.registry.get_llm(service_id, self.llm_config)
|
||||
|
||||
# Verify same LLM instance is returned
|
||||
self.assertEqual(llm1, llm2)
|
||||
self.assertEqual(llm1, mock_llm)
|
||||
|
||||
# Verify _create_new_llm was only called once
|
||||
mock_create.assert_called_once()
|
||||
|
||||
def test_get_llm_with_different_config_raises_error(self):
|
||||
"""Test that requesting same service ID with different config raises an error."""
|
||||
service_id = 'test-service'
|
||||
different_config = LLMConfig(model='different-model')
|
||||
|
||||
# Manually add an LLM to the registry to simulate existing service
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.config = self.llm_config
|
||||
self.registry.service_to_llm[service_id] = mock_llm
|
||||
|
||||
# Attempt to get LLM with different config should raise ValueError
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.registry.get_llm(service_id, different_config)
|
||||
|
||||
self.assertIn('Requesting same service ID', str(context.exception))
|
||||
self.assertIn('with different config', str(context.exception))
|
||||
|
||||
def test_get_llm_without_config_raises_error(self):
|
||||
"""Test that requesting new LLM without config raises an error."""
|
||||
service_id = 'test-service'
|
||||
|
||||
# Attempt to get LLM without providing config should raise ValueError
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.registry.get_llm(service_id, None)
|
||||
|
||||
self.assertIn(
|
||||
'Requesting new LLM without specifying LLM config', str(context.exception)
|
||||
)
|
||||
|
||||
def test_request_extraneous_completion(self):
|
||||
"""Test that requesting an extraneous completion creates a new LLM if needed."""
|
||||
service_id = 'extraneous-service'
|
||||
messages = [{'role': 'user', 'content': 'Hello, world!'}]
|
||||
|
||||
# Mock the _create_new_llm method to avoid actual LLM initialization
|
||||
with patch.object(self.registry, '_create_new_llm') as mock_create:
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = ' Hello from the LLM! '
|
||||
mock_llm.completion.return_value = mock_response
|
||||
mock_create.return_value = mock_llm
|
||||
|
||||
# Mock the side effect to add the LLM to the registry
|
||||
def side_effect(*args, **kwargs):
|
||||
self.registry.service_to_llm[service_id] = mock_llm
|
||||
return mock_llm
|
||||
|
||||
mock_create.side_effect = side_effect
|
||||
|
||||
# Request a completion
|
||||
response = self.registry.request_extraneous_completion(
|
||||
service_id=service_id,
|
||||
llm_config=self.llm_config,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Verify the response (should be stripped)
|
||||
self.assertEqual(response, 'Hello from the LLM!')
|
||||
|
||||
# Verify that _create_new_llm was called with correct parameters
|
||||
mock_create.assert_called_once_with(
|
||||
config=self.llm_config, service_id=service_id, with_listener=False
|
||||
)
|
||||
|
||||
# Verify completion was called with correct messages
|
||||
mock_llm.completion.assert_called_once_with(messages=messages)
|
||||
|
||||
def test_get_active_llm(self):
|
||||
"""Test that get_active_llm returns the active agent LLM."""
|
||||
active_llm = self.registry.get_active_llm()
|
||||
self.assertEqual(active_llm, self.registry.active_agent_llm)
|
||||
|
||||
def test_subscribe_and_notify(self):
|
||||
"""Test the subscription and notification system."""
|
||||
events_received = []
|
||||
|
||||
def callback(event: RegistryEvent):
|
||||
events_received.append(event)
|
||||
|
||||
# Subscribe to events
|
||||
self.registry.subscribe(callback)
|
||||
|
||||
# Should receive notification for the active agent LLM
|
||||
self.assertEqual(len(events_received), 1)
|
||||
self.assertEqual(events_received[0].llm, self.registry.active_agent_llm)
|
||||
self.assertEqual(
|
||||
events_received[0].service_id, self.registry.active_agent_llm.service_id
|
||||
)
|
||||
|
||||
# Test that the subscriber is set correctly
|
||||
self.assertIsNotNone(self.registry.subscriber)
|
||||
|
||||
# Test notify method directly with a mock event
|
||||
with patch.object(self.registry, 'subscriber') as mock_subscriber:
|
||||
mock_event = MagicMock()
|
||||
self.registry.notify(mock_event)
|
||||
mock_subscriber.assert_called_once_with(mock_event)
|
||||
|
||||
def test_registry_has_unique_id(self):
|
||||
"""Test that each registry instance has a unique ID."""
|
||||
registry2 = LLMRegistry(config=self.config)
|
||||
self.assertNotEqual(self.registry.registry_id, registry2.registry_id)
|
||||
self.assertTrue(len(self.registry.registry_id) > 0)
|
||||
self.assertTrue(len(registry2.registry_id) > 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -12,6 +12,8 @@ from openhands.core.config.mcp_config import (
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.session.session import Session
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
@ -428,6 +430,8 @@ async def test_session_preserves_env_mcp_config(monkeypatch):
|
||||
file_store=InMemoryFileStore({}),
|
||||
config=config,
|
||||
sio=AsyncMock(),
|
||||
llm_registry=LLMRegistry(config=OpenHandsConfig()),
|
||||
convo_stats=ConversationStats(None, 'test-sid', None),
|
||||
)
|
||||
|
||||
# Create empty settings
|
||||
|
||||
@ -8,7 +8,8 @@ import pytest
|
||||
from mcp import McpError
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController, AgentState
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import EventSource
|
||||
@ -17,6 +18,8 @@ from openhands.events.stream import EventStream
|
||||
from openhands.mcp.client import MCPClient
|
||||
from openhands.mcp.tool import MCPClientTool
|
||||
from openhands.mcp.utils import call_tool_mcp
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
class MockConfig:
|
||||
@ -34,6 +37,11 @@ class MockLLM:
|
||||
self.config = MockConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def convo_stats():
|
||||
return ConversationStats(None, 'convo-id', None)
|
||||
|
||||
|
||||
class MockAgent(Agent):
|
||||
"""Mock agent for testing."""
|
||||
|
||||
@ -53,7 +61,7 @@ class MockAgent(Agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_timeout_error_handling():
|
||||
async def test_mcp_tool_timeout_error_handling(convo_stats):
|
||||
"""Test that verifies MCP tool timeout errors are properly handled and returned as observations."""
|
||||
# Create a mock MCPClient
|
||||
mock_client = mock.MagicMock(spec=MCPClient)
|
||||
@ -80,7 +88,7 @@ async def test_mcp_tool_timeout_error_handling():
|
||||
mock_client.tool_map = {'test_tool': mock_tool}
|
||||
|
||||
# Create a mock file store
|
||||
mock_file_store = mock.MagicMock()
|
||||
mock_file_store = InMemoryFileStore({})
|
||||
|
||||
# Create a mock event stream
|
||||
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
|
||||
@ -90,13 +98,12 @@ async def test_mcp_tool_timeout_error_handling():
|
||||
|
||||
# Create a mock agent controller
|
||||
controller = AgentController(
|
||||
sid='test-session',
|
||||
file_store=mock_file_store,
|
||||
user_id='test-user',
|
||||
agent=agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=convo_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=None,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
# Set up the agent state
|
||||
@ -143,7 +150,7 @@ async def test_mcp_tool_timeout_error_handling():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_timeout_agent_continuation():
|
||||
async def test_mcp_tool_timeout_agent_continuation(convo_stats):
|
||||
"""Test that verifies the agent can continue processing after an MCP tool timeout."""
|
||||
# Create a mock MCPClient
|
||||
mock_client = mock.MagicMock(spec=MCPClient)
|
||||
@ -170,7 +177,7 @@ async def test_mcp_tool_timeout_agent_continuation():
|
||||
mock_client.tool_map = {'test_tool': mock_tool}
|
||||
|
||||
# Create a mock file store
|
||||
mock_file_store = mock.MagicMock()
|
||||
mock_file_store = InMemoryFileStore({})
|
||||
|
||||
# Create a mock event stream
|
||||
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
|
||||
@ -180,13 +187,12 @@ async def test_mcp_tool_timeout_agent_continuation():
|
||||
|
||||
# Create a mock agent controller
|
||||
controller = AgentController(
|
||||
sid='test-session',
|
||||
file_store=mock_file_store,
|
||||
user_id='test-user',
|
||||
agent=agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=convo_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=None,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
# Set up the agent state
|
||||
|
||||
@ -21,11 +21,13 @@ from openhands.events.observation.agent import (
|
||||
from openhands.events.serialization.observation import observation_from_dict
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
from openhands.utils.prompt import (
|
||||
@ -42,6 +44,12 @@ def file_store():
|
||||
return InMemoryFileStore({})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_registry(file_store):
|
||||
"""Create a mock LLMRegistry for testing."""
|
||||
return MagicMock(spec=LLMRegistry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_stream(file_store):
|
||||
"""Create a test event stream."""
|
||||
@ -90,24 +98,29 @@ def mock_agent():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent):
|
||||
async def test_memory_on_event_exception_handling(
|
||||
memory, event_stream, mock_agent, mock_llm_registry
|
||||
):
|
||||
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory method to raise an exception
|
||||
with patch.object(
|
||||
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
|
||||
with (
|
||||
patch.object(
|
||||
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
|
||||
),
|
||||
patch('openhands.core.main.create_agent', return_value=mock_agent),
|
||||
):
|
||||
state = await run_controller(
|
||||
config=OpenHandsConfig(),
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
llm_registry=mock_llm_registry,
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
@ -118,7 +131,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream, mock_age
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
memory, event_stream, mock_agent
|
||||
memory, event_stream, mock_agent, mock_llm_registry
|
||||
):
|
||||
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
|
||||
# Create a mock runtime
|
||||
@ -126,19 +139,22 @@ async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory._on_workspace_context_recall to raise an exception
|
||||
with patch.object(
|
||||
memory,
|
||||
'_find_microagent_knowledge',
|
||||
side_effect=Exception('Test error from _find_microagent_knowledge'),
|
||||
with (
|
||||
patch.object(
|
||||
memory,
|
||||
'_find_microagent_knowledge',
|
||||
side_effect=Exception('Test error from _find_microagent_knowledge'),
|
||||
),
|
||||
patch('openhands.core.main.create_agent', return_value=mock_agent),
|
||||
):
|
||||
state = await run_controller(
|
||||
config=OpenHandsConfig(),
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
llm_registry=mock_llm_registry,
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
@ -593,12 +609,14 @@ REPOSITORY INSTRUCTIONS: This is the second test repository.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_instructions_plumbed_to_memory(
|
||||
mock_agent, event_stream, file_store
|
||||
mock_agent, event_stream, file_store, mock_llm_registry
|
||||
):
|
||||
# Setup
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=ConversationStats(file_store, 'test-session', None),
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
|
||||
@ -3,26 +3,30 @@ from litellm import ModelResponse
|
||||
|
||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
llm = LLM(
|
||||
LLMConfig(
|
||||
model='claude-3-5-sonnet-20241022',
|
||||
api_key='fake',
|
||||
caching_prompt=True,
|
||||
)
|
||||
def llm_config():
|
||||
return LLMConfig(
|
||||
model='claude-3-5-sonnet-20241022',
|
||||
api_key='fake',
|
||||
caching_prompt=True,
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def codeact_agent(mock_llm):
|
||||
def llm_registry():
|
||||
registry = LLMRegistry(config=OpenHandsConfig())
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def codeact_agent(llm_registry):
|
||||
config = AgentConfig()
|
||||
agent = CodeActAgent(mock_llm, config)
|
||||
agent = CodeActAgent(config, llm_registry)
|
||||
return agent
|
||||
|
||||
|
||||
|
||||
@ -12,14 +12,26 @@ from openhands.events.observation import NullObservation, Observation
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError, Repository
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
class TestRuntime(Runtime):
|
||||
class MockRuntime(Runtime):
|
||||
"""A concrete implementation of Runtime for testing"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Ensure llm_registry is provided if not already in kwargs
|
||||
if 'llm_registry' not in kwargs and len(args) < 3:
|
||||
# Create a mock LLMRegistry if not provided
|
||||
config = (
|
||||
kwargs.get('config')
|
||||
if 'config' in kwargs
|
||||
else args[0]
|
||||
if args
|
||||
else OpenHandsConfig()
|
||||
)
|
||||
kwargs['llm_registry'] = LLMRegistry(config=config)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.run_action_calls = []
|
||||
self._execute_shell_fn_git_handler = MagicMock(
|
||||
@ -89,9 +101,11 @@ def runtime(temp_dir):
|
||||
)
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
llm_registry = LLMRegistry(config=config)
|
||||
runtime = MockRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
llm_registry=llm_registry,
|
||||
sid='test',
|
||||
user_id='test_user',
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
@ -119,7 +133,7 @@ async def test_export_latest_git_provider_tokens_no_user_id(temp_dir):
|
||||
config = OpenHandsConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(config=config, event_stream=event_stream, sid='test')
|
||||
runtime = MockRuntime(config=config, event_stream=event_stream, sid='test')
|
||||
|
||||
# Create a command that would normally trigger token export
|
||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
||||
@ -137,7 +151,7 @@ async def test_export_latest_git_provider_tokens_no_token_ref(temp_dir):
|
||||
config = OpenHandsConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
@ -177,7 +191,7 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
|
||||
)
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
@ -225,7 +239,7 @@ async def test_clone_or_init_repo_no_repo_init_git_in_empty_workspace(temp_dir):
|
||||
config.init_git_in_empty_workspace = True
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id=None
|
||||
)
|
||||
|
||||
@ -249,7 +263,7 @@ async def test_clone_or_init_repo_no_repo_no_user_id_with_workspace_base(temp_di
|
||||
config.workspace_base = '/some/path' # Set workspace_base
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id=None
|
||||
)
|
||||
|
||||
@ -267,7 +281,7 @@ async def test_clone_or_init_repo_auth_error(temp_dir):
|
||||
config = OpenHandsConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
@ -298,7 +312,7 @@ async def test_clone_or_init_repo_github_with_token(temp_dir, monkeypatch):
|
||||
{ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))}
|
||||
)
|
||||
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
@ -336,7 +350,7 @@ async def test_clone_or_init_repo_github_no_token(temp_dir, monkeypatch):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
@ -371,7 +385,7 @@ async def test_clone_or_init_repo_gitlab_with_token(temp_dir, monkeypatch):
|
||||
{ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))}
|
||||
)
|
||||
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
@ -410,7 +424,7 @@ async def test_clone_or_init_repo_with_branch(temp_dir, monkeypatch):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
|
||||
@ -9,10 +9,12 @@ import pytest
|
||||
from openhands.core.config import OpenHandsConfig, SandboxConfig
|
||||
from openhands.events import EventStream
|
||||
from openhands.integrations.service_types import ProviderType, Repository
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.microagent.microagent import (
|
||||
RepoMicroagent,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
class MockRuntime(Runtime):
|
||||
@ -24,12 +26,21 @@ class MockRuntime(Runtime):
|
||||
config.workspace_mount_path_in_sandbox = str(workspace_root)
|
||||
config.sandbox = SandboxConfig()
|
||||
|
||||
# Create a mock event stream
|
||||
# Create a mock event stream and file store
|
||||
file_store = get_file_store('local', str(workspace_root))
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
event_stream.file_store = file_store
|
||||
|
||||
# Create a mock LLM registry
|
||||
llm_registry = LLMRegistry(config)
|
||||
|
||||
# Initialize the parent class properly
|
||||
super().__init__(
|
||||
config=config, event_stream=event_stream, sid='test', git_provider_tokens={}
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
llm_registry=llm_registry,
|
||||
sid='test',
|
||||
git_provider_tokens={},
|
||||
)
|
||||
|
||||
self._workspace_root = workspace_root
|
||||
|
||||
@ -595,7 +595,7 @@ async def test_check_usertask(
|
||||
analyzer = InvariantAnalyzer(event_stream)
|
||||
mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
analyzer.guardrail_llm = LLM(config=default_config)
|
||||
analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
|
||||
analyzer.check_browsing_alignment = True
|
||||
data = [
|
||||
(MessageAction(usertask), EventSource.USER),
|
||||
@ -657,7 +657,7 @@ async def test_check_fillaction(
|
||||
analyzer = InvariantAnalyzer(event_stream)
|
||||
mock_response = {'choices': [{'message': {'content': is_harmful}}]}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
analyzer.guardrail_llm = LLM(config=default_config)
|
||||
analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
|
||||
analyzer.check_browsing_alignment = True
|
||||
data = [
|
||||
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
|
||||
|
||||
@ -7,7 +7,9 @@ from litellm.exceptions import (
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.session import Session
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
@ -33,10 +35,28 @@ def default_llm_config():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_registry():
|
||||
config = OpenHandsConfig()
|
||||
return LLMRegistry(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_stats():
|
||||
file_store = InMemoryFileStore({})
|
||||
return ConversationStats(
|
||||
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
async def test_notify_on_llm_retry(
|
||||
mock_litellm_completion, mock_sio, default_llm_config
|
||||
mock_litellm_completion,
|
||||
mock_sio,
|
||||
default_llm_config,
|
||||
llm_registry,
|
||||
conversation_stats,
|
||||
):
|
||||
config = OpenHandsConfig()
|
||||
config.set_llm_config(default_llm_config)
|
||||
@ -44,6 +64,8 @@ async def test_notify_on_llm_retry(
|
||||
sid='..sid..',
|
||||
file_store=InMemoryFileStore({}),
|
||||
config=config,
|
||||
llm_registry=llm_registry,
|
||||
convo_stats=conversation_stats,
|
||||
sio=mock_sio,
|
||||
user_id='..uid..',
|
||||
)
|
||||
@ -56,12 +78,20 @@ async def test_notify_on_llm_retry(
|
||||
),
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
llm = session._create_llm('..cls..')
|
||||
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
# Set the retry listener on the registry
|
||||
llm_registry.retry_listner = session._notify_on_llm_retry
|
||||
|
||||
# Create an LLM through the registry
|
||||
llm = llm_registry.get_llm(
|
||||
service_id='test_service',
|
||||
config=default_llm_config,
|
||||
)
|
||||
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert mock_litellm_completion.call_count == 2
|
||||
session.queue_status_message.assert_called_once_with(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user