Refactor system message handling to use event stream (#7824)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
Xingyao Wang 2025-04-17 10:30:19 -04:00 committed by GitHub
parent caf34d83bd
commit 93e9db3206
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 446 additions and 321 deletions

View File

@ -11,6 +11,7 @@ from openhands.events.action import (
Action,
AgentFinishAction,
)
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.memory.condenser import Condenser
@ -166,8 +167,8 @@ class CodeActAgent(Agent):
message flow and function-calling scenarios.
The method performs the following steps:
1. Initializes with system prompt and optional initial user message
2. Processes events (Actions and Observations) into messages
1. Checks for SystemMessageAction in events, adds one if missing (legacy support)
2. Processes events (Actions and Observations) into messages, including SystemMessageAction
3. Handles tool calls and their responses in function-calling mode
4. Manages message role alternation (user/assistant/tool)
5. Applies caching for specific LLM providers (e.g., Anthropic)
@ -178,8 +179,7 @@ class CodeActAgent(Agent):
Returns:
list[Message]: A list of formatted messages ready for LLM consumption, including:
- System message with prompt
- Initial user message (if configured)
- System message with prompt (from SystemMessageAction)
- Action messages (from both user and assistant)
- Observation messages (including tool responses)
- Environment reminders (in non-function-calling mode)
@ -193,15 +193,32 @@ class CodeActAgent(Agent):
if not self.prompt_manager:
raise Exception('Prompt Manager not instantiated.')
# Use ConversationMemory to process initial messages
messages = self.conversation_memory.process_initial_messages(
with_caching=self.llm.is_caching_prompt_active()
# Check if there's a SystemMessageAction in the events
has_system_message = any(
isinstance(event, SystemMessageAction) for event in events
)
# Use ConversationMemory to process events
# Legacy behavior: If no SystemMessageAction is found, add one
if not has_system_message:
logger.warning(
f'[{self.name}] No SystemMessageAction found in events. '
'Adding one for backward compatibility. '
'This is deprecated behavior and will be removed in a future version.'
)
system_message = self.get_system_message()
if system_message:
# Create a copy and insert at the beginning of the list
processed_events = list(events)
processed_events.insert(0, system_message)
logger.debug(
f'[{self.name}] Added SystemMessageAction for backward compatibility'
)
else:
processed_events = events
# Use ConversationMemory to process events (including SystemMessageAction)
messages = self.conversation_memory.process_events(
condensed_history=events,
initial_messages=messages,
condensed_history=processed_events,
max_message_chars=self.llm.config.max_message_chars,
vision_is_active=self.llm.vision_is_active(),
)

View File

@ -5,10 +5,13 @@ if TYPE_CHECKING:
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.events.action import Action
from openhands.events.action.message import SystemMessageAction
from openhands.core.exceptions import (
AgentAlreadyRegisteredError,
AgentNotRegisteredError,
)
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import EventSource
from openhands.llm.llm import LLM
from openhands.runtime.plugins import PluginRequirement
@ -38,6 +41,42 @@ class Agent(ABC):
self._complete = False
self.prompt_manager: 'PromptManager' | None = None
self.mcp_tools: list[dict] = []
self.tools: list = []
def get_system_message(self) -> 'SystemMessageAction | None':
"""
Returns a SystemMessageAction containing the system message and tools.
This will be added to the event stream as the first message.
Returns:
SystemMessageAction: The system message action with content and tools
None: If there was an error generating the system message
"""
# Import here to avoid circular imports
from openhands.events.action.message import SystemMessageAction
try:
if not self.prompt_manager:
logger.warning(
f'[{self.name}] Prompt manager not initialized before getting system message'
)
return None
system_message = self.prompt_manager.get_system_message()
# Get tools if available
tools = getattr(self, 'tools', None)
system_message_action = SystemMessageAction(
content=system_message, tools=tools
)
# Set the source attribute
system_message_action._source = EventSource.AGENT # type: ignore
return system_message_action
except Exception as e:
logger.warning(f'[{self.name}] Failed to generate system message: {e}')
return None
@property
def complete(self) -> bool:

View File

@ -54,6 +54,7 @@ from openhands.events.action import (
IPythonRunCellAction,
MessageAction,
NullAction,
SystemMessageAction,
)
from openhands.events.action.agent import CondensationAction, RecallAction
from openhands.events.event import Event
@ -163,6 +164,31 @@ class AgentController:
# replay-related
self._replay_manager = ReplayManager(replay_events)
# Add the system message to the event stream
self._add_system_message()
def _add_system_message(self):
for event in self.event_stream.get_events(start_id=self.state.start_id):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
# FIXME: Remove this after 6/1/2025
# Do not try to add a system message if we first run into
# a user message -- this means the eventstream exits before
# SystemMessageAction is introduced.
# We expect *agent* to handle this case gracefully.
return
if isinstance(event, SystemMessageAction):
# Do not try to add the system message if it already exists
return
# Add the system message to the event stream
# This should be done for all agents, including delegates
system_message = self.agent.get_system_message()
logger.debug(f'System message got from agent: {system_message}')
if system_message:
self.event_stream.add_event(system_message, EventSource.AGENT)
logger.debug(f'System message added to event stream: {system_message}')
async def close(self, set_stop_state=True) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.

View File

@ -6,6 +6,10 @@ class ActionType(str, Enum):
"""Represents a message.
"""
SYSTEM = 'system'
"""Represents a system message.
"""
START = 'start'
"""Starts a new development task OR send chat from the user. Only sent by the client.
"""

View File

@ -16,7 +16,7 @@ from openhands.events.action.files import (
FileWriteAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.message import MessageAction
from openhands.events.action.message import MessageAction, SystemMessageAction
__all__ = [
'Action',
@ -33,6 +33,7 @@ __all__ = [
'ChangeAgentStateAction',
'IPythonRunCellAction',
'MessageAction',
'SystemMessageAction',
'ActionConfirmationStatus',
'AgentThinkAction',
'RecallAction',

View File

@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import Any
import openhands
from openhands.core.schema import ActionType
from openhands.events.action.action import Action, ActionSecurityRisk
@ -32,3 +34,27 @@ class MessageAction(Action):
for url in self.image_urls:
ret += f'\nIMAGE_URL: {url}'
return ret
@dataclass
class SystemMessageAction(Action):
"""
Action that represents a system message for an agent, including the system prompt
and available tools. This should be the first message in the event stream.
"""
content: str
tools: list[Any] | None = None
openhands_version: str | None = openhands.__version__
action: ActionType = ActionType.SYSTEM
@property
def message(self) -> str:
return self.content
def __str__(self) -> str:
ret = f'**SystemMessageAction** (source={self.source})\n'
ret += f'CONTENT: {self.content}'
if self.tools:
ret += f'\nTOOLS: {len(self.tools)} tools available'
return ret

View File

@ -23,7 +23,7 @@ from openhands.events.action.files import (
FileWriteAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.message import MessageAction
from openhands.events.action.message import MessageAction, SystemMessageAction
actions = (
NullAction,
@ -41,6 +41,7 @@ actions = (
RecallAction,
ChangeAgentStateAction,
MessageAction,
SystemMessageAction,
CondensationAction,
McpAction,
)

View File

@ -20,6 +20,7 @@ from openhands.events.action import (
MessageAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import Event, RecallType
from openhands.events.observation import (
AgentCondensationObservation,
@ -53,7 +54,6 @@ class ConversationMemory:
def process_events(
self,
condensed_history: list[Event],
initial_messages: list[Message],
max_message_chars: int | None = None,
vision_is_active: bool = False,
) -> list[Message]:
@ -63,7 +63,6 @@ class ConversationMemory:
Args:
condensed_history: The condensed history of events to convert
initial_messages: The initial messages to include in the conversation
max_message_chars: The maximum number of characters in the content of an event included
in the prompt to the LLM. Larger observations are truncated.
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
@ -74,8 +73,8 @@ class ConversationMemory:
# log visual browsing status
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
# Process special events first (system prompts, etc.)
messages = initial_messages
# Initialize empty messages list
messages = []
# Process regular events
pending_tool_call_action_messages: dict[str, Message] = {}
@ -132,20 +131,6 @@ class ConversationMemory:
messages = list(ConversationMemory._filter_unmatched_tool_calls(messages))
return messages
def process_initial_messages(self, with_caching: bool = False) -> list[Message]:
"""Create the initial messages for the conversation."""
return [
Message(
role='system',
content=[
TextContent(
text=self.prompt_manager.get_system_message(),
cache_prompt=with_caching,
)
],
)
]
def _process_action(
self,
action: Action,
@ -275,6 +260,16 @@ class ConversationMemory:
content=content,
)
]
elif isinstance(action, SystemMessageAction):
# Convert SystemMessageAction to a system message
return [
Message(
role='system',
content=[TextContent(text=action.content)],
# Include tools if function calling is enabled
tool_calls=None,
)
]
return []
def _process_observation(
@ -546,6 +541,8 @@ class ConversationMemory:
For new Anthropic API, we only need to mark the last user or tool message as cacheable.
"""
if len(messages) > 0 and messages[0].role == 'system':
messages[0].content[-1].cache_prompt = True
# NOTE: this is only needed for anthropic
for message in reversed(messages):
if message.role in ('user', 'tool'):

View File

@ -12,9 +12,16 @@ from openhands.resolver.utils import extract_issue_references
class GithubIssueHandler(IssueHandlerInterface):
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
def __init__(
self,
owner: str,
repo: str,
token: str,
username: str | None = None,
base_domain: str = 'github.com',
):
"""Initialize a GitHub issue handler.
Args:
owner: The owner of the repository
repo: The name of the repository
@ -42,7 +49,7 @@ class GithubIssueHandler(IssueHandlerInterface):
}
def get_base_url(self) -> str:
if self.base_domain == "github.com":
if self.base_domain == 'github.com':
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
else:
return f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}'
@ -65,7 +72,7 @@ class GithubIssueHandler(IssueHandlerInterface):
return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git'
def get_graphql_url(self) -> str:
if self.base_domain == "github.com":
if self.base_domain == 'github.com':
return 'https://api.github.com/graphql'
else:
return f'https://{self.base_domain}/api/v3/graphql'
@ -302,9 +309,16 @@ class GithubIssueHandler(IssueHandlerInterface):
class GithubPRHandler(GithubIssueHandler):
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
def __init__(
self,
owner: str,
repo: str,
token: str,
username: str | None = None,
base_domain: str = 'github.com',
):
"""Initialize a GitHub PR handler.
Args:
owner: The owner of the repository
repo: The name of the repository
@ -313,8 +327,10 @@ class GithubPRHandler(GithubIssueHandler):
base_domain: The domain for GitHub Enterprise (default: "github.com")
"""
super().__init__(owner, repo, token, username, base_domain)
if self.base_domain == "github.com":
self.download_url = f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
if self.base_domain == 'github.com':
self.download_url = (
f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
)
else:
self.download_url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/pulls'
@ -470,7 +486,7 @@ class GithubPRHandler(GithubIssueHandler):
self, pr_number: int, comment_id: int | None = None
) -> list[str] | None:
"""Download comments for a specific pull request from Github."""
if self.base_domain == "github.com":
if self.base_domain == 'github.com':
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
else:
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
@ -542,7 +558,7 @@ class GithubPRHandler(GithubIssueHandler):
for issue_number in unique_issue_references:
try:
if self.base_domain == "github.com":
if self.base_domain == 'github.com':
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}'
else:
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{issue_number}'

View File

@ -134,10 +134,7 @@ async def reset_settings(request: Request) -> JSONResponse:
)
async def check_provider_tokens(request: Request,
settings: POSTSettingsModel) -> str:
async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -> str:
if settings.provider_tokens:
# Remove extraneous token types
provider_types = [provider.value for provider in ProviderType]
@ -152,17 +149,13 @@ async def check_provider_tokens(request: Request,
SecretStr(token_value)
)
if not confirmed_token_type or confirmed_token_type.value != token_type:
return f"Invalid token. Please make sure it is a valid {token_type} token."
return ""
return f'Invalid token. Please make sure it is a valid {token_type} token.'
return ''
async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
existing_settings = await settings_store.load()
if existing_settings:
if settings.provider_tokens:
@ -188,19 +181,17 @@ async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
else: # nothing passed in means keep current settings
provider_tokens = existing_settings.secrets_store.provider_tokens
settings.provider_tokens = {
provider.value: data.token.get_secret_value()
if data.token
else None
provider.value: data.token.get_secret_value() if data.token else None
for provider, data in provider_tokens.items()
}
return settings
async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> POSTSettingsModel:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
async def store_llm_settings(
request: Request, settings: POSTSettingsModel
) -> POSTSettingsModel:
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
@ -215,6 +206,7 @@ async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> P
return settings
@app.post('/settings', response_model=dict[str, str])
async def store_settings(
request: Request,
@ -225,11 +217,8 @@ async def store_settings(
if provider_err_msg:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': provider_err_msg
},
content={'error': provider_err_msg},
)
try:
settings_store = await SettingsStoreImpl.get_instance(
@ -248,7 +237,6 @@ async def store_settings(
)
settings = await store_provider_tokens(request, settings)
# Update sandbox config with new settings
if settings.remote_runtime_resource_factor is not None:

View File

@ -139,7 +139,7 @@ class Session:
condensers=[
BrowserOutputCondenserConfig(),
LLMSummarizingCondenserConfig(
llm_config=llm.config, keep_first=3, max_size=80
llm_config=llm.config, keep_first=4, max_size=80
),
]
)

View File

@ -15,6 +15,7 @@ from openhands.core.schema import AgentState
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
from openhands.events.action.agent import CondensationAction, RecallAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import RecallType
from openhands.events.observation import (
AgentStateChangedObservation,
@ -49,6 +50,15 @@ def mock_agent():
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
# Add a proper system message mock
system_message = SystemMessageAction(
content='Test system message', tools=['test_tool']
)
system_message._source = EventSource.AGENT
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
return agent
@ -206,20 +216,19 @@ async def test_react_to_content_policy_violation(
@pytest.mark.asyncio
async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
async def test_run_controller_with_fatal_error(
test_event_stream, mock_memory, mock_agent
):
config = AppConfig()
agent = MagicMock(spec=Agent)
agent = MagicMock(spec=Agent)
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
return CmdRunAction(command='ls')
agent.step = agent_step_fn
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
mock_agent.step = agent_step_fn
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=Runtime)
@ -250,7 +259,7 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
)
@ -268,22 +277,24 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
assert (
error_observation.reason == 'AgentStuckInLoopError: Agent got stuck in a loop'
)
assert len(events) == 11
assert len(events) == 12
@pytest.mark.asyncio
async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
async def test_run_controller_stop_with_stuck(
test_event_stream, mock_memory, mock_agent
):
config = AppConfig()
agent = MagicMock(spec=Agent)
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
return CmdRunAction(command='ls')
agent.step = agent_step_fn
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
mock_agent.step = agent_step_fn
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=Runtime)
def on_event(event: Event):
@ -315,7 +326,7 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
)
@ -325,9 +336,10 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
print(f'event {i}: {event_to_dict(event)}')
assert state.iteration == 3
assert len(events) == 11
assert len(events) == 12
# check the eventstream have 4 pairs of repeated actions and observations
repeating_actions_and_observations = events[4:12]
# With the refactored system message handling, we need to adjust the range
repeating_actions_and_observations = events[5:13]
for action, observation in zip(
repeating_actions_and_observations[0::2],
repeating_actions_and_observations[1::2],
@ -469,6 +481,9 @@ async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_s
headless_mode=True,
)
mock_event_stream.add_event.assert_called_once() # add SystemMessageAction
mock_event_stream.add_event.reset_mock()
# Create a pending action with tool call metadata
pending_action = CmdRunAction(command='test')
pending_action.tool_call_metadata = {
@ -512,6 +527,9 @@ async def test_reset_with_pending_action_existing_observation(
headless_mode=True,
)
mock_event_stream.add_event.assert_called_once() # add SystemMessageAction
mock_event_stream.add_event.reset_mock()
# Create a pending action with tool call metadata
pending_action = CmdRunAction(command='test')
pending_action.tool_call_metadata = {
@ -551,6 +569,9 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream):
headless_mode=True,
)
# Reset the mock to clear the call from system message addition
mock_event_stream.add_event.reset_mock()
# Call reset
controller._reset()
@ -579,6 +600,9 @@ async def test_reset_with_pending_action_no_metadata(
headless_mode=True,
)
# Reset the mock to clear the call from system message addition
mock_event_stream.add_event.reset_mock()
# Create a pending action without tool call metadata
pending_action = CmdRunAction(command='test')
# Mock hasattr to return False for tool_call_metadata
@ -608,28 +632,27 @@ async def test_reset_with_pending_action_no_metadata(
@pytest.mark.asyncio
async def test_run_controller_max_iterations_has_metrics(
test_event_stream, mock_memory
test_event_stream, mock_memory, mock_agent
):
config = AppConfig(
max_iterations=3,
)
event_stream = test_event_stream
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
# Mock the cost of the LLM
agent.llm.metrics.add_cost(10.0)
mock_agent.llm.metrics.add_cost(10.0)
print(
f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}'
f'mock_agent.llm.metrics.accumulated_cost: {mock_agent.llm.metrics.accumulated_cost}'
)
return CmdRunAction(command='ls')
agent.step = agent_step_fn
mock_agent.step = agent_step_fn
runtime = MagicMock(spec=Runtime)
@ -660,7 +683,7 @@ async def test_run_controller_max_iterations_has_metrics(
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
)
@ -839,9 +862,10 @@ async def test_context_window_exceeded_error_handling(
)
== 1
)
# With the refactored system message handling, we now have max_iterations + 4 events
assert (
len(final_state.history) == max_iterations + 3
) # 1 condensation action, 1 recall action, 1 recall observation
len(final_state.history) == max_iterations + 4
) # 1 system message, 1 condensation action, 1 recall action, 1 recall observation
assert len(final_state.view) == len(step_state.views[-1]) + 1
@ -990,7 +1014,8 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
# Hitting the iteration limit indicates the controller is failing for the
# expected reason
assert state.iteration == 2
# With the refactored system message handling, the iteration count is different
assert state.iteration == 1
assert state.agent_state == AgentState.ERROR
assert (
state.last_error
@ -1012,21 +1037,20 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
@pytest.mark.asyncio
async def test_run_controller_with_memory_error(test_event_stream):
async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
config = AppConfig()
event_stream = test_event_stream
# Create a propert agent that returns an action without an ID
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
# Create a proper agent that returns an action without an ID
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
# Create a real action to return from the mocked step function
def agent_step_fn(state):
return MessageAction(content='Agent returned a message')
agent.step = agent_step_fn
mock_agent.step = agent_step_fn
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
@ -1046,7 +1070,7 @@ async def test_run_controller_with_memory_error(test_event_stream):
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
)
@ -1057,14 +1081,13 @@ async def test_run_controller_with_memory_error(test_event_stream):
@pytest.mark.asyncio
async def test_action_metrics_copy():
async def test_action_metrics_copy(mock_agent):
# Setup
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
# Create agent with metrics
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
mock_agent.llm = MagicMock(spec=LLM)
metrics = Metrics(model_name='test-model')
metrics.accumulated_cost = 0.05
@ -1106,7 +1129,7 @@ async def test_action_metrics_copy():
# Add a response latency - should not be included in action metrics
metrics.add_response_latency(0.5, 'test-id-2')
agent.llm.metrics = metrics
mock_agent.llm.metrics = metrics
# Mock agent step to return an action
action = MessageAction(content='Test message')
@ -1114,11 +1137,11 @@ async def test_action_metrics_copy():
def agent_step_fn(state):
return action
agent.step = agent_step_fn
mock_agent.step = agent_step_fn
# Create controller with correct parameters
controller = AgentController(
agent=agent,
agent=mock_agent,
event_stream=event_stream,
max_iterations=10,
sid='test',
@ -1169,14 +1192,14 @@ async def test_action_metrics_copy():
assert not hasattr(last_action.llm_metrics, 'average_latency')
# Verify it's a deep copy by modifying the original
agent.llm.metrics.accumulated_cost = 0.1
mock_agent.llm.metrics.accumulated_cost = 0.1
assert last_action.llm_metrics.accumulated_cost == 0.07
await controller.close()
@pytest.mark.asyncio
async def test_first_user_message_with_identical_content():
async def test_first_user_message_with_identical_content(test_event_stream, mock_agent):
"""
Test that _first_user_message correctly identifies the first user message
even when multiple messages have identical content but different IDs.
@ -1185,18 +1208,14 @@ async def test_first_user_message_with_identical_content():
The issue we're checking is that the comparison (action == self._first_user_message())
should correctly differentiate between messages with the same content but different IDs.
"""
# Create a real event stream for this test
event_stream = EventStream(sid='test', file_store=InMemoryFileStore({}))
# Create an agent controller
mock_agent = MagicMock(spec=Agent)
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = AppConfig().get_llm_config()
controller = AgentController(
agent=mock_agent,
event_stream=event_stream,
event_stream=test_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
@ -1206,12 +1225,12 @@ async def test_first_user_message_with_identical_content():
# Create and add the first user message
first_message = MessageAction(content='Hello, this is a test message')
first_message._source = EventSource.USER
event_stream.add_event(first_message, EventSource.USER)
test_event_stream.add_event(first_message, EventSource.USER)
# Create and add a second user message with identical content
second_message = MessageAction(content='Hello, this is a test message')
second_message._source = EventSource.USER
event_stream.add_event(second_message, EventSource.USER)
test_event_stream.add_event(second_message, EventSource.USER)
# Verify that _first_user_message returns the first message
first_user_message = controller._first_user_message()
@ -1235,7 +1254,7 @@ async def test_first_user_message_with_identical_content():
) # Cache should store the same object
# Mock get_events to verify it's not called again
with patch.object(event_stream, 'get_events') as mock_get_events:
with patch.object(test_event_stream, 'get_events') as mock_get_events:
cached_message = controller._first_user_message()
assert cached_message is first_user_message # Should return cached object
mock_get_events.assert_not_called() # Should not call get_events again
@ -1326,15 +1345,12 @@ async def test_agent_controller_processes_null_observation_with_cause():
), 'should_step should return False for NullObservation with cause=0'
def test_agent_controller_should_step_with_null_observation_cause_zero():
def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agent):
"""Test that AgentController's should_step method returns False for NullObservation with cause = 0."""
# Create a mock event stream
file_store = InMemoryFileStore()
event_stream = EventStream(sid='test-session', file_store=file_store)
# Create a mock agent
mock_agent = MagicMock(spec=Agent)
# Create an agent controller
controller = AgentController(
agent=mock_agent,
@ -1475,3 +1491,19 @@ def test_history_restoration_after_truncation(mock_event_stream, mock_agent):
assert len(new_controller.state.history) == saved_history_len
assert new_controller.state.history[0] == first_msg
assert new_controller.state.start_id == saved_start_id
def test_system_message_in_event_stream(mock_agent, test_event_stream):
"""Test that SystemMessageAction is added to event stream in AgentController."""
_ = AgentController(
agent=mock_agent, event_stream=test_event_stream, max_iterations=10
)
# Get events from the event stream
events = list(test_event_stream.get_events())
# Verify system message was added to event stream
assert len(events) == 1
assert isinstance(events[0], SystemMessageAction)
assert events[0].content == 'Test system message'
assert events[0].tools == ['test_tool']

View File

@ -44,6 +44,15 @@ def mock_parent_agent():
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.config = AgentConfig()
# Add a proper system message mock
from openhands.events.action.message import SystemMessageAction
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
return agent
@ -56,6 +65,15 @@ def mock_child_agent():
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.config = AgentConfig()
# Add a proper system message mock
from openhands.events.action.message import SystemMessageAction
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
return agent
@ -113,9 +131,9 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
# Verify that a RecallObservation was added to the event stream
events = list(mock_event_stream.get_events())
assert (
mock_event_stream.get_latest_event_id() == 3
) # Microagents and AgentChangeState
# SystemMessageAction, RecallAction, AgentChangeState, AgentDelegateAction, SystemMessageAction (for child)
assert mock_event_stream.get_latest_event_id() == 5
# a RecallObservation and an AgentDelegateAction should be in the list
assert any(isinstance(event, RecallObservation) for event in events)

View File

@ -26,6 +26,7 @@ from openhands.events.action import (
CmdRunAction,
MessageAction,
)
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import EventSource
from openhands.events.observation.commands import (
CmdOutputObservation,
@ -288,8 +289,11 @@ def test_correct_tool_description_loaded_based_on_model_name(mock_state: State):
assert any(len(tool['function']['description']) > 1024 for tool in agent.tools)
def test_mismatched_tool_call_events(mock_state: State):
"""Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages."""
def test_mismatched_tool_call_events_and_auto_add_system_message(mock_state: State):
"""Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages.
This also tests that the system message is automatically added to the event stream if SystemMessageAction is not present.
"""
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
tool_call_metadata = Mock(
@ -320,26 +324,35 @@ def test_mismatched_tool_call_events(mock_state: State):
observation.tool_call_metadata = tool_call_metadata
# When both events are provided, the agent should get three messages:
# 1. The system message,
# 2. The action message, and
# 1. The system message (added automatically for backward compatibility)
# 2. The action message
# 3. The observation message
mock_state.history = [action, observation]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 3
assert messages[0].role == 'system' # First message should be the system message
assert messages[1].role == 'assistant' # Second message should be the action
assert messages[2].role == 'tool' # Third message should be the observation
# The same should hold if the events are presented out-of-order
mock_state.history = [observation, action]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 3
assert messages[0].role == 'system' # First message should be the system message
# If only one of the two events is present, then we should just get the system message
# plus any valid message from the event
mock_state.history = [action]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 1
assert (
len(messages) == 1
) # Only system message, action is waiting for its observation
assert messages[0].role == 'system'
mock_state.history = [observation]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 1
assert len(messages) == 1 # Only system message, observation has no matching action
assert messages[0].role == 'system'
def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
@ -397,3 +410,18 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
# Fifth message only has ImageContent, no TextContent to modify
assert len(enhanced_messages[5].content) == 1
assert isinstance(enhanced_messages[5].content[0], ImageContent)
def test_get_system_message():
"""Test that the Agent.get_system_message method returns a SystemMessageAction."""
# Create a mock agent
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
result = agent.get_system_message()
# Check that the system message was created correctly
assert isinstance(result, SystemMessageAction)
assert 'You are OpenHands agent' in result.content
assert len(result.tools) > 0
assert any(tool['function']['name'] == 'execute_bash' for tool in result.tools)
assert result._source == EventSource.AGENT

View File

@ -13,6 +13,7 @@ from openhands.events.action import (
CmdRunAction,
MessageAction,
)
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import (
Event,
EventSource,
@ -84,36 +85,29 @@ def mock_state():
return state
def test_process_initial_messages(conversation_memory):
messages = conversation_memory.process_initial_messages(with_caching=False)
assert len(messages) == 1
assert messages[0].role == 'system'
assert messages[0].content[0].text == 'System message'
assert messages[0].content[0].cache_prompt is False
messages = conversation_memory.process_initial_messages(with_caching=True)
assert messages[0].content[0].cache_prompt is True
def test_process_events_with_message_action(conversation_memory):
"""Test that MessageAction is processed correctly."""
# Create a system message action
system_message = SystemMessageAction(content='System message')
system_message._source = EventSource.AGENT
# Create user and assistant messages
user_message = MessageAction(content='Hello')
user_message._source = EventSource.USER
assistant_message = MessageAction(content='Hi there')
assistant_message._source = EventSource.AGENT
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
# Process events
messages = conversation_memory.process_events(
condensed_history=[user_message, assistant_message],
initial_messages=initial_messages,
condensed_history=[system_message, user_message, assistant_message],
max_message_chars=None,
vision_is_active=False,
)
# Check that the messages were processed correctly
assert len(messages) == 3
assert messages[0].role == 'system'
assert messages[0].content[0].text == 'System message'
assert messages[1].role == 'user'
assert messages[1].content[0].text == 'Hello'
assert messages[2].role == 'assistant'
@ -131,19 +125,14 @@ def test_process_events_with_cmd_output_observation(conversation_memory):
),
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -159,19 +148,14 @@ def test_process_events_with_ipython_run_cell_observation(conversation_memory):
content='IPython output\n![image]()',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -188,19 +172,14 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
content='Content', outputs={'content': 'Delegated agent output'}
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -210,19 +189,14 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
def test_process_events_with_error_observation(conversation_memory):
obs = ErrorObservation('Error message')
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -234,14 +208,9 @@ def test_process_events_with_unknown_observation(conversation_memory):
# Create a mock that inherits from Event but not Action or Observation
obs = Mock(spec=Event)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
with pytest.raises(ValueError, match='Unknown event type'):
conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
@ -257,19 +226,14 @@ def test_process_events_with_file_edit_observation(conversation_memory):
impl_source=FileEditSource.LLM_BASED_EDIT,
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -283,19 +247,14 @@ def test_process_events_with_file_read_observation(conversation_memory):
impl_source=FileReadSource.DEFAULT,
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -311,19 +270,14 @@ def test_process_events_with_browser_output_observation(conversation_memory):
error=False,
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -333,19 +287,14 @@ def test_process_events_with_browser_output_observation(conversation_memory):
def test_process_events_with_user_reject_observation(conversation_memory):
obs = UserRejectObservation('Action rejected')
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -368,20 +317,14 @@ def test_process_events_with_empty_environment_info(conversation_memory):
content='Retrieved environment info',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[empty_obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# Should only contain the initial system message
assert len(messages) == 1
assert messages[0].role == 'system'
# Should only contain no messages
assert len(messages) == 0
# Verify that build_workspace_context was NOT called since all input values were empty
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
@ -405,20 +348,14 @@ def test_process_events_with_function_calling_observation(conversation_memory):
model_response=mock_response,
total_calls_in_response=1,
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# No direct message when using function calling
assert len(messages) == 1 # Only the initial system message
assert len(messages) == 0 # should be no messages
def test_process_events_with_message_action_with_image(conversation_memory):
@ -428,19 +365,14 @@ def test_process_events_with_message_action_with_image(conversation_memory):
)
action._source = EventSource.AGENT
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[action],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=True,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'assistant'
assert len(result.content) == 2
assert isinstance(result.content[0], TextContent)
@ -453,19 +385,14 @@ def test_process_events_with_user_cmd_action(conversation_memory):
action = CmdRunAction(command='ls -l')
action._source = EventSource.USER
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[action],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -491,19 +418,14 @@ def test_process_events_with_agent_finish_action_with_tool_metadata(
total_calls_in_response=1,
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[action],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'assistant'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -520,10 +442,11 @@ def test_apply_prompt_caching(conversation_memory):
conversation_memory.apply_prompt_caching(messages)
# Only the last user message should have cache_prompt=True
assert messages[0].content[0].cache_prompt is False
# System message is hard-coded to be cached always
assert messages[0].content[0].cache_prompt is True
assert messages[1].content[0].cache_prompt is False
assert messages[2].content[0].cache_prompt is False
# Only the last user message should have cache_prompt=True
assert messages[3].content[0].cache_prompt is True
@ -538,19 +461,14 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
content='Retrieved environment info',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -598,19 +516,14 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
content='Retrieved knowledge from microagents',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 1
result = messages[0]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@ -646,20 +559,14 @@ def test_process_events_with_microagent_observation_extensions_disabled(
content='Retrieved environment info',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# When prompt extensions are disabled, the RecallObservation should be ignored
assert len(messages) == 1 # Only the initial system message
assert messages[0].role == 'system'
assert len(messages) == 0 # should be no messages
# Verify the prompt_manager was not called
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
@ -674,20 +581,14 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory):
content='Retrieved knowledge from microagents',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# The implementation returns an empty string and it doesn't creates a message
assert len(messages) == 1
assert messages[0].role == 'system'
assert len(messages) == 0 # should be no messages
# When there are no triggered agents, build_microagent_info is not called
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
@ -892,27 +793,19 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
content='Third retrieval',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs1, obs2, obs3],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# Verify that only the first occurrence of content for each agent is included
assert (
len(messages) == 2
) # system + 1 microagent, because the second and third microagents are duplicates
microagent_messages = messages[1:] # Skip system message
assert len(messages) == 1
# First microagent should include all agents since they appear here first
assert 'Image best practices v1' in microagent_messages[0].content[0].text
assert 'Git best practices v1' in microagent_messages[0].content[0].text
assert 'Python best practices v1' in microagent_messages[0].content[0].text
assert 'Image best practices v1' in messages[0].content[0].text
assert 'Git best practices v1' in messages[0].content[0].text
assert 'Python best practices v1' in messages[0].content[0].text
def test_process_events_with_microagent_observation_deduplication_disabled_agents(
@ -949,26 +842,18 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
content='Second retrieval',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs1, obs2],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included
assert (
len(messages) == 2
) # system + 1 microagent, the second is the same "enabled_agent"
microagent_messages = messages[1:] # Skip system message
assert len(messages) == 1
# First microagent should include enabled_agent but not disabled_agent
assert 'Disabled agent content' not in microagent_messages[0].content[0].text
assert 'Enabled agent content v1' in microagent_messages[0].content[0].text
assert 'Disabled agent content' not in messages[0].content[0].text
assert 'Enabled agent content v1' in messages[0].content[0].text
def test_process_events_with_microagent_observation_deduplication_empty(
@ -981,21 +866,14 @@ def test_process_events_with_microagent_observation_deduplication_empty(
content='Empty retrieval',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# Verify that empty RecallObservations are handled gracefully
assert (
len(messages) == 1
) # system message, because an empty microagent is not added to Messages
assert len(messages) == 0 # an empty microagent is not added to Messages
def test_has_agent_in_earlier_events(conversation_memory):
@ -1198,3 +1076,22 @@ class TestFilterUnmatchedToolCalls:
assert len(result) == len(expected)
for i, msg in enumerate(result):
assert msg == expected[i]
def test_system_message_in_events(conversation_memory):
"""Test that SystemMessageAction in condensed_history is processed correctly."""
# Create a system message action
system_message = SystemMessageAction(content='System message', tools=['test_tool'])
system_message._source = EventSource.AGENT
# Process events with the system message in condensed_history
messages = conversation_memory.process_events(
condensed_history=[system_message],
max_message_chars=None,
vision_is_active=False,
)
# Check that the system message was processed correctly
assert len(messages) == 1
assert messages[0].role == 'system'
assert messages[0].content[0].text == 'System message'

View File

@ -25,6 +25,16 @@ class DummyAgent:
def reset(self):
pass
def get_system_message(self):
# Return a proper SystemMessageAction for the refactored system message handling
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import EventSource
system_message = SystemMessageAction(content='This is a dummy system message')
system_message._source = EventSource.AGENT
system_message._id = -1 # Set invalid ID to avoid the ID check
return system_message
@pytest.mark.asyncio
async def test_iteration_limit_extends_on_user_message():

View File

@ -58,15 +58,27 @@ def prompt_dir(tmp_path):
return tmp_path
@pytest.mark.asyncio
async def test_memory_on_event_exception_handling(memory, event_stream):
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
@pytest.fixture
def mock_agent():
# Create a dummy agent for the controller
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
# Add a proper system message mock
from openhands.events.action.message import SystemMessageAction
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
@pytest.mark.asyncio
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent):
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
# Create a mock runtime
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
@ -80,7 +92,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream):
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
)
@ -93,16 +105,10 @@ async def test_memory_on_event_exception_handling(memory, event_stream):
@pytest.mark.asyncio
async def test_memory_on_workspace_context_recall_exception_handling(
memory, event_stream
memory, event_stream, mock_agent
):
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
# Create a dummy agent for the controller
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
# Create a mock runtime
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
@ -118,7 +124,7 @@ async def test_memory_on_workspace_context_recall_exception_handling(
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
)

View File

@ -55,6 +55,10 @@ def response_mock(content: str, tool_call_id: str):
def test_get_messages(codeact_agent: CodeActAgent):
# Add some events to history
history = list()
# Add system message action
system_message_action = codeact_agent.get_system_message()
history.append(system_message_action)
message_action_1 = MessageAction('Initial user message')
message_action_1._source = 'user'
history.append(message_action_1)
@ -77,7 +81,8 @@ def test_get_messages(codeact_agent: CodeActAgent):
assert (
len(messages) == 6
) # System, initial user + user message, agent message, last user message
assert messages[0].content[0].cache_prompt # system message
assert messages[0].role == 'system' # system message
assert messages[0].content[0].cache_prompt # system message should be cached
assert messages[1].role == 'user'
assert messages[1].content[0].text.endswith('Initial user message')
# we add cache breakpoint to only the last user message
@ -96,6 +101,10 @@ def test_get_messages(codeact_agent: CodeActAgent):
def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
history = list()
# Add system message action
system_message_action = codeact_agent.get_system_message()
history.append(system_message_action)
# Add multiple user and agent messages
for i in range(15):
message_action_user = MessageAction(f'User message {i}')
@ -116,7 +125,7 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
]
assert (
len(cached_user_messages) == 2
) # Including the initial system+user + last user message
) # Including the initial system message + last user message
# Verify that these are indeed the last user message (from start)
assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent')

View File

@ -16,6 +16,16 @@ def agent_controller():
agent.name = 'test_agent'
agent.llm = llm
agent.config = AgentConfig()
# Add a proper system message mock
from openhands.events import EventSource
from openhands.events.action.message import SystemMessageAction
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
event_stream = EventStream(sid='test', file_store=InMemoryFileStore())
controller = AgentController(
agent=agent,