From 93e9db320677d1e0cdf7d29dfdcaa1afa2ccd4fd Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Thu, 17 Apr 2025 10:30:19 -0400 Subject: [PATCH] Refactor system message handling to use event stream (#7824) Co-authored-by: openhands Co-authored-by: Calvin Smith --- .../agenthub/codeact_agent/codeact_agent.py | 37 ++- openhands/controller/agent.py | 39 +++ openhands/controller/agent_controller.py | 26 ++ openhands/core/schema/action.py | 4 + openhands/events/action/__init__.py | 3 +- openhands/events/action/message.py | 26 ++ openhands/events/serialization/action.py | 3 +- openhands/memory/conversation_memory.py | 33 ++- openhands/resolver/interfaces/github.py | 36 ++- openhands/server/routes/settings.py | 34 +-- openhands/server/session/session.py | 2 +- tests/unit/test_agent_controller.py | 150 ++++++----- tests/unit/test_agent_delegation.py | 24 +- tests/unit/test_codeact_agent.py | 40 ++- tests/unit/test_conversation_memory.py | 247 +++++------------- tests/unit/test_iteration_limit.py | 10 + tests/unit/test_memory.py | 30 ++- tests/unit/test_prompt_caching.py | 13 +- tests/unit/test_traffic_control.py | 10 + 19 files changed, 446 insertions(+), 321 deletions(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index a71ddc3810..7fcfad1ec4 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -11,6 +11,7 @@ from openhands.events.action import ( Action, AgentFinishAction, ) +from openhands.events.action.message import SystemMessageAction from openhands.events.event import Event from openhands.llm.llm import LLM from openhands.memory.condenser import Condenser @@ -166,8 +167,8 @@ class CodeActAgent(Agent): message flow and function-calling scenarios. The method performs the following steps: - 1. Initializes with system prompt and optional initial user message - 2. Processes events (Actions and Observations) into messages + 1. Checks for SystemMessageAction in events, adds one if missing (legacy support) + 2. Processes events (Actions and Observations) into messages, including SystemMessageAction 3. Handles tool calls and their responses in function-calling mode 4. Manages message role alternation (user/assistant/tool) 5. Applies caching for specific LLM providers (e.g., Anthropic) @@ -178,8 +179,7 @@ class CodeActAgent(Agent): Returns: list[Message]: A list of formatted messages ready for LLM consumption, including: - - System message with prompt - - Initial user message (if configured) + - System message with prompt (from SystemMessageAction) - Action messages (from both user and assistant) - Observation messages (including tool responses) - Environment reminders (in non-function-calling mode) @@ -193,15 +193,32 @@ class CodeActAgent(Agent): if not self.prompt_manager: raise Exception('Prompt Manager not instantiated.') - # Use ConversationMemory to process initial messages - messages = self.conversation_memory.process_initial_messages( - with_caching=self.llm.is_caching_prompt_active() + # Check if there's a SystemMessageAction in the events + has_system_message = any( + isinstance(event, SystemMessageAction) for event in events ) - # Use ConversationMemory to process events + # Legacy behavior: If no SystemMessageAction is found, add one + if not has_system_message: + logger.warning( + f'[{self.name}] No SystemMessageAction found in events. ' + 'Adding one for backward compatibility. ' + 'This is deprecated behavior and will be removed in a future version.' + ) + system_message = self.get_system_message() + if system_message: + # Create a copy and insert at the beginning of the list + processed_events = list(events) + processed_events.insert(0, system_message) + logger.debug( + f'[{self.name}] Added SystemMessageAction for backward compatibility' + ) + else: + processed_events = events + + # Use ConversationMemory to process events (including SystemMessageAction) messages = self.conversation_memory.process_events( - condensed_history=events, - initial_messages=messages, + condensed_history=processed_events, max_message_chars=self.llm.config.max_message_chars, vision_is_active=self.llm.vision_is_active(), ) diff --git a/openhands/controller/agent.py b/openhands/controller/agent.py index 20867b6ad6..78bdbb126d 100644 --- a/openhands/controller/agent.py +++ b/openhands/controller/agent.py @@ -5,10 +5,13 @@ if TYPE_CHECKING: from openhands.controller.state.state import State from openhands.core.config import AgentConfig from openhands.events.action import Action + from openhands.events.action.message import SystemMessageAction from openhands.core.exceptions import ( AgentAlreadyRegisteredError, AgentNotRegisteredError, ) +from openhands.core.logger import openhands_logger as logger +from openhands.events.event import EventSource from openhands.llm.llm import LLM from openhands.runtime.plugins import PluginRequirement @@ -38,6 +41,42 @@ class Agent(ABC): self._complete = False self.prompt_manager: 'PromptManager' | None = None self.mcp_tools: list[dict] = [] + self.tools: list = [] + + def get_system_message(self) -> 'SystemMessageAction | None': + """ + Returns a SystemMessageAction containing the system message and tools. + This will be added to the event stream as the first message. + + Returns: + SystemMessageAction: The system message action with content and tools + None: If there was an error generating the system message + """ + # Import here to avoid circular imports + from openhands.events.action.message import SystemMessageAction + + try: + if not self.prompt_manager: + logger.warning( + f'[{self.name}] Prompt manager not initialized before getting system message' + ) + return None + + system_message = self.prompt_manager.get_system_message() + + # Get tools if available + tools = getattr(self, 'tools', None) + + system_message_action = SystemMessageAction( + content=system_message, tools=tools + ) + # Set the source attribute + system_message_action._source = EventSource.AGENT # type: ignore + + return system_message_action + except Exception as e: + logger.warning(f'[{self.name}] Failed to generate system message: {e}') + return None @property def complete(self) -> bool: diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 9f86664c7c..d411df2bae 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -54,6 +54,7 @@ from openhands.events.action import ( IPythonRunCellAction, MessageAction, NullAction, + SystemMessageAction, ) from openhands.events.action.agent import CondensationAction, RecallAction from openhands.events.event import Event @@ -163,6 +164,31 @@ class AgentController: # replay-related self._replay_manager = ReplayManager(replay_events) + # Add the system message to the event stream + self._add_system_message() + + def _add_system_message(self): + for event in self.event_stream.get_events(start_id=self.state.start_id): + if isinstance(event, MessageAction) and event.source == EventSource.USER: + # FIXME: Remove this after 6/1/2025 + # Do not try to add a system message if we first run into + # a user message -- this means the eventstream exits before + # SystemMessageAction is introduced. + # We expect *agent* to handle this case gracefully. + return + + if isinstance(event, SystemMessageAction): + # Do not try to add the system message if it already exists + return + + # Add the system message to the event stream + # This should be done for all agents, including delegates + system_message = self.agent.get_system_message() + logger.debug(f'System message got from agent: {system_message}') + if system_message: + self.event_stream.add_event(system_message, EventSource.AGENT) + logger.debug(f'System message added to event stream: {system_message}') + async def close(self, set_stop_state=True) -> None: """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream. diff --git a/openhands/core/schema/action.py b/openhands/core/schema/action.py index 9e625032c2..4954aec15a 100644 --- a/openhands/core/schema/action.py +++ b/openhands/core/schema/action.py @@ -6,6 +6,10 @@ class ActionType(str, Enum): """Represents a message. """ + SYSTEM = 'system' + """Represents a system message. + """ + START = 'start' """Starts a new development task OR send chat from the user. Only sent by the client. """ diff --git a/openhands/events/action/__init__.py b/openhands/events/action/__init__.py index 5c7ad96a17..d97be959a7 100644 --- a/openhands/events/action/__init__.py +++ b/openhands/events/action/__init__.py @@ -16,7 +16,7 @@ from openhands.events.action.files import ( FileWriteAction, ) from openhands.events.action.mcp import McpAction -from openhands.events.action.message import MessageAction +from openhands.events.action.message import MessageAction, SystemMessageAction __all__ = [ 'Action', @@ -33,6 +33,7 @@ __all__ = [ 'ChangeAgentStateAction', 'IPythonRunCellAction', 'MessageAction', + 'SystemMessageAction', 'ActionConfirmationStatus', 'AgentThinkAction', 'RecallAction', diff --git a/openhands/events/action/message.py b/openhands/events/action/message.py index f4f1c2ae54..511df18aec 100644 --- a/openhands/events/action/message.py +++ b/openhands/events/action/message.py @@ -1,5 +1,7 @@ from dataclasses import dataclass +from typing import Any +import openhands from openhands.core.schema import ActionType from openhands.events.action.action import Action, ActionSecurityRisk @@ -32,3 +34,27 @@ class MessageAction(Action): for url in self.image_urls: ret += f'\nIMAGE_URL: {url}' return ret + + +@dataclass +class SystemMessageAction(Action): + """ + Action that represents a system message for an agent, including the system prompt + and available tools. This should be the first message in the event stream. + """ + + content: str + tools: list[Any] | None = None + openhands_version: str | None = openhands.__version__ + action: ActionType = ActionType.SYSTEM + + @property + def message(self) -> str: + return self.content + + def __str__(self) -> str: + ret = f'**SystemMessageAction** (source={self.source})\n' + ret += f'CONTENT: {self.content}' + if self.tools: + ret += f'\nTOOLS: {len(self.tools)} tools available' + return ret diff --git a/openhands/events/serialization/action.py b/openhands/events/serialization/action.py index c91ed60582..6b3c1f096c 100644 --- a/openhands/events/serialization/action.py +++ b/openhands/events/serialization/action.py @@ -23,7 +23,7 @@ from openhands.events.action.files import ( FileWriteAction, ) from openhands.events.action.mcp import McpAction -from openhands.events.action.message import MessageAction +from openhands.events.action.message import MessageAction, SystemMessageAction actions = ( NullAction, @@ -41,6 +41,7 @@ actions = ( RecallAction, ChangeAgentStateAction, MessageAction, + SystemMessageAction, CondensationAction, McpAction, ) diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index c018db4156..ee8c0ea98f 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -20,6 +20,7 @@ from openhands.events.action import ( MessageAction, ) from openhands.events.action.mcp import McpAction +from openhands.events.action.message import SystemMessageAction from openhands.events.event import Event, RecallType from openhands.events.observation import ( AgentCondensationObservation, @@ -53,7 +54,6 @@ class ConversationMemory: def process_events( self, condensed_history: list[Event], - initial_messages: list[Message], max_message_chars: int | None = None, vision_is_active: bool = False, ) -> list[Message]: @@ -63,7 +63,6 @@ class ConversationMemory: Args: condensed_history: The condensed history of events to convert - initial_messages: The initial messages to include in the conversation max_message_chars: The maximum number of characters in the content of an event included in the prompt to the LLM. Larger observations are truncated. vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included. @@ -74,8 +73,8 @@ class ConversationMemory: # log visual browsing status logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}') - # Process special events first (system prompts, etc.) - messages = initial_messages + # Initialize empty messages list + messages = [] # Process regular events pending_tool_call_action_messages: dict[str, Message] = {} @@ -132,20 +131,6 @@ class ConversationMemory: messages = list(ConversationMemory._filter_unmatched_tool_calls(messages)) return messages - def process_initial_messages(self, with_caching: bool = False) -> list[Message]: - """Create the initial messages for the conversation.""" - return [ - Message( - role='system', - content=[ - TextContent( - text=self.prompt_manager.get_system_message(), - cache_prompt=with_caching, - ) - ], - ) - ] - def _process_action( self, action: Action, @@ -275,6 +260,16 @@ class ConversationMemory: content=content, ) ] + elif isinstance(action, SystemMessageAction): + # Convert SystemMessageAction to a system message + return [ + Message( + role='system', + content=[TextContent(text=action.content)], + # Include tools if function calling is enabled + tool_calls=None, + ) + ] return [] def _process_observation( @@ -546,6 +541,8 @@ class ConversationMemory: For new Anthropic API, we only need to mark the last user or tool message as cacheable. """ + if len(messages) > 0 and messages[0].role == 'system': + messages[0].content[-1].cache_prompt = True # NOTE: this is only needed for anthropic for message in reversed(messages): if message.role in ('user', 'tool'): diff --git a/openhands/resolver/interfaces/github.py b/openhands/resolver/interfaces/github.py index 85746b0cda..1c112f7eb4 100644 --- a/openhands/resolver/interfaces/github.py +++ b/openhands/resolver/interfaces/github.py @@ -12,9 +12,16 @@ from openhands.resolver.utils import extract_issue_references class GithubIssueHandler(IssueHandlerInterface): - def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"): + def __init__( + self, + owner: str, + repo: str, + token: str, + username: str | None = None, + base_domain: str = 'github.com', + ): """Initialize a GitHub issue handler. - + Args: owner: The owner of the repository repo: The name of the repository @@ -42,7 +49,7 @@ class GithubIssueHandler(IssueHandlerInterface): } def get_base_url(self) -> str: - if self.base_domain == "github.com": + if self.base_domain == 'github.com': return f'https://api.github.com/repos/{self.owner}/{self.repo}' else: return f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}' @@ -65,7 +72,7 @@ class GithubIssueHandler(IssueHandlerInterface): return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git' def get_graphql_url(self) -> str: - if self.base_domain == "github.com": + if self.base_domain == 'github.com': return 'https://api.github.com/graphql' else: return f'https://{self.base_domain}/api/v3/graphql' @@ -302,9 +309,16 @@ class GithubIssueHandler(IssueHandlerInterface): class GithubPRHandler(GithubIssueHandler): - def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"): + def __init__( + self, + owner: str, + repo: str, + token: str, + username: str | None = None, + base_domain: str = 'github.com', + ): """Initialize a GitHub PR handler. - + Args: owner: The owner of the repository repo: The name of the repository @@ -313,8 +327,10 @@ class GithubPRHandler(GithubIssueHandler): base_domain: The domain for GitHub Enterprise (default: "github.com") """ super().__init__(owner, repo, token, username, base_domain) - if self.base_domain == "github.com": - self.download_url = f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls' + if self.base_domain == 'github.com': + self.download_url = ( + f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls' + ) else: self.download_url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/pulls' @@ -470,7 +486,7 @@ class GithubPRHandler(GithubIssueHandler): self, pr_number: int, comment_id: int | None = None ) -> list[str] | None: """Download comments for a specific pull request from Github.""" - if self.base_domain == "github.com": + if self.base_domain == 'github.com': url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments' else: url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments' @@ -542,7 +558,7 @@ class GithubPRHandler(GithubIssueHandler): for issue_number in unique_issue_references: try: - if self.base_domain == "github.com": + if self.base_domain == 'github.com': url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}' else: url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{issue_number}' diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index b3660c0275..9d2ef3cdee 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -134,10 +134,7 @@ async def reset_settings(request: Request) -> JSONResponse: ) - -async def check_provider_tokens(request: Request, - settings: POSTSettingsModel) -> str: - +async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -> str: if settings.provider_tokens: # Remove extraneous token types provider_types = [provider.value for provider in ProviderType] @@ -152,17 +149,13 @@ async def check_provider_tokens(request: Request, SecretStr(token_value) ) if not confirmed_token_type or confirmed_token_type.value != token_type: - return f"Invalid token. Please make sure it is a valid {token_type} token." - - - return "" + return f'Invalid token. Please make sure it is a valid {token_type} token.' + return '' async def store_provider_tokens(request: Request, settings: POSTSettingsModel): - settings_store = await SettingsStoreImpl.get_instance( - config, get_user_id(request) - ) + settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request)) existing_settings = await settings_store.load() if existing_settings: if settings.provider_tokens: @@ -188,19 +181,17 @@ async def store_provider_tokens(request: Request, settings: POSTSettingsModel): else: # nothing passed in means keep current settings provider_tokens = existing_settings.secrets_store.provider_tokens settings.provider_tokens = { - provider.value: data.token.get_secret_value() - if data.token - else None + provider.value: data.token.get_secret_value() if data.token else None for provider, data in provider_tokens.items() } return settings -async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> POSTSettingsModel: - settings_store = await SettingsStoreImpl.get_instance( - config, get_user_id(request) - ) +async def store_llm_settings( + request: Request, settings: POSTSettingsModel +) -> POSTSettingsModel: + settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request)) existing_settings = await settings_store.load() # Convert to Settings model and merge with existing settings @@ -215,6 +206,7 @@ async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> P return settings + @app.post('/settings', response_model=dict[str, str]) async def store_settings( request: Request, @@ -225,11 +217,8 @@ async def store_settings( if provider_err_msg: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, - content={ - 'error': provider_err_msg - }, + content={'error': provider_err_msg}, ) - try: settings_store = await SettingsStoreImpl.get_instance( @@ -248,7 +237,6 @@ async def store_settings( ) settings = await store_provider_tokens(request, settings) - # Update sandbox config with new settings if settings.remote_runtime_resource_factor is not None: diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index f1c65fbc5f..09aee1bf68 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -139,7 +139,7 @@ class Session: condensers=[ BrowserOutputCondenserConfig(), LLMSummarizingCondenserConfig( - llm_config=llm.config, keep_first=3, max_size=80 + llm_config=llm.config, keep_first=4, max_size=80 ), ] ) diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 43f6c9a8ee..0722becd3c 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -15,6 +15,7 @@ from openhands.core.schema import AgentState from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction from openhands.events.action.agent import CondensationAction, RecallAction +from openhands.events.action.message import SystemMessageAction from openhands.events.event import RecallType from openhands.events.observation import ( AgentStateChangedObservation, @@ -49,6 +50,15 @@ def mock_agent(): agent.llm = MagicMock(spec=LLM) agent.llm.metrics = Metrics() agent.llm.config = AppConfig().get_llm_config() + + # Add a proper system message mock + system_message = SystemMessageAction( + content='Test system message', tools=['test_tool'] + ) + system_message._source = EventSource.AGENT + system_message._id = -1 # Set invalid ID to avoid the ID check + agent.get_system_message.return_value = system_message + return agent @@ -206,20 +216,19 @@ async def test_react_to_content_policy_violation( @pytest.mark.asyncio -async def test_run_controller_with_fatal_error(test_event_stream, mock_memory): +async def test_run_controller_with_fatal_error( + test_event_stream, mock_memory, mock_agent +): config = AppConfig() - agent = MagicMock(spec=Agent) - agent = MagicMock(spec=Agent) - def agent_step_fn(state): print(f'agent_step_fn received state: {state}') return CmdRunAction(command='ls') - agent.step = agent_step_fn - agent.llm = MagicMock(spec=LLM) - agent.llm.metrics = Metrics() - agent.llm.config = config.get_llm_config() + mock_agent.step = agent_step_fn + mock_agent.llm = MagicMock(spec=LLM) + mock_agent.llm.metrics = Metrics() + mock_agent.llm.config = config.get_llm_config() runtime = MagicMock(spec=Runtime) @@ -250,7 +259,7 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory): initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', - agent=agent, + agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', memory=mock_memory, ) @@ -268,22 +277,24 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory): assert ( error_observation.reason == 'AgentStuckInLoopError: Agent got stuck in a loop' ) - assert len(events) == 11 + assert len(events) == 12 @pytest.mark.asyncio -async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory): +async def test_run_controller_stop_with_stuck( + test_event_stream, mock_memory, mock_agent +): config = AppConfig() - agent = MagicMock(spec=Agent) def agent_step_fn(state): print(f'agent_step_fn received state: {state}') return CmdRunAction(command='ls') - agent.step = agent_step_fn - agent.llm = MagicMock(spec=LLM) - agent.llm.metrics = Metrics() - agent.llm.config = config.get_llm_config() + mock_agent.step = agent_step_fn + mock_agent.llm = MagicMock(spec=LLM) + mock_agent.llm.metrics = Metrics() + mock_agent.llm.config = config.get_llm_config() + runtime = MagicMock(spec=Runtime) def on_event(event: Event): @@ -315,7 +326,7 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory): initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', - agent=agent, + agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', memory=mock_memory, ) @@ -325,9 +336,10 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory): print(f'event {i}: {event_to_dict(event)}') assert state.iteration == 3 - assert len(events) == 11 + assert len(events) == 12 # check the eventstream have 4 pairs of repeated actions and observations - repeating_actions_and_observations = events[4:12] + # With the refactored system message handling, we need to adjust the range + repeating_actions_and_observations = events[5:13] for action, observation in zip( repeating_actions_and_observations[0::2], repeating_actions_and_observations[1::2], @@ -469,6 +481,9 @@ async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_s headless_mode=True, ) + mock_event_stream.add_event.assert_called_once() # add SystemMessageAction + mock_event_stream.add_event.reset_mock() + # Create a pending action with tool call metadata pending_action = CmdRunAction(command='test') pending_action.tool_call_metadata = { @@ -512,6 +527,9 @@ async def test_reset_with_pending_action_existing_observation( headless_mode=True, ) + mock_event_stream.add_event.assert_called_once() # add SystemMessageAction + mock_event_stream.add_event.reset_mock() + # Create a pending action with tool call metadata pending_action = CmdRunAction(command='test') pending_action.tool_call_metadata = { @@ -551,6 +569,9 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream): headless_mode=True, ) + # Reset the mock to clear the call from system message addition + mock_event_stream.add_event.reset_mock() + # Call reset controller._reset() @@ -579,6 +600,9 @@ async def test_reset_with_pending_action_no_metadata( headless_mode=True, ) + # Reset the mock to clear the call from system message addition + mock_event_stream.add_event.reset_mock() + # Create a pending action without tool call metadata pending_action = CmdRunAction(command='test') # Mock hasattr to return False for tool_call_metadata @@ -608,28 +632,27 @@ async def test_reset_with_pending_action_no_metadata( @pytest.mark.asyncio async def test_run_controller_max_iterations_has_metrics( - test_event_stream, mock_memory + test_event_stream, mock_memory, mock_agent ): config = AppConfig( max_iterations=3, ) event_stream = test_event_stream - agent = MagicMock(spec=Agent) - agent.llm = MagicMock(spec=LLM) - agent.llm.metrics = Metrics() - agent.llm.config = config.get_llm_config() + mock_agent.llm = MagicMock(spec=LLM) + mock_agent.llm.metrics = Metrics() + mock_agent.llm.config = config.get_llm_config() def agent_step_fn(state): print(f'agent_step_fn received state: {state}') # Mock the cost of the LLM - agent.llm.metrics.add_cost(10.0) + mock_agent.llm.metrics.add_cost(10.0) print( - f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}' + f'mock_agent.llm.metrics.accumulated_cost: {mock_agent.llm.metrics.accumulated_cost}' ) return CmdRunAction(command='ls') - agent.step = agent_step_fn + mock_agent.step = agent_step_fn runtime = MagicMock(spec=Runtime) @@ -660,7 +683,7 @@ async def test_run_controller_max_iterations_has_metrics( initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', - agent=agent, + agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', memory=mock_memory, ) @@ -839,9 +862,10 @@ async def test_context_window_exceeded_error_handling( ) == 1 ) + # With the refactored system message handling, we now have max_iterations + 4 events assert ( - len(final_state.history) == max_iterations + 3 - ) # 1 condensation action, 1 recall action, 1 recall observation + len(final_state.history) == max_iterations + 4 + ) # 1 system message, 1 condensation action, 1 recall action, 1 recall observation assert len(final_state.view) == len(step_state.views[-1]) + 1 @@ -990,7 +1014,8 @@ async def test_run_controller_with_context_window_exceeded_without_truncation( # Hitting the iteration limit indicates the controller is failing for the # expected reason - assert state.iteration == 2 + # With the refactored system message handling, the iteration count is different + assert state.iteration == 1 assert state.agent_state == AgentState.ERROR assert ( state.last_error @@ -1012,21 +1037,20 @@ async def test_run_controller_with_context_window_exceeded_without_truncation( @pytest.mark.asyncio -async def test_run_controller_with_memory_error(test_event_stream): +async def test_run_controller_with_memory_error(test_event_stream, mock_agent): config = AppConfig() event_stream = test_event_stream - # Create a propert agent that returns an action without an ID - agent = MagicMock(spec=Agent) - agent.llm = MagicMock(spec=LLM) - agent.llm.metrics = Metrics() - agent.llm.config = config.get_llm_config() + # Create a proper agent that returns an action without an ID + mock_agent.llm = MagicMock(spec=LLM) + mock_agent.llm.metrics = Metrics() + mock_agent.llm.config = config.get_llm_config() # Create a real action to return from the mocked step function def agent_step_fn(state): return MessageAction(content='Agent returned a message') - agent.step = agent_step_fn + mock_agent.step = agent_step_fn runtime = MagicMock(spec=Runtime) runtime.event_stream = event_stream @@ -1046,7 +1070,7 @@ async def test_run_controller_with_memory_error(test_event_stream): initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', - agent=agent, + agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', memory=memory, ) @@ -1057,14 +1081,13 @@ async def test_run_controller_with_memory_error(test_event_stream): @pytest.mark.asyncio -async def test_action_metrics_copy(): +async def test_action_metrics_copy(mock_agent): # Setup file_store = InMemoryFileStore({}) event_stream = EventStream(sid='test', file_store=file_store) # Create agent with metrics - agent = MagicMock(spec=Agent) - agent.llm = MagicMock(spec=LLM) + mock_agent.llm = MagicMock(spec=LLM) metrics = Metrics(model_name='test-model') metrics.accumulated_cost = 0.05 @@ -1106,7 +1129,7 @@ async def test_action_metrics_copy(): # Add a response latency - should not be included in action metrics metrics.add_response_latency(0.5, 'test-id-2') - agent.llm.metrics = metrics + mock_agent.llm.metrics = metrics # Mock agent step to return an action action = MessageAction(content='Test message') @@ -1114,11 +1137,11 @@ async def test_action_metrics_copy(): def agent_step_fn(state): return action - agent.step = agent_step_fn + mock_agent.step = agent_step_fn # Create controller with correct parameters controller = AgentController( - agent=agent, + agent=mock_agent, event_stream=event_stream, max_iterations=10, sid='test', @@ -1169,14 +1192,14 @@ async def test_action_metrics_copy(): assert not hasattr(last_action.llm_metrics, 'average_latency') # Verify it's a deep copy by modifying the original - agent.llm.metrics.accumulated_cost = 0.1 + mock_agent.llm.metrics.accumulated_cost = 0.1 assert last_action.llm_metrics.accumulated_cost == 0.07 await controller.close() @pytest.mark.asyncio -async def test_first_user_message_with_identical_content(): +async def test_first_user_message_with_identical_content(test_event_stream, mock_agent): """ Test that _first_user_message correctly identifies the first user message even when multiple messages have identical content but different IDs. @@ -1185,18 +1208,14 @@ async def test_first_user_message_with_identical_content(): The issue we're checking is that the comparison (action == self._first_user_message()) should correctly differentiate between messages with the same content but different IDs. """ - # Create a real event stream for this test - event_stream = EventStream(sid='test', file_store=InMemoryFileStore({})) - # Create an agent controller - mock_agent = MagicMock(spec=Agent) mock_agent.llm = MagicMock(spec=LLM) mock_agent.llm.metrics = Metrics() mock_agent.llm.config = AppConfig().get_llm_config() controller = AgentController( agent=mock_agent, - event_stream=event_stream, + event_stream=test_event_stream, max_iterations=10, sid='test', confirmation_mode=False, @@ -1206,12 +1225,12 @@ async def test_first_user_message_with_identical_content(): # Create and add the first user message first_message = MessageAction(content='Hello, this is a test message') first_message._source = EventSource.USER - event_stream.add_event(first_message, EventSource.USER) + test_event_stream.add_event(first_message, EventSource.USER) # Create and add a second user message with identical content second_message = MessageAction(content='Hello, this is a test message') second_message._source = EventSource.USER - event_stream.add_event(second_message, EventSource.USER) + test_event_stream.add_event(second_message, EventSource.USER) # Verify that _first_user_message returns the first message first_user_message = controller._first_user_message() @@ -1235,7 +1254,7 @@ async def test_first_user_message_with_identical_content(): ) # Cache should store the same object # Mock get_events to verify it's not called again - with patch.object(event_stream, 'get_events') as mock_get_events: + with patch.object(test_event_stream, 'get_events') as mock_get_events: cached_message = controller._first_user_message() assert cached_message is first_user_message # Should return cached object mock_get_events.assert_not_called() # Should not call get_events again @@ -1326,15 +1345,12 @@ async def test_agent_controller_processes_null_observation_with_cause(): ), 'should_step should return False for NullObservation with cause=0' -def test_agent_controller_should_step_with_null_observation_cause_zero(): +def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agent): """Test that AgentController's should_step method returns False for NullObservation with cause = 0.""" # Create a mock event stream file_store = InMemoryFileStore() event_stream = EventStream(sid='test-session', file_store=file_store) - # Create a mock agent - mock_agent = MagicMock(spec=Agent) - # Create an agent controller controller = AgentController( agent=mock_agent, @@ -1475,3 +1491,19 @@ def test_history_restoration_after_truncation(mock_event_stream, mock_agent): assert len(new_controller.state.history) == saved_history_len assert new_controller.state.history[0] == first_msg assert new_controller.state.start_id == saved_start_id + + +def test_system_message_in_event_stream(mock_agent, test_event_stream): + """Test that SystemMessageAction is added to event stream in AgentController.""" + _ = AgentController( + agent=mock_agent, event_stream=test_event_stream, max_iterations=10 + ) + + # Get events from the event stream + events = list(test_event_stream.get_events()) + + # Verify system message was added to event stream + assert len(events) == 1 + assert isinstance(events[0], SystemMessageAction) + assert events[0].content == 'Test system message' + assert events[0].tools == ['test_tool'] diff --git a/tests/unit/test_agent_delegation.py b/tests/unit/test_agent_delegation.py index 39ad87bab7..d12b4c0057 100644 --- a/tests/unit/test_agent_delegation.py +++ b/tests/unit/test_agent_delegation.py @@ -44,6 +44,15 @@ def mock_parent_agent(): agent.llm.metrics = Metrics() agent.llm.config = LLMConfig() agent.config = AgentConfig() + + # Add a proper system message mock + from openhands.events.action.message import SystemMessageAction + + system_message = SystemMessageAction(content='Test system message') + system_message._source = EventSource.AGENT + system_message._id = -1 # Set invalid ID to avoid the ID check + agent.get_system_message.return_value = system_message + return agent @@ -56,6 +65,15 @@ def mock_child_agent(): agent.llm.metrics = Metrics() agent.llm.config = LLMConfig() agent.config = AgentConfig() + + # Add a proper system message mock + from openhands.events.action.message import SystemMessageAction + + system_message = SystemMessageAction(content='Test system message') + system_message._source = EventSource.AGENT + system_message._id = -1 # Set invalid ID to avoid the ID check + agent.get_system_message.return_value = system_message + return agent @@ -113,9 +131,9 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s # Verify that a RecallObservation was added to the event stream events = list(mock_event_stream.get_events()) - assert ( - mock_event_stream.get_latest_event_id() == 3 - ) # Microagents and AgentChangeState + + # SystemMessageAction, RecallAction, AgentChangeState, AgentDelegateAction, SystemMessageAction (for child) + assert mock_event_stream.get_latest_event_id() == 5 # a RecallObservation and an AgentDelegateAction should be in the list assert any(isinstance(event, RecallObservation) for event in events) diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py index 7afe307143..bbcabe47fb 100644 --- a/tests/unit/test_codeact_agent.py +++ b/tests/unit/test_codeact_agent.py @@ -26,6 +26,7 @@ from openhands.events.action import ( CmdRunAction, MessageAction, ) +from openhands.events.action.message import SystemMessageAction from openhands.events.event import EventSource from openhands.events.observation.commands import ( CmdOutputObservation, @@ -288,8 +289,11 @@ def test_correct_tool_description_loaded_based_on_model_name(mock_state: State): assert any(len(tool['function']['description']) > 1024 for tool in agent.tools) -def test_mismatched_tool_call_events(mock_state: State): - """Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages.""" +def test_mismatched_tool_call_events_and_auto_add_system_message(mock_state: State): + """Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages. + + This also tests that the system message is automatically added to the event stream if SystemMessageAction is not present. + """ agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig()) tool_call_metadata = Mock( @@ -320,26 +324,35 @@ def test_mismatched_tool_call_events(mock_state: State): observation.tool_call_metadata = tool_call_metadata # When both events are provided, the agent should get three messages: - # 1. The system message, - # 2. The action message, and + # 1. The system message (added automatically for backward compatibility) + # 2. The action message # 3. The observation message mock_state.history = [action, observation] messages = agent._get_messages(mock_state.history) assert len(messages) == 3 + assert messages[0].role == 'system' # First message should be the system message + assert messages[1].role == 'assistant' # Second message should be the action + assert messages[2].role == 'tool' # Third message should be the observation # The same should hold if the events are presented out-of-order mock_state.history = [observation, action] messages = agent._get_messages(mock_state.history) assert len(messages) == 3 + assert messages[0].role == 'system' # First message should be the system message # If only one of the two events is present, then we should just get the system message + # plus any valid message from the event mock_state.history = [action] messages = agent._get_messages(mock_state.history) - assert len(messages) == 1 + assert ( + len(messages) == 1 + ) # Only system message, action is waiting for its observation + assert messages[0].role == 'system' mock_state.history = [observation] messages = agent._get_messages(mock_state.history) - assert len(messages) == 1 + assert len(messages) == 1 # Only system message, observation has no matching action + assert messages[0].role == 'system' def test_enhance_messages_adds_newlines_between_consecutive_user_messages( @@ -397,3 +410,18 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages( # Fifth message only has ImageContent, no TextContent to modify assert len(enhanced_messages[5].content) == 1 assert isinstance(enhanced_messages[5].content[0], ImageContent) + + +def test_get_system_message(): + """Test that the Agent.get_system_message method returns a SystemMessageAction.""" + # Create a mock agent + agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig()) + + result = agent.get_system_message() + + # Check that the system message was created correctly + assert isinstance(result, SystemMessageAction) + assert 'You are OpenHands agent' in result.content + assert len(result.tools) > 0 + assert any(tool['function']['name'] == 'execute_bash' for tool in result.tools) + assert result._source == EventSource.AGENT diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index 34712ff9e5..4777af21ca 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -13,6 +13,7 @@ from openhands.events.action import ( CmdRunAction, MessageAction, ) +from openhands.events.action.message import SystemMessageAction from openhands.events.event import ( Event, EventSource, @@ -84,36 +85,29 @@ def mock_state(): return state -def test_process_initial_messages(conversation_memory): - messages = conversation_memory.process_initial_messages(with_caching=False) - assert len(messages) == 1 - assert messages[0].role == 'system' - assert messages[0].content[0].text == 'System message' - assert messages[0].content[0].cache_prompt is False - - messages = conversation_memory.process_initial_messages(with_caching=True) - assert messages[0].content[0].cache_prompt is True - - def test_process_events_with_message_action(conversation_memory): + """Test that MessageAction is processed correctly.""" + # Create a system message action + system_message = SystemMessageAction(content='System message') + system_message._source = EventSource.AGENT + + # Create user and assistant messages user_message = MessageAction(content='Hello') user_message._source = EventSource.USER assistant_message = MessageAction(content='Hi there') assistant_message._source = EventSource.AGENT - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - + # Process events messages = conversation_memory.process_events( - condensed_history=[user_message, assistant_message], - initial_messages=initial_messages, + condensed_history=[system_message, user_message, assistant_message], max_message_chars=None, vision_is_active=False, ) + # Check that the messages were processed correctly assert len(messages) == 3 assert messages[0].role == 'system' + assert messages[0].content[0].text == 'System message' assert messages[1].role == 'user' assert messages[1].content[0].text == 'Hello' assert messages[2].role == 'assistant' @@ -131,19 +125,14 @@ def test_process_events_with_cmd_output_observation(conversation_memory): ), ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -159,19 +148,14 @@ def test_process_events_with_ipython_run_cell_observation(conversation_memory): content='IPython output\n![image](data:image/png;base64,ABC123)', ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -188,19 +172,14 @@ def test_process_events_with_agent_delegate_observation(conversation_memory): content='Content', outputs={'content': 'Delegated agent output'} ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -210,19 +189,14 @@ def test_process_events_with_agent_delegate_observation(conversation_memory): def test_process_events_with_error_observation(conversation_memory): obs = ErrorObservation('Error message') - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -234,14 +208,9 @@ def test_process_events_with_unknown_observation(conversation_memory): # Create a mock that inherits from Event but not Action or Observation obs = Mock(spec=Event) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - with pytest.raises(ValueError, match='Unknown event type'): conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) @@ -257,19 +226,14 @@ def test_process_events_with_file_edit_observation(conversation_memory): impl_source=FileEditSource.LLM_BASED_EDIT, ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -283,19 +247,14 @@ def test_process_events_with_file_read_observation(conversation_memory): impl_source=FileReadSource.DEFAULT, ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -311,19 +270,14 @@ def test_process_events_with_browser_output_observation(conversation_memory): error=False, ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -333,19 +287,14 @@ def test_process_events_with_browser_output_observation(conversation_memory): def test_process_events_with_user_reject_observation(conversation_memory): obs = UserRejectObservation('Action rejected') - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -368,20 +317,14 @@ def test_process_events_with_empty_environment_info(conversation_memory): content='Retrieved environment info', ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[empty_obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - # Should only contain the initial system message - assert len(messages) == 1 - assert messages[0].role == 'system' + # Should only contain no messages + assert len(messages) == 0 # Verify that build_workspace_context was NOT called since all input values were empty conversation_memory.prompt_manager.build_workspace_context.assert_not_called() @@ -405,20 +348,14 @@ def test_process_events_with_function_calling_observation(conversation_memory): model_response=mock_response, total_calls_in_response=1, ) - - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) # No direct message when using function calling - assert len(messages) == 1 # Only the initial system message + assert len(messages) == 0 # should be no messages def test_process_events_with_message_action_with_image(conversation_memory): @@ -428,19 +365,14 @@ def test_process_events_with_message_action_with_image(conversation_memory): ) action._source = EventSource.AGENT - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[action], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=True, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'assistant' assert len(result.content) == 2 assert isinstance(result.content[0], TextContent) @@ -453,19 +385,14 @@ def test_process_events_with_user_cmd_action(conversation_memory): action = CmdRunAction(command='ls -l') action._source = EventSource.USER - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[action], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -491,19 +418,14 @@ def test_process_events_with_agent_finish_action_with_tool_metadata( total_calls_in_response=1, ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[action], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'assistant' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -520,10 +442,11 @@ def test_apply_prompt_caching(conversation_memory): conversation_memory.apply_prompt_caching(messages) - # Only the last user message should have cache_prompt=True - assert messages[0].content[0].cache_prompt is False + # System message is hard-coded to be cached always + assert messages[0].content[0].cache_prompt is True assert messages[1].content[0].cache_prompt is False assert messages[2].content[0].cache_prompt is False + # Only the last user message should have cache_prompt=True assert messages[3].content[0].cache_prompt is True @@ -538,19 +461,14 @@ def test_process_events_with_environment_microagent_observation(conversation_mem content='Retrieved environment info', ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -598,19 +516,14 @@ def test_process_events_with_knowledge_microagent_microagent_observation( content='Retrieved knowledge from microagents', ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 1 + result = messages[0] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -646,20 +559,14 @@ def test_process_events_with_microagent_observation_extensions_disabled( content='Retrieved environment info', ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) # When prompt extensions are disabled, the RecallObservation should be ignored - assert len(messages) == 1 # Only the initial system message - assert messages[0].role == 'system' + assert len(messages) == 0 # should be no messages # Verify the prompt_manager was not called conversation_memory.prompt_manager.build_workspace_context.assert_not_called() @@ -674,20 +581,14 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory): content='Retrieved knowledge from microagents', ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) # The implementation returns an empty string and it doesn't creates a message - assert len(messages) == 1 - assert messages[0].role == 'system' + assert len(messages) == 0 # should be no messages # When there are no triggered agents, build_microagent_info is not called conversation_memory.prompt_manager.build_microagent_info.assert_not_called() @@ -892,27 +793,19 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m content='Third retrieval', ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs1, obs2, obs3], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) # Verify that only the first occurrence of content for each agent is included - assert ( - len(messages) == 2 - ) # system + 1 microagent, because the second and third microagents are duplicates - microagent_messages = messages[1:] # Skip system message + assert len(messages) == 1 # First microagent should include all agents since they appear here first - assert 'Image best practices v1' in microagent_messages[0].content[0].text - assert 'Git best practices v1' in microagent_messages[0].content[0].text - assert 'Python best practices v1' in microagent_messages[0].content[0].text + assert 'Image best practices v1' in messages[0].content[0].text + assert 'Git best practices v1' in messages[0].content[0].text + assert 'Python best practices v1' in messages[0].content[0].text def test_process_events_with_microagent_observation_deduplication_disabled_agents( @@ -949,26 +842,18 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent content='Second retrieval', ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs1, obs2], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) # Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included - assert ( - len(messages) == 2 - ) # system + 1 microagent, the second is the same "enabled_agent" - microagent_messages = messages[1:] # Skip system message + assert len(messages) == 1 # First microagent should include enabled_agent but not disabled_agent - assert 'Disabled agent content' not in microagent_messages[0].content[0].text - assert 'Enabled agent content v1' in microagent_messages[0].content[0].text + assert 'Disabled agent content' not in messages[0].content[0].text + assert 'Enabled agent content v1' in messages[0].content[0].text def test_process_events_with_microagent_observation_deduplication_empty( @@ -981,21 +866,14 @@ def test_process_events_with_microagent_observation_deduplication_empty( content='Empty retrieval', ) - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - messages = conversation_memory.process_events( condensed_history=[obs], - initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) # Verify that empty RecallObservations are handled gracefully - assert ( - len(messages) == 1 - ) # system message, because an empty microagent is not added to Messages + assert len(messages) == 0 # an empty microagent is not added to Messages def test_has_agent_in_earlier_events(conversation_memory): @@ -1198,3 +1076,22 @@ class TestFilterUnmatchedToolCalls: assert len(result) == len(expected) for i, msg in enumerate(result): assert msg == expected[i] + + +def test_system_message_in_events(conversation_memory): + """Test that SystemMessageAction in condensed_history is processed correctly.""" + # Create a system message action + system_message = SystemMessageAction(content='System message', tools=['test_tool']) + system_message._source = EventSource.AGENT + + # Process events with the system message in condensed_history + messages = conversation_memory.process_events( + condensed_history=[system_message], + max_message_chars=None, + vision_is_active=False, + ) + + # Check that the system message was processed correctly + assert len(messages) == 1 + assert messages[0].role == 'system' + assert messages[0].content[0].text == 'System message' diff --git a/tests/unit/test_iteration_limit.py b/tests/unit/test_iteration_limit.py index dae272dde4..b332085c00 100644 --- a/tests/unit/test_iteration_limit.py +++ b/tests/unit/test_iteration_limit.py @@ -25,6 +25,16 @@ class DummyAgent: def reset(self): pass + def get_system_message(self): + # Return a proper SystemMessageAction for the refactored system message handling + from openhands.events.action.message import SystemMessageAction + from openhands.events.event import EventSource + + system_message = SystemMessageAction(content='This is a dummy system message') + system_message._source = EventSource.AGENT + system_message._id = -1 # Set invalid ID to avoid the ID check + return system_message + @pytest.mark.asyncio async def test_iteration_limit_extends_on_user_message(): diff --git a/tests/unit/test_memory.py b/tests/unit/test_memory.py index 3dd0e59979..94010057b4 100644 --- a/tests/unit/test_memory.py +++ b/tests/unit/test_memory.py @@ -58,15 +58,27 @@ def prompt_dir(tmp_path): return tmp_path -@pytest.mark.asyncio -async def test_memory_on_event_exception_handling(memory, event_stream): - """Test that exceptions in Memory.on_event are properly handled via status callback.""" +@pytest.fixture +def mock_agent(): # Create a dummy agent for the controller agent = MagicMock(spec=Agent) agent.llm = MagicMock(spec=LLM) agent.llm.metrics = Metrics() agent.llm.config = AppConfig().get_llm_config() + # Add a proper system message mock + from openhands.events.action.message import SystemMessageAction + + system_message = SystemMessageAction(content='Test system message') + system_message._source = EventSource.AGENT + system_message._id = -1 # Set invalid ID to avoid the ID check + agent.get_system_message.return_value = system_message + + +@pytest.mark.asyncio +async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent): + """Test that exceptions in Memory.on_event are properly handled via status callback.""" + # Create a mock runtime runtime = MagicMock(spec=Runtime) runtime.event_stream = event_stream @@ -80,7 +92,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream): initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', - agent=agent, + agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', memory=memory, ) @@ -93,16 +105,10 @@ async def test_memory_on_event_exception_handling(memory, event_stream): @pytest.mark.asyncio async def test_memory_on_workspace_context_recall_exception_handling( - memory, event_stream + memory, event_stream, mock_agent ): """Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback.""" - # Create a dummy agent for the controller - agent = MagicMock(spec=Agent) - agent.llm = MagicMock(spec=LLM) - agent.llm.metrics = Metrics() - agent.llm.config = AppConfig().get_llm_config() - # Create a mock runtime runtime = MagicMock(spec=Runtime) runtime.event_stream = event_stream @@ -118,7 +124,7 @@ async def test_memory_on_workspace_context_recall_exception_handling( initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', - agent=agent, + agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', memory=memory, ) diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index fdb9f1f2fb..2b323da0dc 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -55,6 +55,10 @@ def response_mock(content: str, tool_call_id: str): def test_get_messages(codeact_agent: CodeActAgent): # Add some events to history history = list() + # Add system message action + system_message_action = codeact_agent.get_system_message() + history.append(system_message_action) + message_action_1 = MessageAction('Initial user message') message_action_1._source = 'user' history.append(message_action_1) @@ -77,7 +81,8 @@ def test_get_messages(codeact_agent: CodeActAgent): assert ( len(messages) == 6 ) # System, initial user + user message, agent message, last user message - assert messages[0].content[0].cache_prompt # system message + assert messages[0].role == 'system' # system message + assert messages[0].content[0].cache_prompt # system message should be cached assert messages[1].role == 'user' assert messages[1].content[0].text.endswith('Initial user message') # we add cache breakpoint to only the last user message @@ -96,6 +101,10 @@ def test_get_messages(codeact_agent: CodeActAgent): def test_get_messages_prompt_caching(codeact_agent: CodeActAgent): history = list() + # Add system message action + system_message_action = codeact_agent.get_system_message() + history.append(system_message_action) + # Add multiple user and agent messages for i in range(15): message_action_user = MessageAction(f'User message {i}') @@ -116,7 +125,7 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent): ] assert ( len(cached_user_messages) == 2 - ) # Including the initial system+user + last user message + ) # Including the initial system message + last user message # Verify that these are indeed the last user message (from start) assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent') diff --git a/tests/unit/test_traffic_control.py b/tests/unit/test_traffic_control.py index 58be557b50..5d011b94b3 100644 --- a/tests/unit/test_traffic_control.py +++ b/tests/unit/test_traffic_control.py @@ -16,6 +16,16 @@ def agent_controller(): agent.name = 'test_agent' agent.llm = llm agent.config = AgentConfig() + + # Add a proper system message mock + from openhands.events import EventSource + from openhands.events.action.message import SystemMessageAction + + system_message = SystemMessageAction(content='Test system message') + system_message._source = EventSource.AGENT + system_message._id = -1 # Set invalid ID to avoid the ID check + agent.get_system_message.return_value = system_message + event_stream = EventStream(sid='test', file_store=InMemoryFileStore()) controller = AgentController( agent=agent,