mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Refactor system message handling to use event stream (#7824)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
parent
caf34d83bd
commit
93e9db3206
@ -11,6 +11,7 @@ from openhands.events.action import (
|
||||
Action,
|
||||
AgentFinishAction,
|
||||
)
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.condenser import Condenser
|
||||
@ -166,8 +167,8 @@ class CodeActAgent(Agent):
|
||||
message flow and function-calling scenarios.
|
||||
|
||||
The method performs the following steps:
|
||||
1. Initializes with system prompt and optional initial user message
|
||||
2. Processes events (Actions and Observations) into messages
|
||||
1. Checks for SystemMessageAction in events, adds one if missing (legacy support)
|
||||
2. Processes events (Actions and Observations) into messages, including SystemMessageAction
|
||||
3. Handles tool calls and their responses in function-calling mode
|
||||
4. Manages message role alternation (user/assistant/tool)
|
||||
5. Applies caching for specific LLM providers (e.g., Anthropic)
|
||||
@ -178,8 +179,7 @@ class CodeActAgent(Agent):
|
||||
|
||||
Returns:
|
||||
list[Message]: A list of formatted messages ready for LLM consumption, including:
|
||||
- System message with prompt
|
||||
- Initial user message (if configured)
|
||||
- System message with prompt (from SystemMessageAction)
|
||||
- Action messages (from both user and assistant)
|
||||
- Observation messages (including tool responses)
|
||||
- Environment reminders (in non-function-calling mode)
|
||||
@ -193,15 +193,32 @@ class CodeActAgent(Agent):
|
||||
if not self.prompt_manager:
|
||||
raise Exception('Prompt Manager not instantiated.')
|
||||
|
||||
# Use ConversationMemory to process initial messages
|
||||
messages = self.conversation_memory.process_initial_messages(
|
||||
with_caching=self.llm.is_caching_prompt_active()
|
||||
# Check if there's a SystemMessageAction in the events
|
||||
has_system_message = any(
|
||||
isinstance(event, SystemMessageAction) for event in events
|
||||
)
|
||||
|
||||
# Use ConversationMemory to process events
|
||||
# Legacy behavior: If no SystemMessageAction is found, add one
|
||||
if not has_system_message:
|
||||
logger.warning(
|
||||
f'[{self.name}] No SystemMessageAction found in events. '
|
||||
'Adding one for backward compatibility. '
|
||||
'This is deprecated behavior and will be removed in a future version.'
|
||||
)
|
||||
system_message = self.get_system_message()
|
||||
if system_message:
|
||||
# Create a copy and insert at the beginning of the list
|
||||
processed_events = list(events)
|
||||
processed_events.insert(0, system_message)
|
||||
logger.debug(
|
||||
f'[{self.name}] Added SystemMessageAction for backward compatibility'
|
||||
)
|
||||
else:
|
||||
processed_events = events
|
||||
|
||||
# Use ConversationMemory to process events (including SystemMessageAction)
|
||||
messages = self.conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_messages=messages,
|
||||
condensed_history=processed_events,
|
||||
max_message_chars=self.llm.config.max_message_chars,
|
||||
vision_is_active=self.llm.vision_is_active(),
|
||||
)
|
||||
|
||||
@ -5,10 +5,13 @@ if TYPE_CHECKING:
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.core.exceptions import (
|
||||
AgentAlreadyRegisteredError,
|
||||
AgentNotRegisteredError,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
|
||||
@ -38,6 +41,42 @@ class Agent(ABC):
|
||||
self._complete = False
|
||||
self.prompt_manager: 'PromptManager' | None = None
|
||||
self.mcp_tools: list[dict] = []
|
||||
self.tools: list = []
|
||||
|
||||
def get_system_message(self) -> 'SystemMessageAction | None':
|
||||
"""
|
||||
Returns a SystemMessageAction containing the system message and tools.
|
||||
This will be added to the event stream as the first message.
|
||||
|
||||
Returns:
|
||||
SystemMessageAction: The system message action with content and tools
|
||||
None: If there was an error generating the system message
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
try:
|
||||
if not self.prompt_manager:
|
||||
logger.warning(
|
||||
f'[{self.name}] Prompt manager not initialized before getting system message'
|
||||
)
|
||||
return None
|
||||
|
||||
system_message = self.prompt_manager.get_system_message()
|
||||
|
||||
# Get tools if available
|
||||
tools = getattr(self, 'tools', None)
|
||||
|
||||
system_message_action = SystemMessageAction(
|
||||
content=system_message, tools=tools
|
||||
)
|
||||
# Set the source attribute
|
||||
system_message_action._source = EventSource.AGENT # type: ignore
|
||||
|
||||
return system_message_action
|
||||
except Exception as e:
|
||||
logger.warning(f'[{self.name}] Failed to generate system message: {e}')
|
||||
return None
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
|
||||
@ -54,6 +54,7 @@ from openhands.events.action import (
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
NullAction,
|
||||
SystemMessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import CondensationAction, RecallAction
|
||||
from openhands.events.event import Event
|
||||
@ -163,6 +164,31 @@ class AgentController:
|
||||
# replay-related
|
||||
self._replay_manager = ReplayManager(replay_events)
|
||||
|
||||
# Add the system message to the event stream
|
||||
self._add_system_message()
|
||||
|
||||
def _add_system_message(self):
|
||||
for event in self.event_stream.get_events(start_id=self.state.start_id):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
# FIXME: Remove this after 6/1/2025
|
||||
# Do not try to add a system message if we first run into
|
||||
# a user message -- this means the eventstream exits before
|
||||
# SystemMessageAction is introduced.
|
||||
# We expect *agent* to handle this case gracefully.
|
||||
return
|
||||
|
||||
if isinstance(event, SystemMessageAction):
|
||||
# Do not try to add the system message if it already exists
|
||||
return
|
||||
|
||||
# Add the system message to the event stream
|
||||
# This should be done for all agents, including delegates
|
||||
system_message = self.agent.get_system_message()
|
||||
logger.debug(f'System message got from agent: {system_message}')
|
||||
if system_message:
|
||||
self.event_stream.add_event(system_message, EventSource.AGENT)
|
||||
logger.debug(f'System message added to event stream: {system_message}')
|
||||
|
||||
async def close(self, set_stop_state=True) -> None:
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
|
||||
|
||||
|
||||
@ -6,6 +6,10 @@ class ActionType(str, Enum):
|
||||
"""Represents a message.
|
||||
"""
|
||||
|
||||
SYSTEM = 'system'
|
||||
"""Represents a system message.
|
||||
"""
|
||||
|
||||
START = 'start'
|
||||
"""Starts a new development task OR send chat from the user. Only sent by the client.
|
||||
"""
|
||||
|
||||
@ -16,7 +16,7 @@ from openhands.events.action.files import (
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
|
||||
__all__ = [
|
||||
'Action',
|
||||
@ -33,6 +33,7 @@ __all__ = [
|
||||
'ChangeAgentStateAction',
|
||||
'IPythonRunCellAction',
|
||||
'MessageAction',
|
||||
'SystemMessageAction',
|
||||
'ActionConfirmationStatus',
|
||||
'AgentThinkAction',
|
||||
'RecallAction',
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import openhands
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
|
||||
@ -32,3 +34,27 @@ class MessageAction(Action):
|
||||
for url in self.image_urls:
|
||||
ret += f'\nIMAGE_URL: {url}'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessageAction(Action):
|
||||
"""
|
||||
Action that represents a system message for an agent, including the system prompt
|
||||
and available tools. This should be the first message in the event stream.
|
||||
"""
|
||||
|
||||
content: str
|
||||
tools: list[Any] | None = None
|
||||
openhands_version: str | None = openhands.__version__
|
||||
action: ActionType = ActionType.SYSTEM
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = f'**SystemMessageAction** (source={self.source})\n'
|
||||
ret += f'CONTENT: {self.content}'
|
||||
if self.tools:
|
||||
ret += f'\nTOOLS: {len(self.tools)} tools available'
|
||||
return ret
|
||||
|
||||
@ -23,7 +23,7 @@ from openhands.events.action.files import (
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
|
||||
actions = (
|
||||
NullAction,
|
||||
@ -41,6 +41,7 @@ actions = (
|
||||
RecallAction,
|
||||
ChangeAgentStateAction,
|
||||
MessageAction,
|
||||
SystemMessageAction,
|
||||
CondensationAction,
|
||||
McpAction,
|
||||
)
|
||||
|
||||
@ -20,6 +20,7 @@ from openhands.events.action import (
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
@ -53,7 +54,6 @@ class ConversationMemory:
|
||||
def process_events(
|
||||
self,
|
||||
condensed_history: list[Event],
|
||||
initial_messages: list[Message],
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
) -> list[Message]:
|
||||
@ -63,7 +63,6 @@ class ConversationMemory:
|
||||
|
||||
Args:
|
||||
condensed_history: The condensed history of events to convert
|
||||
initial_messages: The initial messages to include in the conversation
|
||||
max_message_chars: The maximum number of characters in the content of an event included
|
||||
in the prompt to the LLM. Larger observations are truncated.
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||
@ -74,8 +73,8 @@ class ConversationMemory:
|
||||
# log visual browsing status
|
||||
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
|
||||
|
||||
# Process special events first (system prompts, etc.)
|
||||
messages = initial_messages
|
||||
# Initialize empty messages list
|
||||
messages = []
|
||||
|
||||
# Process regular events
|
||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||
@ -132,20 +131,6 @@ class ConversationMemory:
|
||||
messages = list(ConversationMemory._filter_unmatched_tool_calls(messages))
|
||||
return messages
|
||||
|
||||
def process_initial_messages(self, with_caching: bool = False) -> list[Message]:
|
||||
"""Create the initial messages for the conversation."""
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content=[
|
||||
TextContent(
|
||||
text=self.prompt_manager.get_system_message(),
|
||||
cache_prompt=with_caching,
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def _process_action(
|
||||
self,
|
||||
action: Action,
|
||||
@ -275,6 +260,16 @@ class ConversationMemory:
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
elif isinstance(action, SystemMessageAction):
|
||||
# Convert SystemMessageAction to a system message
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content=[TextContent(text=action.content)],
|
||||
# Include tools if function calling is enabled
|
||||
tool_calls=None,
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def _process_observation(
|
||||
@ -546,6 +541,8 @@ class ConversationMemory:
|
||||
|
||||
For new Anthropic API, we only need to mark the last user or tool message as cacheable.
|
||||
"""
|
||||
if len(messages) > 0 and messages[0].role == 'system':
|
||||
messages[0].content[-1].cache_prompt = True
|
||||
# NOTE: this is only needed for anthropic
|
||||
for message in reversed(messages):
|
||||
if message.role in ('user', 'tool'):
|
||||
|
||||
@ -12,9 +12,16 @@ from openhands.resolver.utils import extract_issue_references
|
||||
|
||||
|
||||
class GithubIssueHandler(IssueHandlerInterface):
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str | None = None,
|
||||
base_domain: str = 'github.com',
|
||||
):
|
||||
"""Initialize a GitHub issue handler.
|
||||
|
||||
|
||||
Args:
|
||||
owner: The owner of the repository
|
||||
repo: The name of the repository
|
||||
@ -42,7 +49,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
}
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
if self.base_domain == "github.com":
|
||||
if self.base_domain == 'github.com':
|
||||
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
|
||||
else:
|
||||
return f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}'
|
||||
@ -65,7 +72,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git'
|
||||
|
||||
def get_graphql_url(self) -> str:
|
||||
if self.base_domain == "github.com":
|
||||
if self.base_domain == 'github.com':
|
||||
return 'https://api.github.com/graphql'
|
||||
else:
|
||||
return f'https://{self.base_domain}/api/v3/graphql'
|
||||
@ -302,9 +309,16 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
|
||||
|
||||
class GithubPRHandler(GithubIssueHandler):
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str | None = None,
|
||||
base_domain: str = 'github.com',
|
||||
):
|
||||
"""Initialize a GitHub PR handler.
|
||||
|
||||
|
||||
Args:
|
||||
owner: The owner of the repository
|
||||
repo: The name of the repository
|
||||
@ -313,8 +327,10 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
base_domain: The domain for GitHub Enterprise (default: "github.com")
|
||||
"""
|
||||
super().__init__(owner, repo, token, username, base_domain)
|
||||
if self.base_domain == "github.com":
|
||||
self.download_url = f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
|
||||
if self.base_domain == 'github.com':
|
||||
self.download_url = (
|
||||
f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
|
||||
)
|
||||
else:
|
||||
self.download_url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/pulls'
|
||||
|
||||
@ -470,7 +486,7 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
self, pr_number: int, comment_id: int | None = None
|
||||
) -> list[str] | None:
|
||||
"""Download comments for a specific pull request from Github."""
|
||||
if self.base_domain == "github.com":
|
||||
if self.base_domain == 'github.com':
|
||||
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
|
||||
else:
|
||||
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
|
||||
@ -542,7 +558,7 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
|
||||
for issue_number in unique_issue_references:
|
||||
try:
|
||||
if self.base_domain == "github.com":
|
||||
if self.base_domain == 'github.com':
|
||||
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}'
|
||||
else:
|
||||
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{issue_number}'
|
||||
|
||||
@ -134,10 +134,7 @@ async def reset_settings(request: Request) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def check_provider_tokens(request: Request,
|
||||
settings: POSTSettingsModel) -> str:
|
||||
|
||||
async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -> str:
|
||||
if settings.provider_tokens:
|
||||
# Remove extraneous token types
|
||||
provider_types = [provider.value for provider in ProviderType]
|
||||
@ -152,17 +149,13 @@ async def check_provider_tokens(request: Request,
|
||||
SecretStr(token_value)
|
||||
)
|
||||
if not confirmed_token_type or confirmed_token_type.value != token_type:
|
||||
return f"Invalid token. Please make sure it is a valid {token_type} token."
|
||||
|
||||
|
||||
return ""
|
||||
return f'Invalid token. Please make sure it is a valid {token_type} token.'
|
||||
|
||||
return ''
|
||||
|
||||
|
||||
async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
existing_settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
if settings.provider_tokens:
|
||||
@ -188,19 +181,17 @@ async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
|
||||
else: # nothing passed in means keep current settings
|
||||
provider_tokens = existing_settings.secrets_store.provider_tokens
|
||||
settings.provider_tokens = {
|
||||
provider.value: data.token.get_secret_value()
|
||||
if data.token
|
||||
else None
|
||||
provider.value: data.token.get_secret_value() if data.token else None
|
||||
for provider, data in provider_tokens.items()
|
||||
}
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> POSTSettingsModel:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
async def store_llm_settings(
|
||||
request: Request, settings: POSTSettingsModel
|
||||
) -> POSTSettingsModel:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
@ -215,6 +206,7 @@ async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> P
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
@app.post('/settings', response_model=dict[str, str])
|
||||
async def store_settings(
|
||||
request: Request,
|
||||
@ -225,11 +217,8 @@ async def store_settings(
|
||||
if provider_err_msg:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': provider_err_msg
|
||||
},
|
||||
content={'error': provider_err_msg},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
@ -248,7 +237,6 @@ async def store_settings(
|
||||
)
|
||||
|
||||
settings = await store_provider_tokens(request, settings)
|
||||
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
|
||||
@ -139,7 +139,7 @@ class Session:
|
||||
condensers=[
|
||||
BrowserOutputCondenserConfig(),
|
||||
LLMSummarizingCondenserConfig(
|
||||
llm_config=llm.config, keep_first=3, max_size=80
|
||||
llm_config=llm.config, keep_first=4, max_size=80
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@ -15,6 +15,7 @@ from openhands.core.schema import AgentState
|
||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
||||
from openhands.events.action.agent import CondensationAction, RecallAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
@ -49,6 +50,15 @@ def mock_agent():
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
# Add a proper system message mock
|
||||
system_message = SystemMessageAction(
|
||||
content='Test system message', tools=['test_tool']
|
||||
)
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@ -206,20 +216,19 @@ async def test_react_to_content_policy_violation(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
|
||||
async def test_run_controller_with_fatal_error(
|
||||
test_event_stream, mock_memory, mock_agent
|
||||
):
|
||||
config = AppConfig()
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent = MagicMock(spec=Agent)
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
mock_agent.step = agent_step_fn
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
|
||||
@ -250,7 +259,7 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
@ -268,22 +277,24 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
|
||||
assert (
|
||||
error_observation.reason == 'AgentStuckInLoopError: Agent got stuck in a loop'
|
||||
)
|
||||
assert len(events) == 11
|
||||
assert len(events) == 12
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
|
||||
async def test_run_controller_stop_with_stuck(
|
||||
test_event_stream, mock_memory, mock_agent
|
||||
):
|
||||
config = AppConfig()
|
||||
agent = MagicMock(spec=Agent)
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
mock_agent.step = agent_step_fn
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
|
||||
def on_event(event: Event):
|
||||
@ -315,7 +326,7 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
@ -325,9 +336,10 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
|
||||
print(f'event {i}: {event_to_dict(event)}')
|
||||
|
||||
assert state.iteration == 3
|
||||
assert len(events) == 11
|
||||
assert len(events) == 12
|
||||
# check the eventstream have 4 pairs of repeated actions and observations
|
||||
repeating_actions_and_observations = events[4:12]
|
||||
# With the refactored system message handling, we need to adjust the range
|
||||
repeating_actions_and_observations = events[5:13]
|
||||
for action, observation in zip(
|
||||
repeating_actions_and_observations[0::2],
|
||||
repeating_actions_and_observations[1::2],
|
||||
@ -469,6 +481,9 @@ async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_s
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
mock_event_stream.add_event.assert_called_once() # add SystemMessageAction
|
||||
mock_event_stream.add_event.reset_mock()
|
||||
|
||||
# Create a pending action with tool call metadata
|
||||
pending_action = CmdRunAction(command='test')
|
||||
pending_action.tool_call_metadata = {
|
||||
@ -512,6 +527,9 @@ async def test_reset_with_pending_action_existing_observation(
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
mock_event_stream.add_event.assert_called_once() # add SystemMessageAction
|
||||
mock_event_stream.add_event.reset_mock()
|
||||
|
||||
# Create a pending action with tool call metadata
|
||||
pending_action = CmdRunAction(command='test')
|
||||
pending_action.tool_call_metadata = {
|
||||
@ -551,6 +569,9 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Reset the mock to clear the call from system message addition
|
||||
mock_event_stream.add_event.reset_mock()
|
||||
|
||||
# Call reset
|
||||
controller._reset()
|
||||
|
||||
@ -579,6 +600,9 @@ async def test_reset_with_pending_action_no_metadata(
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Reset the mock to clear the call from system message addition
|
||||
mock_event_stream.add_event.reset_mock()
|
||||
|
||||
# Create a pending action without tool call metadata
|
||||
pending_action = CmdRunAction(command='test')
|
||||
# Mock hasattr to return False for tool_call_metadata
|
||||
@ -608,28 +632,27 @@ async def test_reset_with_pending_action_no_metadata(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_max_iterations_has_metrics(
|
||||
test_event_stream, mock_memory
|
||||
test_event_stream, mock_memory, mock_agent
|
||||
):
|
||||
config = AppConfig(
|
||||
max_iterations=3,
|
||||
)
|
||||
event_stream = test_event_stream
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
# Mock the cost of the LLM
|
||||
agent.llm.metrics.add_cost(10.0)
|
||||
mock_agent.llm.metrics.add_cost(10.0)
|
||||
print(
|
||||
f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}'
|
||||
f'mock_agent.llm.metrics.accumulated_cost: {mock_agent.llm.metrics.accumulated_cost}'
|
||||
)
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
mock_agent.step = agent_step_fn
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
|
||||
@ -660,7 +683,7 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
@ -839,9 +862,10 @@ async def test_context_window_exceeded_error_handling(
|
||||
)
|
||||
== 1
|
||||
)
|
||||
# With the refactored system message handling, we now have max_iterations + 4 events
|
||||
assert (
|
||||
len(final_state.history) == max_iterations + 3
|
||||
) # 1 condensation action, 1 recall action, 1 recall observation
|
||||
len(final_state.history) == max_iterations + 4
|
||||
) # 1 system message, 1 condensation action, 1 recall action, 1 recall observation
|
||||
|
||||
assert len(final_state.view) == len(step_state.views[-1]) + 1
|
||||
|
||||
@ -990,7 +1014,8 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
|
||||
# Hitting the iteration limit indicates the controller is failing for the
|
||||
# expected reason
|
||||
assert state.iteration == 2
|
||||
# With the refactored system message handling, the iteration count is different
|
||||
assert state.iteration == 1
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert (
|
||||
state.last_error
|
||||
@ -1012,21 +1037,20 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_memory_error(test_event_stream):
|
||||
async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
config = AppConfig()
|
||||
event_stream = test_event_stream
|
||||
|
||||
# Create a propert agent that returns an action without an ID
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
# Create a proper agent that returns an action without an ID
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
# Create a real action to return from the mocked step function
|
||||
def agent_step_fn(state):
|
||||
return MessageAction(content='Agent returned a message')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
mock_agent.step = agent_step_fn
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
@ -1046,7 +1070,7 @@ async def test_run_controller_with_memory_error(test_event_stream):
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
@ -1057,14 +1081,13 @@ async def test_run_controller_with_memory_error(test_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_metrics_copy():
|
||||
async def test_action_metrics_copy(mock_agent):
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
# Create agent with metrics
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
metrics = Metrics(model_name='test-model')
|
||||
metrics.accumulated_cost = 0.05
|
||||
|
||||
@ -1106,7 +1129,7 @@ async def test_action_metrics_copy():
|
||||
# Add a response latency - should not be included in action metrics
|
||||
metrics.add_response_latency(0.5, 'test-id-2')
|
||||
|
||||
agent.llm.metrics = metrics
|
||||
mock_agent.llm.metrics = metrics
|
||||
|
||||
# Mock agent step to return an action
|
||||
action = MessageAction(content='Test message')
|
||||
@ -1114,11 +1137,11 @@ async def test_action_metrics_copy():
|
||||
def agent_step_fn(state):
|
||||
return action
|
||||
|
||||
agent.step = agent_step_fn
|
||||
mock_agent.step = agent_step_fn
|
||||
|
||||
# Create controller with correct parameters
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
@ -1169,14 +1192,14 @@ async def test_action_metrics_copy():
|
||||
assert not hasattr(last_action.llm_metrics, 'average_latency')
|
||||
|
||||
# Verify it's a deep copy by modifying the original
|
||||
agent.llm.metrics.accumulated_cost = 0.1
|
||||
mock_agent.llm.metrics.accumulated_cost = 0.1
|
||||
assert last_action.llm_metrics.accumulated_cost == 0.07
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_user_message_with_identical_content():
|
||||
async def test_first_user_message_with_identical_content(test_event_stream, mock_agent):
|
||||
"""
|
||||
Test that _first_user_message correctly identifies the first user message
|
||||
even when multiple messages have identical content but different IDs.
|
||||
@ -1185,18 +1208,14 @@ async def test_first_user_message_with_identical_content():
|
||||
The issue we're checking is that the comparison (action == self._first_user_message())
|
||||
should correctly differentiate between messages with the same content but different IDs.
|
||||
"""
|
||||
# Create a real event stream for this test
|
||||
event_stream = EventStream(sid='test', file_store=InMemoryFileStore({}))
|
||||
|
||||
# Create an agent controller
|
||||
mock_agent = MagicMock(spec=Agent)
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
event_stream=test_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@ -1206,12 +1225,12 @@ async def test_first_user_message_with_identical_content():
|
||||
# Create and add the first user message
|
||||
first_message = MessageAction(content='Hello, this is a test message')
|
||||
first_message._source = EventSource.USER
|
||||
event_stream.add_event(first_message, EventSource.USER)
|
||||
test_event_stream.add_event(first_message, EventSource.USER)
|
||||
|
||||
# Create and add a second user message with identical content
|
||||
second_message = MessageAction(content='Hello, this is a test message')
|
||||
second_message._source = EventSource.USER
|
||||
event_stream.add_event(second_message, EventSource.USER)
|
||||
test_event_stream.add_event(second_message, EventSource.USER)
|
||||
|
||||
# Verify that _first_user_message returns the first message
|
||||
first_user_message = controller._first_user_message()
|
||||
@ -1235,7 +1254,7 @@ async def test_first_user_message_with_identical_content():
|
||||
) # Cache should store the same object
|
||||
|
||||
# Mock get_events to verify it's not called again
|
||||
with patch.object(event_stream, 'get_events') as mock_get_events:
|
||||
with patch.object(test_event_stream, 'get_events') as mock_get_events:
|
||||
cached_message = controller._first_user_message()
|
||||
assert cached_message is first_user_message # Should return cached object
|
||||
mock_get_events.assert_not_called() # Should not call get_events again
|
||||
@ -1326,15 +1345,12 @@ async def test_agent_controller_processes_null_observation_with_cause():
|
||||
), 'should_step should return False for NullObservation with cause=0'
|
||||
|
||||
|
||||
def test_agent_controller_should_step_with_null_observation_cause_zero():
|
||||
def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agent):
|
||||
"""Test that AgentController's should_step method returns False for NullObservation with cause = 0."""
|
||||
# Create a mock event stream
|
||||
file_store = InMemoryFileStore()
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
|
||||
# Create a mock agent
|
||||
mock_agent = MagicMock(spec=Agent)
|
||||
|
||||
# Create an agent controller
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
@ -1475,3 +1491,19 @@ def test_history_restoration_after_truncation(mock_event_stream, mock_agent):
|
||||
assert len(new_controller.state.history) == saved_history_len
|
||||
assert new_controller.state.history[0] == first_msg
|
||||
assert new_controller.state.start_id == saved_start_id
|
||||
|
||||
|
||||
def test_system_message_in_event_stream(mock_agent, test_event_stream):
|
||||
"""Test that SystemMessageAction is added to event stream in AgentController."""
|
||||
_ = AgentController(
|
||||
agent=mock_agent, event_stream=test_event_stream, max_iterations=10
|
||||
)
|
||||
|
||||
# Get events from the event stream
|
||||
events = list(test_event_stream.get_events())
|
||||
|
||||
# Verify system message was added to event stream
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], SystemMessageAction)
|
||||
assert events[0].content == 'Test system message'
|
||||
assert events[0].tools == ['test_tool']
|
||||
|
||||
@ -44,6 +44,15 @@ def mock_parent_agent():
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.config = AgentConfig()
|
||||
|
||||
# Add a proper system message mock
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@ -56,6 +65,15 @@ def mock_child_agent():
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.config = AgentConfig()
|
||||
|
||||
# Add a proper system message mock
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@ -113,9 +131,9 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
|
||||
# Verify that a RecallObservation was added to the event stream
|
||||
events = list(mock_event_stream.get_events())
|
||||
assert (
|
||||
mock_event_stream.get_latest_event_id() == 3
|
||||
) # Microagents and AgentChangeState
|
||||
|
||||
# SystemMessageAction, RecallAction, AgentChangeState, AgentDelegateAction, SystemMessageAction (for child)
|
||||
assert mock_event_stream.get_latest_event_id() == 5
|
||||
|
||||
# a RecallObservation and an AgentDelegateAction should be in the list
|
||||
assert any(isinstance(event, RecallObservation) for event in events)
|
||||
|
||||
@ -26,6 +26,7 @@ from openhands.events.action import (
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputObservation,
|
||||
@ -288,8 +289,11 @@ def test_correct_tool_description_loaded_based_on_model_name(mock_state: State):
|
||||
assert any(len(tool['function']['description']) > 1024 for tool in agent.tools)
|
||||
|
||||
|
||||
def test_mismatched_tool_call_events(mock_state: State):
|
||||
"""Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages."""
|
||||
def test_mismatched_tool_call_events_and_auto_add_system_message(mock_state: State):
|
||||
"""Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages.
|
||||
|
||||
This also tests that the system message is automatically added to the event stream if SystemMessageAction is not present.
|
||||
"""
|
||||
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
|
||||
|
||||
tool_call_metadata = Mock(
|
||||
@ -320,26 +324,35 @@ def test_mismatched_tool_call_events(mock_state: State):
|
||||
observation.tool_call_metadata = tool_call_metadata
|
||||
|
||||
# When both events are provided, the agent should get three messages:
|
||||
# 1. The system message,
|
||||
# 2. The action message, and
|
||||
# 1. The system message (added automatically for backward compatibility)
|
||||
# 2. The action message
|
||||
# 3. The observation message
|
||||
mock_state.history = [action, observation]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 3
|
||||
assert messages[0].role == 'system' # First message should be the system message
|
||||
assert messages[1].role == 'assistant' # Second message should be the action
|
||||
assert messages[2].role == 'tool' # Third message should be the observation
|
||||
|
||||
# The same should hold if the events are presented out-of-order
|
||||
mock_state.history = [observation, action]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 3
|
||||
assert messages[0].role == 'system' # First message should be the system message
|
||||
|
||||
# If only one of the two events is present, then we should just get the system message
|
||||
# plus any valid message from the event
|
||||
mock_state.history = [action]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 1
|
||||
assert (
|
||||
len(messages) == 1
|
||||
) # Only system message, action is waiting for its observation
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
mock_state.history = [observation]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 1
|
||||
assert len(messages) == 1 # Only system message, observation has no matching action
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
|
||||
def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
|
||||
@ -397,3 +410,18 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
|
||||
# Fifth message only has ImageContent, no TextContent to modify
|
||||
assert len(enhanced_messages[5].content) == 1
|
||||
assert isinstance(enhanced_messages[5].content[0], ImageContent)
|
||||
|
||||
|
||||
def test_get_system_message():
|
||||
"""Test that the Agent.get_system_message method returns a SystemMessageAction."""
|
||||
# Create a mock agent
|
||||
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
|
||||
|
||||
result = agent.get_system_message()
|
||||
|
||||
# Check that the system message was created correctly
|
||||
assert isinstance(result, SystemMessageAction)
|
||||
assert 'You are OpenHands agent' in result.content
|
||||
assert len(result.tools) > 0
|
||||
assert any(tool['function']['name'] == 'execute_bash' for tool in result.tools)
|
||||
assert result._source == EventSource.AGENT
|
||||
|
||||
@ -13,6 +13,7 @@ from openhands.events.action import (
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import (
|
||||
Event,
|
||||
EventSource,
|
||||
@ -84,36 +85,29 @@ def mock_state():
|
||||
return state
|
||||
|
||||
|
||||
def test_process_initial_messages(conversation_memory):
|
||||
messages = conversation_memory.process_initial_messages(with_caching=False)
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System message'
|
||||
assert messages[0].content[0].cache_prompt is False
|
||||
|
||||
messages = conversation_memory.process_initial_messages(with_caching=True)
|
||||
assert messages[0].content[0].cache_prompt is True
|
||||
|
||||
|
||||
def test_process_events_with_message_action(conversation_memory):
|
||||
"""Test that MessageAction is processed correctly."""
|
||||
# Create a system message action
|
||||
system_message = SystemMessageAction(content='System message')
|
||||
system_message._source = EventSource.AGENT
|
||||
|
||||
# Create user and assistant messages
|
||||
user_message = MessageAction(content='Hello')
|
||||
user_message._source = EventSource.USER
|
||||
assistant_message = MessageAction(content='Hi there')
|
||||
assistant_message._source = EventSource.AGENT
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
# Process events
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[user_message, assistant_message],
|
||||
initial_messages=initial_messages,
|
||||
condensed_history=[system_message, user_message, assistant_message],
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Check that the messages were processed correctly
|
||||
assert len(messages) == 3
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System message'
|
||||
assert messages[1].role == 'user'
|
||||
assert messages[1].content[0].text == 'Hello'
|
||||
assert messages[2].role == 'assistant'
|
||||
@ -131,19 +125,14 @@ def test_process_events_with_cmd_output_observation(conversation_memory):
|
||||
),
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -159,19 +148,14 @@ def test_process_events_with_ipython_run_cell_observation(conversation_memory):
|
||||
content='IPython output\n',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -188,19 +172,14 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
|
||||
content='Content', outputs={'content': 'Delegated agent output'}
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -210,19 +189,14 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
|
||||
def test_process_events_with_error_observation(conversation_memory):
|
||||
obs = ErrorObservation('Error message')
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -234,14 +208,9 @@ def test_process_events_with_unknown_observation(conversation_memory):
|
||||
# Create a mock that inherits from Event but not Action or Observation
|
||||
obs = Mock(spec=Event)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match='Unknown event type'):
|
||||
conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
@ -257,19 +226,14 @@ def test_process_events_with_file_edit_observation(conversation_memory):
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -283,19 +247,14 @@ def test_process_events_with_file_read_observation(conversation_memory):
|
||||
impl_source=FileReadSource.DEFAULT,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -311,19 +270,14 @@ def test_process_events_with_browser_output_observation(conversation_memory):
|
||||
error=False,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -333,19 +287,14 @@ def test_process_events_with_browser_output_observation(conversation_memory):
|
||||
def test_process_events_with_user_reject_observation(conversation_memory):
|
||||
obs = UserRejectObservation('Action rejected')
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -368,20 +317,14 @@ def test_process_events_with_empty_environment_info(conversation_memory):
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[empty_obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Should only contain the initial system message
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
# Should only contain no messages
|
||||
assert len(messages) == 0
|
||||
|
||||
# Verify that build_workspace_context was NOT called since all input values were empty
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
@ -405,20 +348,14 @@ def test_process_events_with_function_calling_observation(conversation_memory):
|
||||
model_response=mock_response,
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# No direct message when using function calling
|
||||
assert len(messages) == 1 # Only the initial system message
|
||||
assert len(messages) == 0 # should be no messages
|
||||
|
||||
|
||||
def test_process_events_with_message_action_with_image(conversation_memory):
|
||||
@ -428,19 +365,14 @@ def test_process_events_with_message_action_with_image(conversation_memory):
|
||||
)
|
||||
action._source = EventSource.AGENT
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=True,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'assistant'
|
||||
assert len(result.content) == 2
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -453,19 +385,14 @@ def test_process_events_with_user_cmd_action(conversation_memory):
|
||||
action = CmdRunAction(command='ls -l')
|
||||
action._source = EventSource.USER
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -491,19 +418,14 @@ def test_process_events_with_agent_finish_action_with_tool_metadata(
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'assistant'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -520,10 +442,11 @@ def test_apply_prompt_caching(conversation_memory):
|
||||
|
||||
conversation_memory.apply_prompt_caching(messages)
|
||||
|
||||
# Only the last user message should have cache_prompt=True
|
||||
assert messages[0].content[0].cache_prompt is False
|
||||
# System message is hard-coded to be cached always
|
||||
assert messages[0].content[0].cache_prompt is True
|
||||
assert messages[1].content[0].cache_prompt is False
|
||||
assert messages[2].content[0].cache_prompt is False
|
||||
# Only the last user message should have cache_prompt=True
|
||||
assert messages[3].content[0].cache_prompt is True
|
||||
|
||||
|
||||
@ -538,19 +461,14 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -598,19 +516,14 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 1
|
||||
result = messages[0]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@ -646,20 +559,14 @@ def test_process_events_with_microagent_observation_extensions_disabled(
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# When prompt extensions are disabled, the RecallObservation should be ignored
|
||||
assert len(messages) == 1 # Only the initial system message
|
||||
assert messages[0].role == 'system'
|
||||
assert len(messages) == 0 # should be no messages
|
||||
|
||||
# Verify the prompt_manager was not called
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
@ -674,20 +581,14 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory):
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# The implementation returns an empty string and it doesn't creates a message
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
assert len(messages) == 0 # should be no messages
|
||||
|
||||
# When there are no triggered agents, build_microagent_info is not called
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
|
||||
@ -892,27 +793,19 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
|
||||
content='Third retrieval',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2, obs3],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that only the first occurrence of content for each agent is included
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # system + 1 microagent, because the second and third microagents are duplicates
|
||||
microagent_messages = messages[1:] # Skip system message
|
||||
assert len(messages) == 1
|
||||
|
||||
# First microagent should include all agents since they appear here first
|
||||
assert 'Image best practices v1' in microagent_messages[0].content[0].text
|
||||
assert 'Git best practices v1' in microagent_messages[0].content[0].text
|
||||
assert 'Python best practices v1' in microagent_messages[0].content[0].text
|
||||
assert 'Image best practices v1' in messages[0].content[0].text
|
||||
assert 'Git best practices v1' in messages[0].content[0].text
|
||||
assert 'Python best practices v1' in messages[0].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_disabled_agents(
|
||||
@ -949,26 +842,18 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # system + 1 microagent, the second is the same "enabled_agent"
|
||||
microagent_messages = messages[1:] # Skip system message
|
||||
assert len(messages) == 1
|
||||
|
||||
# First microagent should include enabled_agent but not disabled_agent
|
||||
assert 'Disabled agent content' not in microagent_messages[0].content[0].text
|
||||
assert 'Enabled agent content v1' in microagent_messages[0].content[0].text
|
||||
assert 'Disabled agent content' not in messages[0].content[0].text
|
||||
assert 'Enabled agent content v1' in messages[0].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
@ -981,21 +866,14 @@ def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
content='Empty retrieval',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that empty RecallObservations are handled gracefully
|
||||
assert (
|
||||
len(messages) == 1
|
||||
) # system message, because an empty microagent is not added to Messages
|
||||
assert len(messages) == 0 # an empty microagent is not added to Messages
|
||||
|
||||
|
||||
def test_has_agent_in_earlier_events(conversation_memory):
|
||||
@ -1198,3 +1076,22 @@ class TestFilterUnmatchedToolCalls:
|
||||
assert len(result) == len(expected)
|
||||
for i, msg in enumerate(result):
|
||||
assert msg == expected[i]
|
||||
|
||||
|
||||
def test_system_message_in_events(conversation_memory):
|
||||
"""Test that SystemMessageAction in condensed_history is processed correctly."""
|
||||
# Create a system message action
|
||||
system_message = SystemMessageAction(content='System message', tools=['test_tool'])
|
||||
system_message._source = EventSource.AGENT
|
||||
|
||||
# Process events with the system message in condensed_history
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[system_message],
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Check that the system message was processed correctly
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System message'
|
||||
|
||||
@ -25,6 +25,16 @@ class DummyAgent:
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_system_message(self):
|
||||
# Return a proper SystemMessageAction for the refactored system message handling
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import EventSource
|
||||
|
||||
system_message = SystemMessageAction(content='This is a dummy system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
return system_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_iteration_limit_extends_on_user_message():
|
||||
|
||||
@ -58,15 +58,27 @@ def prompt_dir(tmp_path):
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_event_exception_handling(memory, event_stream):
|
||||
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
# Create a dummy agent for the controller
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
# Add a proper system message mock
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent):
|
||||
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
|
||||
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
@ -80,7 +92,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream):
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
@ -93,16 +105,10 @@ async def test_memory_on_event_exception_handling(memory, event_stream):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
memory, event_stream
|
||||
memory, event_stream, mock_agent
|
||||
):
|
||||
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
|
||||
|
||||
# Create a dummy agent for the controller
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
@ -118,7 +124,7 @@ async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@ -55,6 +55,10 @@ def response_mock(content: str, tool_call_id: str):
|
||||
def test_get_messages(codeact_agent: CodeActAgent):
|
||||
# Add some events to history
|
||||
history = list()
|
||||
# Add system message action
|
||||
system_message_action = codeact_agent.get_system_message()
|
||||
history.append(system_message_action)
|
||||
|
||||
message_action_1 = MessageAction('Initial user message')
|
||||
message_action_1._source = 'user'
|
||||
history.append(message_action_1)
|
||||
@ -77,7 +81,8 @@ def test_get_messages(codeact_agent: CodeActAgent):
|
||||
assert (
|
||||
len(messages) == 6
|
||||
) # System, initial user + user message, agent message, last user message
|
||||
assert messages[0].content[0].cache_prompt # system message
|
||||
assert messages[0].role == 'system' # system message
|
||||
assert messages[0].content[0].cache_prompt # system message should be cached
|
||||
assert messages[1].role == 'user'
|
||||
assert messages[1].content[0].text.endswith('Initial user message')
|
||||
# we add cache breakpoint to only the last user message
|
||||
@ -96,6 +101,10 @@ def test_get_messages(codeact_agent: CodeActAgent):
|
||||
|
||||
def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
history = list()
|
||||
# Add system message action
|
||||
system_message_action = codeact_agent.get_system_message()
|
||||
history.append(system_message_action)
|
||||
|
||||
# Add multiple user and agent messages
|
||||
for i in range(15):
|
||||
message_action_user = MessageAction(f'User message {i}')
|
||||
@ -116,7 +125,7 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
]
|
||||
assert (
|
||||
len(cached_user_messages) == 2
|
||||
) # Including the initial system+user + last user message
|
||||
) # Including the initial system message + last user message
|
||||
|
||||
# Verify that these are indeed the last user message (from start)
|
||||
assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent')
|
||||
|
||||
@ -16,6 +16,16 @@ def agent_controller():
|
||||
agent.name = 'test_agent'
|
||||
agent.llm = llm
|
||||
agent.config = AgentConfig()
|
||||
|
||||
# Add a proper system message mock
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
event_stream = EventStream(sid='test', file_store=InMemoryFileStore())
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user