Add RecallActions and observations for retrieval of prompt extensions (#6909)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
Engel Nyst 2025-03-15 21:48:37 +01:00 committed by GitHub
parent e34a771e66
commit cc45f5d9c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 2317 additions and 735 deletions

View File

@ -1,8 +1,6 @@
import json
import os
from collections import deque
import openhands
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
@ -74,21 +72,14 @@ class CodeActAgent(Agent):
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
)
logger.debug(
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
f'TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}'
)
self.prompt_manager = PromptManager(
microagent_dir=os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
'microagents',
)
if self.config.enable_prompt_extensions
else None,
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
disabled_microagents=self.config.disabled_microagents,
)
# Create a ConversationMemory instance
self.conversation_memory = ConversationMemory(self.prompt_manager)
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
self.condenser = Condenser.from_config(self.config.condenser)
logger.debug(f'Using condenser: {type(self.condenser)}')
@ -168,7 +159,7 @@ class CodeActAgent(Agent):
if not self.prompt_manager:
raise Exception('Prompt Manager not instantiated.')
# Use conversation_memory to process events instead of calling events_to_messages directly
# Use ConversationMemory to process initial messages
messages = self.conversation_memory.process_initial_messages(
with_caching=self.llm.is_caching_prompt_active()
)
@ -180,12 +171,12 @@ class CodeActAgent(Agent):
f'Processing {len(events)} events from a total of {len(state.history)} events'
)
# Use ConversationMemory to process events
messages = self.conversation_memory.process_events(
condensed_history=events,
initial_messages=messages,
max_message_chars=self.llm.config.max_message_chars,
vision_is_active=self.llm.vision_is_active(),
enable_som_visual_browsing=self.config.enable_som_visual_browsing,
)
messages = self._enhance_messages(messages)
@ -216,14 +207,7 @@ class CodeActAgent(Agent):
# compose the first user message with examples
self.prompt_manager.add_examples_to_initial_message(msg)
# and/or repo/runtime info
if self.config.enable_prompt_extensions:
self.prompt_manager.add_info_to_initial_message(msg)
# enhance the user message with additional context based on keywords matched
if msg.role == 'user':
self.prompt_manager.enhance_message(msg)
elif msg.role == 'user':
# Add double newline between consecutive user messages
if prev_role == 'user' and len(msg.content) > 0:
# Find the first TextContent in the message to add newlines

View File

@ -20,6 +20,8 @@ When starting a web server, use the corresponding ports. You should also
set any options to allow iframes and CORS requests, and allow the server to
be accessed from any host (e.g. 0.0.0.0).
{% endif %}
{% if runtime_info.additional_agent_instructions %}
{{ runtime_info.additional_agent_instructions }}
{% endif %}
</RUNTIME_INFORMATION>
{% endif %}

View File

@ -1,8 +1,8 @@
{% for agent_info in triggered_agents %}
<EXTRA_INFO>
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
The following information has been included based on a keyword match for "{{ agent_info.trigger }}".
It may or may not be relevant to the user's request.
{{ agent_info.agent.content }}
{{ agent_info.content }}
</EXTRA_INFO>
{% endfor %}

View File

@ -29,7 +29,12 @@ from openhands.core.exceptions import (
from openhands.core.logger import LOG_ALL_EVENTS
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events import (
EventSource,
EventStream,
EventStreamSubscriber,
RecallType,
)
from openhands.events.action import (
Action,
ActionConfirmationStatus,
@ -42,6 +47,7 @@ from openhands.events.action import (
MessageAction,
NullAction,
)
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event
from openhands.events.observation import (
AgentCondensationObservation,
@ -89,7 +95,7 @@ class AgentController:
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
sid: str = 'default',
sid: str | None = None,
confirmation_mode: bool = False,
initial_state: State | None = None,
is_delegate: bool = False,
@ -116,7 +122,7 @@ class AgentController:
status_callback: Optional callback function to handle status updates.
replay_events: A list of logs to replay.
"""
self.id = sid
self.id = sid or event_stream.sid
self.agent = agent
self.headless_mode = headless_mode
self.is_delegate = is_delegate
@ -287,8 +293,14 @@ class AgentController:
return True
return False
if isinstance(event, Observation):
if isinstance(event, NullObservation) or isinstance(
event, AgentStateChangedObservation
if (
isinstance(event, NullObservation)
and event.cause is not None
and event.cause > 0
):
return True
if isinstance(event, AgentStateChangedObservation) or isinstance(
event, NullObservation
):
return False
return True
@ -388,6 +400,7 @@ class AgentController:
if observation.llm_metrics is not None:
self.agent.llm.metrics.merge(observation.llm_metrics)
# this happens for runnable actions and microagent actions
if self._pending_action and self._pending_action.id == observation.cause:
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
return
@ -431,6 +444,25 @@ class AgentController:
'debug',
f'Extended max iterations to {self.state.max_iterations} after user message',
)
# try to retrieve microagents relevant to the user message
# set pending_action while we search for information
# if this is the first user message for this agent, matters for the microagent info type
first_user_message = self._first_user_message()
is_first_user_message = (
action.id == first_user_message.id if first_user_message else False
)
recall_type = (
RecallType.WORKSPACE_CONTEXT
if is_first_user_message
else RecallType.KNOWLEDGE
)
recall_action = RecallAction(query=action.content, recall_type=recall_type)
self._pending_action = recall_action
# this is source=USER because the user message is the trigger for the microagent retrieval
self.event_stream.add_event(recall_action, EventSource.USER)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif action.source == EventSource.AGENT and action.wait_for_response:
@ -438,6 +470,7 @@ class AgentController:
def _reset(self) -> None:
"""Resets the agent controller"""
# Runnable actions need an Observation
# make sure there is an Observation with the tool call metadata to be recognized by the agent
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
@ -459,6 +492,8 @@ class AgentController:
obs._cause = self._pending_action.id # type: ignore[attr-defined]
self.event_stream.add_event(obs, EventSource.AGENT)
# NOTE: RecallActions don't need an ErrorObservation upon reset, as long as they have no tool calls
# reset the pending action, this will be called when the agent is STOPPED or ERROR
self._pending_action = None
self.agent.reset()
@ -1146,3 +1181,26 @@ class AgentController:
result = event.agent_state == AgentState.RUNNING
return result
return False
def _first_user_message(self) -> MessageAction | None:
"""
Get the first user message for this agent.
For regular agents, this is the first user message from the beginning (start_id=0).
For delegate agents, this is the first user message after the delegate's start_id.
Returns:
MessageAction | None: The first user message, or None if no user message found
"""
# Find the first user message from the appropriate starting point
user_messages = list(self.event_stream.get_events(start_id=self.state.start_id))
# Get and return the first user message
return next(
(
e
for e in user_messages
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)

View File

@ -135,7 +135,7 @@ class StuckDetector:
# it takes 3 actions and 3 observations to detect a loop
# check if the last three actions are the same and result in errors
if len(last_actions) < 4 or len(last_observations) < 4:
if len(last_actions) < 3 or len(last_observations) < 3:
return False
# are the last three actions the "same"?

View File

@ -17,6 +17,7 @@ from openhands.core.schema import AgentState
from openhands.core.setup import (
create_agent,
create_controller,
create_memory,
create_runtime,
initialize_repository_for_runtime,
)
@ -170,13 +171,22 @@ async def main(loop: asyncio.AbstractEventLoop):
await runtime.connect()
# Initialize repository if needed
repo_directory = None
if config.sandbox.selected_repo:
initialize_repository_for_runtime(
repo_directory = initialize_repository_for_runtime(
runtime,
agent=agent,
selected_repository=config.sandbox.selected_repo,
)
# when memory is created, it will load the microagents from the selected repository
memory = create_memory(
runtime=runtime,
event_stream=event_stream,
sid=sid,
selected_repository=config.sandbox.selected_repo,
repo_directory=repo_directory,
)
if initial_user_action:
# If there's an initial user action, enqueue it and do not prompt again
event_stream.add_event(initial_user_action, EventSource.USER)
@ -185,7 +195,7 @@ async def main(loop: asyncio.AbstractEventLoop):
asyncio.create_task(prompt_for_next_task())
await run_agent_until_done(
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
)

View File

@ -3,12 +3,14 @@ import asyncio
from openhands.controller import AgentController
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
async def run_agent_until_done(
controller: AgentController,
runtime: Runtime,
memory: Memory,
end_states: list[AgentState],
):
"""
@ -37,6 +39,7 @@ async def run_agent_until_done(
runtime.status_callback = status_callback
controller.status_callback = status_callback
memory.status_callback = status_callback
while controller.state.agent_state not in end_states:
await asyncio.sleep(1)

View File

@ -18,6 +18,7 @@ from openhands.core.schema import AgentState
from openhands.core.setup import (
create_agent,
create_controller,
create_memory,
create_runtime,
generate_sid,
initialize_repository_for_runtime,
@ -29,6 +30,7 @@ from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.io import read_input, read_task
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
@ -51,6 +53,7 @@ async def run_controller(
exit_on_message: bool = False,
fake_user_response_fn: FakeUserResponseFunc | None = None,
headless_mode: bool = True,
memory: Memory | None = None,
) -> State | None:
"""Main coroutine to run the agent controller with task input flexibility.
@ -93,6 +96,8 @@ async def run_controller(
if agent is None:
agent = create_agent(config)
# when the runtime is created, it will be connected and clone the selected repository
repo_directory = None
if runtime is None:
runtime = create_runtime(
config,
@ -105,14 +110,23 @@ async def run_controller(
# Initialize repository if needed
if config.sandbox.selected_repo:
initialize_repository_for_runtime(
repo_directory = initialize_repository_for_runtime(
runtime,
agent=agent,
selected_repository=config.sandbox.selected_repo,
)
event_stream = runtime.event_stream
# when memory is created, it will load the microagents from the selected repository
if memory is None:
memory = create_memory(
runtime=runtime,
event_stream=event_stream,
sid=sid,
selected_repository=config.sandbox.selected_repo,
repo_directory=repo_directory,
)
replay_events: list[Event] | None = None
if config.replay_trajectory_path:
logger.info('Trajectory replay is enabled')
@ -172,7 +186,7 @@ async def run_controller(
]
try:
await run_agent_until_done(controller, runtime, end_states)
await run_agent_until_done(controller, runtime, memory, end_states)
except Exception as e:
logger.error(f'Exception in main loop: {e}')

View File

@ -82,5 +82,8 @@ class ActionTypeSchema(BaseModel):
SEND_PR: str = Field(default='send_pr')
"""Send a PR to github."""
RECALL: str = Field(default='recall')
"""Retrieves content from a user workspace, microagent, or other source."""
ActionType = ActionTypeSchema()

View File

@ -49,5 +49,8 @@ class ObservationTypeSchema(BaseModel):
CONDENSE: str = Field(default='condense')
"""Result of a condensation operation."""
MICROAGENT: str = Field(default='microagent')
"""Result of a microagent retrieval operation."""
ObservationType = ObservationTypeSchema()

View File

@ -1,7 +1,7 @@
import hashlib
import os
import uuid
from typing import Tuple, Type
from typing import Callable, Tuple, Type
from pydantic import SecretStr
@ -16,6 +16,7 @@ from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroAgent
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
@ -83,7 +84,6 @@ def create_runtime(
def initialize_repository_for_runtime(
runtime: Runtime,
agent: Agent | None = None,
selected_repository: str | None = None,
github_token: SecretStr | None = None,
) -> str | None:
@ -91,7 +91,6 @@ def initialize_repository_for_runtime(
Args:
runtime: The runtime to initialize the repository for.
agent: (optional) The agent to load microagents for.
selected_repository: (optional) The GitHub repository to use.
github_token: (optional) The GitHub token to use.
@ -99,10 +98,10 @@ def initialize_repository_for_runtime(
The repository directory path if a repository was cloned, None otherwise.
"""
# clone selected repository if provided
repo_directory = None
github_token = (
SecretStr(os.environ.get('GITHUB_TOKEN')) if not github_token else github_token
)
repo_directory = None
if selected_repository and github_token:
logger.debug(f'Selected repository {selected_repository}.')
repo_directory = runtime.clone_repo(
@ -111,16 +110,47 @@ def initialize_repository_for_runtime(
None,
)
# load microagents from selected repository
if agent and agent.prompt_manager and selected_repository and repo_directory:
agent.prompt_manager.set_runtime_info(runtime)
return repo_directory
def create_memory(
runtime: Runtime,
event_stream: EventStream,
sid: str,
selected_repository: str | None = None,
repo_directory: str | None = None,
status_callback: Callable | None = None,
) -> Memory:
"""Create a memory for the agent to use.
Args:
runtime: The runtime to use.
event_stream: The event stream it will subscribe to.
sid: The session id.
selected_repository: The repository to clone and start with, if any.
repo_directory: The repository directory, if any.
status_callback: Optional callback function to handle status updates.
"""
memory = Memory(
event_stream=event_stream,
sid=sid,
status_callback=status_callback,
)
if runtime:
# sets available hosts
memory.set_runtime_info(runtime)
# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroAgent] = runtime.get_microagents_from_selected_repo(
selected_repository
)
agent.prompt_manager.load_microagents(microagents)
agent.prompt_manager.set_repository_info(selected_repository, repo_directory)
memory.load_user_workspace_microagents(microagents)
return repo_directory
if selected_repository and repo_directory:
memory.set_repository_info(selected_repository, repo_directory)
return memory
def create_agent(config: AppConfig) -> Agent:

View File

@ -1,4 +1,4 @@
from openhands.events.event import Event, EventSource
from openhands.events.event import Event, EventSource, RecallType
from openhands.events.stream import EventStream, EventStreamSubscriber
__all__ = [
@ -6,4 +6,5 @@ __all__ = [
'EventSource',
'EventStream',
'EventStreamSubscriber',
'RecallType',
]

View File

@ -6,6 +6,7 @@ from openhands.events.action.agent import (
AgentSummarizeAction,
AgentThinkAction,
ChangeAgentStateAction,
RecallAction,
)
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
@ -35,4 +36,5 @@ __all__ = [
'MessageAction',
'ActionConfirmationStatus',
'AgentThinkAction',
'RecallAction',
]

View File

@ -4,6 +4,7 @@ from typing import Any
from openhands.core.schema import ActionType
from openhands.events.action.action import Action
from openhands.events.event import RecallType
@dataclass
@ -106,3 +107,22 @@ class AgentDelegateAction(Action):
@property
def message(self) -> str:
return f"I'm asking {self.agent} for help with this task."
@dataclass
class RecallAction(Action):
"""This action is used for retrieving content, e.g., from the global directory or user workspace."""
recall_type: RecallType
query: str = ''
thought: str = ''
action: str = ActionType.RECALL
@property
def message(self) -> str:
return f'Retrieving content for: {self.query[:50]}'
def __str__(self) -> str:
ret = '**RecallAction**\n'
ret += f'QUERY: {self.query[:50]}'
return ret

View File

@ -22,6 +22,16 @@ class FileReadSource(str, Enum):
DEFAULT = 'default'
class RecallType(str, Enum):
"""The type of information that can be retrieved from microagents."""
WORKSPACE_CONTEXT = 'workspace_context'
"""Workspace context (repo instructions, runtime, etc.)"""
KNOWLEDGE = 'knowledge'
"""A knowledge microagent."""
@dataclass
class Event:
INVALID_ID = -1

View File

@ -1,7 +1,9 @@
from openhands.events.event import RecallType
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
AgentThinkObservation,
MicroagentObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
@ -40,4 +42,6 @@ __all__ = [
'SuccessObservation',
'UserRejectObservation',
'AgentCondensationObservation',
'MicroagentObservation',
'RecallType',
]

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from openhands.core.schema import ObservationType
from openhands.events.event import RecallType
from openhands.events.observation.observation import Observation
@ -40,3 +41,76 @@ class AgentThinkObservation(Observation):
@property
def message(self) -> str:
return self.content
@dataclass
class MicroagentKnowledge:
"""
Represents knowledge from a triggered microagent.
Attributes:
name: The name of the microagent that was triggered
trigger: The word that triggered this microagent
content: The actual content/knowledge from the microagent
"""
name: str
trigger: str
content: str
@dataclass
class MicroagentObservation(Observation):
"""The retrieval of content from a microagent or more microagents."""
recall_type: RecallType
observation: str = ObservationType.MICROAGENT
# environment
repo_name: str = ''
repo_directory: str = ''
repo_instructions: str = ''
runtime_hosts: dict[str, int] = field(default_factory=dict)
additional_agent_instructions: str = ''
# knowledge
microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list)
"""
A list of MicroagentKnowledge objects, each containing information from a triggered microagent.
Example:
[
MicroagentKnowledge(
name="python_best_practices",
trigger="python",
content="Always use virtual environments for Python projects."
),
MicroagentKnowledge(
name="git_workflow",
trigger="git",
content="Create a new branch for each feature or bugfix."
)
]
"""
@property
def message(self) -> str:
return self.__str__()
def __str__(self) -> str:
# Build a string representation of all fields
fields = [
f'recall_type={self.recall_type}',
f'repo_name={self.repo_name}',
f'repo_instructions={self.repo_instructions[:20]}...',
f'runtime_hosts={self.runtime_hosts}',
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
]
# Only include microagent_knowledge if it's not empty
if self.microagent_knowledge:
fields.append(
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}'
)
return f'**MicroagentObservation**\n{", ".join(fields)}'

View File

@ -8,6 +8,7 @@ from openhands.events.action.agent import (
AgentRejectAction,
AgentThinkAction,
ChangeAgentStateAction,
RecallAction,
)
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
from openhands.events.action.commands import (
@ -35,6 +36,7 @@ actions = (
AgentFinishAction,
AgentRejectAction,
AgentDelegateAction,
RecallAction,
ChangeAgentStateAction,
MessageAction,
)

View File

@ -1,5 +1,6 @@
from dataclasses import asdict
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
@ -102,6 +103,8 @@ def event_to_dict(event: 'Event') -> dict:
d['timestamp'] = d['timestamp'].isoformat()
if key == 'source' and 'source' in d:
d['source'] = d['source'].value
if key == 'recall_type' and 'recall_type' in d:
d['recall_type'] = d['recall_type'].value
if key == 'tool_call_metadata' and 'tool_call_metadata' in d:
d['tool_call_metadata'] = d['tool_call_metadata'].model_dump()
if key == 'llm_metrics' and 'llm_metrics' in d:
@ -119,7 +122,11 @@ def event_to_dict(event: 'Event') -> dict:
# props is a dict whose values can include a complex object like an instance of a BaseModel subclass
# such as CmdOutputMetadata
# we serialize it along with the rest
d['extras'] = {k: _convert_pydantic_to_dict(v) for k, v in props.items()}
# we also handle the Enum conversion for MicroagentObservation
d['extras'] = {
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
for k, v in props.items()
}
# Include success field for CmdOutputObservation
if hasattr(event, 'success'):
d['success'] = event.success

View File

@ -1,9 +1,12 @@
import copy
from openhands.events.event import RecallType
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
AgentThinkObservation,
MicroagentKnowledge,
MicroagentObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
@ -40,6 +43,7 @@ observations = (
UserRejectObservation,
AgentCondensationObservation,
AgentThinkObservation,
MicroagentObservation,
)
OBSERVATION_TYPE_TO_CLASS = {
@ -110,4 +114,18 @@ def observation_from_dict(observation: dict) -> Observation:
else:
extras['metadata'] = CmdOutputMetadata()
if observation_class is MicroagentObservation:
# handle the Enum conversion
if 'recall_type' in extras:
extras['recall_type'] = RecallType(extras['recall_type'])
# convert dicts in microagent_knowledge to MicroagentKnowledge objects
if 'microagent_knowledge' in extras and isinstance(
extras['microagent_knowledge'], list
):
extras['microagent_knowledge'] = [
MicroagentKnowledge(**item) if isinstance(item, dict) else item
for item in extras['microagent_knowledge']
]
return observation_class(content=content, **extras)

View File

@ -27,6 +27,7 @@ class EventStreamSubscriber(str, Enum):
RESOLVER = 'openhands_resolver'
SERVER = 'server'
RUNTIME = 'runtime'
MEMORY = 'memory'
MAIN = 'main'
TEST = 'test'

View File

@ -249,7 +249,8 @@ class LLM(RetryMixin, DebugMixin):
# if we mocked function calling, and we have tools, convert the response back to function calling format
if mock_function_calling and mock_fncall_tools is not None:
assert len(resp.choices) == 1
logger.debug(f'Response choices: {len(resp.choices)}')
assert len(resp.choices) >= 1
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(

View File

@ -1,5 +1,6 @@
from litellm import ModelResponse
from openhands.core.config.agent_config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import ImageContent, Message, TextContent
from openhands.core.schema import ActionType
@ -16,7 +17,7 @@ from openhands.events.action import (
IPythonRunCellAction,
MessageAction,
)
from openhands.events.event import Event
from openhands.events.event import Event, RecallType
from openhands.events.observation import (
AgentCondensationObservation,
AgentDelegateObservation,
@ -28,16 +29,21 @@ from openhands.events.observation import (
IPythonRunCellObservation,
UserRejectObservation,
)
from openhands.events.observation.agent import (
MicroagentKnowledge,
MicroagentObservation,
)
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import truncate_content
from openhands.utils.prompt import PromptManager
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
class ConversationMemory:
"""Processes event history into a coherent conversation for the agent."""
def __init__(self, prompt_manager: PromptManager):
def __init__(self, config: AgentConfig, prompt_manager: PromptManager):
self.agent_config = config
self.prompt_manager = prompt_manager
def process_events(
@ -53,14 +59,14 @@ class ConversationMemory:
Ensures that tool call actions are processed correctly in function calling mode.
Args:
state: The state containing the history of events to convert
condensed_history: The condensed list of events to process
initial_messages: The initial messages to include in the result
condensed_history: The condensed history of events to convert
initial_messages: The initial messages to include in the conversation
max_message_chars: The maximum number of characters in the content of an event included
in the prompt to the LLM. Larger observations are truncated.
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model.
"""
events = condensed_history
# Process special events first (system prompts, etc.)
@ -70,7 +76,7 @@ class ConversationMemory:
pending_tool_call_action_messages: dict[str, Message] = {}
tool_call_id_to_message: dict[str, Message] = {}
for event in events:
for i, event in enumerate(events):
# create a regular message from an event
if isinstance(event, Action):
messages_to_add = self._process_action(
@ -84,7 +90,9 @@ class ConversationMemory:
tool_call_id_to_message=tool_call_id_to_message,
max_message_chars=max_message_chars,
vision_is_active=vision_is_active,
enable_som_visual_browsing=enable_som_visual_browsing,
enable_som_visual_browsing=self.agent_config.enable_som_visual_browsing,
current_index=i,
events=events,
)
else:
raise ValueError(f'Unknown event type: {type(event)}')
@ -270,6 +278,8 @@ class ConversationMemory:
max_message_chars: int | None = None,
vision_is_active: bool = False,
enable_som_visual_browsing: bool = False,
current_index: int = 0,
events: list[Event] | None = None,
) -> list[Message]:
"""Converts an observation into a message format that can be sent to the LLM.
@ -291,6 +301,8 @@ class ConversationMemory:
max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model
current_index: The index of the current event in the events list (for deduplication)
events: The list of all events (for deduplication)
Returns:
list[Message]: A list containing the formatted message(s) for the observation.
@ -372,6 +384,92 @@ class ConversationMemory:
elif isinstance(obs, AgentCondensationObservation):
text = truncate_content(obs.content, max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])
elif (
isinstance(obs, MicroagentObservation)
and self.agent_config.enable_prompt_extensions
):
if obs.recall_type == RecallType.WORKSPACE_CONTEXT:
# everything is optional, check if they are present
repo_info = (
RepositoryInfo(
repo_name=obs.repo_name or '',
repo_directory=obs.repo_directory or '',
)
if obs.repo_name or obs.repo_directory
else None
)
if obs.runtime_hosts or obs.additional_agent_instructions:
runtime_info = RuntimeInfo(
available_hosts=obs.runtime_hosts,
additional_agent_instructions=obs.additional_agent_instructions,
)
else:
runtime_info = None
repo_instructions = (
obs.repo_instructions if obs.repo_instructions else ''
)
# Have some meaningful content before calling the template
has_repo_info = repo_info is not None and (
repo_info.repo_name or repo_info.repo_directory
)
has_runtime_info = runtime_info is not None and (
runtime_info.available_hosts
or runtime_info.additional_agent_instructions
)
has_repo_instructions = bool(repo_instructions.strip())
# Build additional info if we have something to render
if has_repo_info or has_runtime_info or has_repo_instructions:
# ok, now we can build the additional info
formatted_text = self.prompt_manager.build_additional_info(
repository_info=repo_info,
runtime_info=runtime_info,
repo_instructions=repo_instructions,
)
message = Message(
role='user', content=[TextContent(text=formatted_text)]
)
else:
return []
elif obs.recall_type == RecallType.KNOWLEDGE:
# Use prompt manager to build the microagent info
# First, filter out agents that appear in earlier MicroagentObservations
filtered_agents = self._filter_agents_in_microagent_obs(
obs, current_index, events or []
)
# Create and return a message if there is microagent knowledge to include
if filtered_agents:
# Exclude disabled microagents
filtered_agents = [
agent
for agent in filtered_agents
if agent.name not in self.agent_config.disabled_microagents
]
# Only proceed if we still have agents after filtering out disabled ones
if filtered_agents:
formatted_text = self.prompt_manager.build_microagent_info(
triggered_agents=filtered_agents,
)
return [
Message(
role='user', content=[TextContent(text=formatted_text)]
)
]
# Return empty list if no microagents to include or all were disabled
return []
elif (
isinstance(obs, MicroagentObservation)
and not self.agent_config.enable_prompt_extensions
):
# If prompt extensions are disabled, we don't add any additional info
# TODO: test this
return []
else:
# If an observation message is not returned, it will cause an error
# when the LLM tries to return the next message
@ -404,3 +502,53 @@ class ConversationMemory:
-1
].cache_prompt = True # Last item inside the message content
break
def _filter_agents_in_microagent_obs(
self, obs: MicroagentObservation, current_index: int, events: list[Event]
) -> list[MicroagentKnowledge]:
"""Filter out agents that appear in earlier MicroagentObservations.
Args:
obs: The current MicroagentObservation to filter
current_index: The index of the current event in the events list
events: The list of all events
Returns:
list[MicroagentKnowledge]: The filtered list of microagent knowledge
"""
if obs.recall_type != RecallType.KNOWLEDGE:
return obs.microagent_knowledge
# For each agent in the current microagent observation, check if it appears in any earlier microagent observation
filtered_agents = []
for agent in obs.microagent_knowledge:
# Keep this agent if it doesn't appear in any earlier observation
# that is, if this is the first microagent observation with this microagent
if not self._has_agent_in_earlier_events(agent.name, current_index, events):
filtered_agents.append(agent)
return filtered_agents
def _has_agent_in_earlier_events(
self, agent_name: str, current_index: int, events: list[Event]
) -> bool:
"""Check if an agent appears in any earlier MicroagentObservation in the event list.
Args:
agent_name: The name of the agent to look for
current_index: The index of the current event in the events list
events: The list of all events
Returns:
bool: True if the agent appears in an earlier MicroagentObservation, False otherwise
"""
for event in events[:current_index]:
if (
isinstance(event, MicroagentObservation)
and event.recall_type == RecallType.KNOWLEDGE
):
if any(
agent.name == agent_name for agent in event.microagent_knowledge
):
return True
return False

270
openhands/memory/memory.py Normal file
View File

@ -0,0 +1,270 @@
import asyncio
import os
import uuid
from typing import Callable
import openhands
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event, EventSource, RecallType
from openhands.events.observation.agent import (
MicroagentKnowledge,
MicroagentObservation,
)
from openhands.events.observation.empty import NullObservation
from openhands.events.stream import EventStream, EventStreamSubscriber
from openhands.microagent import (
BaseMicroAgent,
KnowledgeMicroAgent,
RepoMicroAgent,
load_microagents_from_dir,
)
from openhands.runtime.base import Runtime
from openhands.utils.prompt import RepositoryInfo, RuntimeInfo
GLOBAL_MICROAGENTS_DIR = os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
'microagents',
)
class Memory:
"""
Memory is a component that listens to the EventStream for information retrieval actions
(a RecallAction) and publishes observations with the content (such as MicroagentObservation).
"""
sid: str
event_stream: EventStream
status_callback: Callable | None
loop: asyncio.AbstractEventLoop | None
def __init__(
self,
event_stream: EventStream,
sid: str,
status_callback: Callable | None = None,
):
self.event_stream = event_stream
self.sid = sid if sid else str(uuid.uuid4())
self.status_callback = status_callback
self.loop = None
self.event_stream.subscribe(
EventStreamSubscriber.MEMORY,
self.on_event,
self.sid,
)
# Additional placeholders to store user workspace microagents
self.repo_microagents: dict[str, RepoMicroAgent] = {}
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}
# Store repository / runtime info to send them to the templating later
self.repository_info: RepositoryInfo | None = None
self.runtime_info: RuntimeInfo | None = None
# Load global microagents (Knowledge + Repo)
# from typically OpenHands/microagents (i.e., the PUBLIC microagents)
self._load_global_microagents()
def on_event(self, event: Event):
"""Handle an event from the event stream."""
asyncio.get_event_loop().run_until_complete(self._on_event(event))
async def _on_event(self, event: Event):
"""Handle an event from the event stream asynchronously."""
try:
observation: MicroagentObservation | NullObservation | None = None
if isinstance(event, RecallAction):
# if this is a workspace context recall (on first user message)
# create and add a MicroagentObservation
# with info about repo and runtime.
if (
event.source == EventSource.USER
and event.recall_type == RecallType.WORKSPACE_CONTEXT
):
observation = self._on_first_microagent_action(event)
# continue with the next handler, to include knowledge microagents if suitable for this query
assert observation is None or isinstance(
observation, MicroagentObservation
), f'Expected a MicroagentObservation, but got {type(observation)}'
observation = self._on_microagent_action(
event, prev_observation=observation
)
if observation is None:
observation = NullObservation(content='')
# important: this will release the execution flow from waiting for the retrieval to complete
observation._cause = event.id # type: ignore[union-attr]
self.event_stream.add_event(observation, EventSource.ENVIRONMENT)
except Exception as e:
error_str = f'Error: {str(e.__class__.__name__)}'
logger.error(error_str)
self.send_error_message('STATUS$ERROR_MEMORY', error_str)
return
def _on_first_microagent_action(
self, event: RecallAction
) -> MicroagentObservation | None:
"""Add repository and runtime information to the stream as a MicroagentObservation."""
# Create ENVIRONMENT info:
# - repository_info
# - runtime_info
# - repository_instructions
# Collect raw repository instructions
repo_instructions = ''
assert (
len(self.repo_microagents) <= 1
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
# Retrieve the context of repo instructions
for microagent in self.repo_microagents.values():
if repo_instructions:
repo_instructions += '\n\n'
repo_instructions += microagent.content
# Create observation if we have anything
if self.repository_info or self.runtime_info or repo_instructions:
obs = MicroagentObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name=self.repository_info.repo_name
if self.repository_info and self.repository_info.repo_name is not None
else '',
repo_directory=self.repository_info.repo_directory
if self.repository_info
and self.repository_info.repo_directory is not None
else '',
repo_instructions=repo_instructions if repo_instructions else '',
runtime_hosts=self.runtime_info.available_hosts
if self.runtime_info and self.runtime_info.available_hosts is not None
else {},
additional_agent_instructions=self.runtime_info.additional_agent_instructions
if self.runtime_info
and self.runtime_info.additional_agent_instructions is not None
else '',
microagent_knowledge=[],
content='Retrieved environment info',
)
return obs
return None
def _on_microagent_action(
self,
event: RecallAction,
prev_observation: MicroagentObservation | None = None,
) -> MicroagentObservation | None:
"""When a microagent action triggers microagents, create a MicroagentObservation with structured data."""
# If there's no query, do nothing
query = event.query.strip()
if not query:
return prev_observation
assert prev_observation is None or isinstance(
prev_observation, MicroagentObservation
), f'Expected a MicroagentObservation, but got {type(prev_observation)}'
# Process text to find suitable microagents and create a MicroagentObservation.
recalled_content: list[MicroagentKnowledge] = []
for name, microagent in self.knowledge_microagents.items():
trigger = microagent.match_trigger(query)
if trigger:
logger.info("Microagent '%s' triggered by keyword '%s'", name, trigger)
recalled_content.append(
MicroagentKnowledge(
name=microagent.name,
trigger=trigger,
content=microagent.content,
)
)
if recalled_content:
if prev_observation is not None:
# it may be on the first user message that already found some repo info etc
prev_observation.microagent_knowledge.extend(recalled_content)
else:
# if it's not the first user message, we may not have found any information this step
obs = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=recalled_content,
content='Retrieved knowledge from microagents',
)
return obs
return prev_observation
def load_user_workspace_microagents(
self, user_microagents: list[BaseMicroAgent]
) -> None:
"""
This method loads microagents from a user's cloned repo or workspace directory.
This is typically called from agent_session or setup once the workspace is cloned.
"""
logger.info(
'Loading user workspace microagents: %s', [m.name for m in user_microagents]
)
for user_microagent in user_microagents:
if isinstance(user_microagent, KnowledgeMicroAgent):
self.knowledge_microagents[user_microagent.name] = user_microagent
elif isinstance(user_microagent, RepoMicroAgent):
self.repo_microagents[user_microagent.name] = user_microagent
def _load_global_microagents(self) -> None:
"""
Loads microagents from the global microagents_dir
"""
repo_agents, knowledge_agents, _ = load_microagents_from_dir(
GLOBAL_MICROAGENTS_DIR
)
for name, agent in knowledge_agents.items():
if isinstance(agent, KnowledgeMicroAgent):
self.knowledge_microagents[name] = agent
for name, agent in repo_agents.items():
if isinstance(agent, RepoMicroAgent):
self.repo_microagents[name] = agent
def set_repository_info(self, repo_name: str, repo_directory: str) -> None:
"""Store repository info so we can reference it in an observation."""
if repo_name or repo_directory:
self.repository_info = RepositoryInfo(repo_name, repo_directory)
else:
self.repository_info = None
def set_runtime_info(self, runtime: Runtime) -> None:
"""Store runtime info (web hosts, ports, etc.)."""
# e.g. { '127.0.0.1': 8080 }
if runtime.web_hosts or runtime.additional_agent_instructions:
self.runtime_info = RuntimeInfo(
available_hosts=runtime.web_hosts,
additional_agent_instructions=runtime.additional_agent_instructions,
)
else:
self.runtime_info = None
def send_error_message(self, message_id: str, message: str):
"""Sends an error message if the callback function was provided."""
if self.status_callback:
try:
if self.loop is None:
self.loop = asyncio.get_running_loop()
asyncio.run_coroutine_threadsafe(
self._send_status_message('error', message_id, message), self.loop
)
except RuntimeError as e:
logger.error(
f'Error sending status message: {e.__class__.__name__}',
stack_info=False,
)
async def _send_status_message(self, msg_type: str, id: str, message: str):
"""Sends a status message to the client."""
if self.status_callback:
self.status_callback(msg_type, id, message)

View File

@ -15,7 +15,8 @@ from openhands.core.schema.agent import AgentState
from openhands.events.action import ChangeAgentStateAction, MessageAction
from openhands.events.event import EventSource
from openhands.events.stream import EventStream
from openhands.microagent import BaseMicroAgent
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroAgent
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
@ -126,6 +127,15 @@ class AgentSession:
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
)
repo_directory = None
if self.runtime and runtime_connected and selected_repository:
repo_directory = selected_repository.split('/')[-1]
self.memory = await self._create_memory(
selected_repository=selected_repository,
repo_directory=repo_directory,
)
if github_token:
self.event_stream.set_secrets(
{
@ -260,26 +270,14 @@ class AgentSession:
)
return False
repo_directory = None
if selected_repository:
repo_directory = await call_sync_from_async(
await call_sync_from_async(
self.runtime.clone_repo,
github_token,
selected_repository,
selected_branch,
)
if agent.prompt_manager:
agent.prompt_manager.set_runtime_info(self.runtime)
microagents: list[BaseMicroAgent] = await call_sync_from_async(
self.runtime.get_microagents_from_selected_repo, selected_repository
)
agent.prompt_manager.load_microagents(microagents)
if selected_repository and repo_directory:
agent.prompt_manager.set_repository_info(
selected_repository, repo_directory
)
self.logger.debug(
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
)
@ -342,6 +340,29 @@ class AgentSession:
return controller
async def _create_memory(
self, selected_repository: str | None, repo_directory: str | None
) -> Memory:
memory = Memory(
event_stream=self.event_stream,
sid=self.sid,
status_callback=self._status_callback,
)
if self.runtime:
# sets available hosts and other runtime info
memory.set_runtime_info(self.runtime)
# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroAgent] = await call_sync_from_async(
self.runtime.get_microagents_from_selected_repo, selected_repository
)
memory.load_user_workspace_microagents(microagents)
if selected_repository and repo_directory:
memory.set_repository_info(selected_repository, repo_directory)
return memory
def _maybe_restore_state(self) -> State | None:
"""Helper method to handle state restore logic."""
restored_state = None

View File

@ -85,8 +85,8 @@ class FileConversationStore(ConversationStore):
try:
conversations.append(await self.get_metadata(conversation_id))
except Exception:
logger.error(
f'Error loading conversation: {conversation_id}',
logger.warning(
f'Could not load conversation metadata: {conversation_id}',
)
conversations.sort(key=_sort_key, reverse=True)
conversations = conversations[start:end]

View File

@ -1,25 +1,18 @@
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from itertools import islice
from jinja2 import Template
from openhands.controller.state.state import State
from openhands.core.logger import openhands_logger
from openhands.core.message import Message, TextContent
from openhands.microagent import (
BaseMicroAgent,
KnowledgeMicroAgent,
RepoMicroAgent,
load_microagents_from_dir,
)
from openhands.runtime.base import Runtime
from openhands.events.observation.agent import MicroagentKnowledge
@dataclass
class RuntimeInfo:
available_hosts: dict[str, int]
additional_agent_instructions: str
available_hosts: dict[str, int] = field(default_factory=dict)
additional_agent_instructions: str = ''
@dataclass
@ -32,75 +25,23 @@ class RepositoryInfo:
class PromptManager:
"""
Manages prompt templates and micro-agents for AI interactions.
Manages prompt templates and includes information from the user's workspace micro-agents and global micro-agents.
This class handles loading and rendering of system and user prompt templates,
as well as loading micro-agent specifications. It provides methods to access
rendered system and initial user messages for AI interactions.
This class is dedicated to loading and rendering prompts (system prompt, user prompt).
Attributes:
prompt_dir (str): Directory containing prompt templates.
microagent_dir (str): Directory containing microagent specifications.
disabled_microagents (list[str] | None): List of microagents to disable. If None, all microagents are enabled.
prompt_dir: Directory containing prompt templates.
"""
def __init__(
self,
prompt_dir: str,
microagent_dir: str | None = None,
disabled_microagents: list[str] | None = None,
):
self.disabled_microagents: list[str] = disabled_microagents or []
self.prompt_dir: str = prompt_dir
self.repository_info: RepositoryInfo | None = None
self.system_template: Template = self._load_template('system_prompt')
self.user_template: Template = self._load_template('user_prompt')
self.additional_info_template: Template = self._load_template('additional_info')
self.microagent_info_template: Template = self._load_template('microagent_info')
self.runtime_info = RuntimeInfo(
available_hosts={}, additional_agent_instructions=''
)
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}
self.repo_microagents: dict[str, RepoMicroAgent] = {}
if microagent_dir:
# This loads micro-agents from the microagent_dir
# which is typically the OpenHands/microagents (i.e., the PUBLIC microagents)
# Only load KnowledgeMicroAgents
repo_microagents, knowledge_microagents, _ = load_microagents_from_dir(
microagent_dir
)
assert all(
isinstance(microagent, KnowledgeMicroAgent)
for microagent in knowledge_microagents.values()
)
for name, microagent in knowledge_microagents.items():
if name not in self.disabled_microagents:
self.knowledge_microagents[name] = microagent
assert all(
isinstance(microagent, RepoMicroAgent)
for microagent in repo_microagents.values()
)
for name, microagent in repo_microagents.items():
if name not in self.disabled_microagents:
self.repo_microagents[name] = microagent
def load_microagents(self, microagents: list[BaseMicroAgent]) -> None:
"""Load microagents from a list of BaseMicroAgents.
This is typically used when loading microagents from inside a repo.
"""
openhands_logger.info('Loading microagents: %s', [m.name for m in microagents])
# Only keep KnowledgeMicroAgents and RepoMicroAgents
for microagent in microagents:
if microagent.name in self.disabled_microagents:
continue
if isinstance(microagent, KnowledgeMicroAgent):
self.knowledge_microagents[microagent.name] = microagent
elif isinstance(microagent, RepoMicroAgent):
self.repo_microagents[microagent.name] = microagent
def _load_template(self, template_name: str) -> Template:
if self.prompt_dir is None:
@ -114,27 +55,6 @@ class PromptManager:
def get_system_message(self) -> str:
return self.system_template.render().strip()
def set_runtime_info(self, runtime: Runtime) -> None:
self.runtime_info.available_hosts = runtime.web_hosts
self.runtime_info.additional_agent_instructions = (
runtime.additional_agent_instructions
)
def set_repository_info(
self,
repo_name: str,
repo_directory: str,
) -> None:
"""Sets information about the GitHub repository that has been cloned.
Args:
repo_name: The name of the GitHub repository (e.g. 'owner/repo')
repo_directory: The directory where the repository has been cloned
"""
self.repository_info = RepositoryInfo(
repo_name=repo_name, repo_directory=repo_directory
)
def get_example_user_message(self) -> str:
"""This is the initial user message provided to the agent
before *actual* user instructions are provided.
@ -148,45 +68,6 @@ class PromptManager:
return self.user_template.render().strip()
def enhance_message(self, message: Message) -> None:
"""Enhance the user message with additional context.
This method is used to enhance the user message with additional context
about the user's task. The additional context will convert the current
generic agent into a more specialized agent that is tailored to the user's task.
"""
if not message.content:
return
# if there were other texts included, they were before the user message
# so the last TextContent is the user message
# content can be a list of TextContent or ImageContent
message_content = ''
for content in reversed(message.content):
if isinstance(content, TextContent):
message_content = content.text
break
if not message_content:
return
triggered_agents = []
for name, microagent in self.knowledge_microagents.items():
trigger = microagent.match_trigger(message_content)
if trigger:
openhands_logger.info(
"Microagent '%s' triggered by keyword '%s'",
name,
trigger,
)
# Create a dictionary with the agent and trigger word
triggered_agents.append({'agent': microagent, 'trigger_word': trigger})
if triggered_agents:
formatted_text = self.build_microagent_info(triggered_agents)
# Insert the new content at the start of the TextContent list
message.content.insert(0, TextContent(text=formatted_text))
def add_examples_to_initial_message(self, message: Message) -> None:
"""Add example_message to the first user message."""
example_message = self.get_example_user_message() or None
@ -195,44 +76,28 @@ class PromptManager:
if example_message:
message.content.insert(0, TextContent(text=example_message))
def add_info_to_initial_message(
def build_additional_info(
self,
message: Message,
) -> None:
"""Adds information about the repository and runtime to the initial user message.
Args:
message: The initial user message to add information to.
"""
repo_instructions = ''
assert (
len(self.repo_microagents) <= 1
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
for microagent in self.repo_microagents.values():
# We assume these are the repo instructions
if repo_instructions:
repo_instructions += '\n\n'
repo_instructions += microagent.content
additional_info = self.additional_info_template.render(
repository_info: RepositoryInfo | None,
runtime_info: RuntimeInfo | None,
repo_instructions: str = '',
) -> str:
"""Renders the additional info template with the stored repository/runtime info."""
return self.additional_info_template.render(
repository_info=repository_info,
repository_instructions=repo_instructions,
repository_info=self.repository_info,
runtime_info=self.runtime_info,
runtime_info=runtime_info,
).strip()
# Insert the new content at the start of the TextContent list
if additional_info:
message.content.insert(0, TextContent(text=additional_info))
def build_microagent_info(
self,
triggered_agents: list[dict],
triggered_agents: list[MicroagentKnowledge],
) -> str:
"""Renders the microagent info template with the triggered agents.
Args:
triggered_agents: A list of dictionaries, each containing an "agent"
(KnowledgeMicroAgent) and a "trigger_word" (str).
triggered_agents: A list of MicroagentKnowledge objects containing information
about triggered microagents.
"""
return self.microagent_info_template.render(
triggered_agents=triggered_agents

View File

@ -9,6 +9,7 @@ from openhands.events.action import (
FileReadAction,
FileWriteAction,
MessageAction,
RecallAction,
)
from openhands.events.action.action import ActionConfirmationStatus
from openhands.events.action.files import FileEditSource, FileReadSource
@ -356,6 +357,18 @@ def test_file_ohaci_edit_action_legacy_serialization():
assert event_dict['args']['end'] == -1
def test_agent_microagent_action_serialization_deserialization():
original_action_dict = {
'action': 'recall',
'args': {
'query': 'What is the capital of France?',
'thought': 'I need to find information about France',
'recall_type': 'knowledge',
},
}
serialization_deserialization(original_action_dict, RecallAction)
def test_file_read_action_legacy_serialization():
original_action_dict = {
'action': 'read',

View File

@ -1,5 +1,5 @@
import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock
from unittest.mock import ANY, AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
@ -14,12 +14,16 @@ from openhands.core.main import run_controller
from openhands.core.schema import AgentState
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
from openhands.events.action.agent import RecallAction
from openhands.events.event import RecallType
from openhands.events.observation import (
ErrorObservation,
)
from openhands.events.observation.agent import MicroagentObservation
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
from openhands.llm.metrics import Metrics, TokenUsage
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.storage.memory import InMemoryFileStore
@ -47,17 +51,36 @@ def mock_agent():
@pytest.fixture
def mock_event_stream():
mock = MagicMock(spec=EventStream)
mock = MagicMock(
spec=EventStream,
event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})),
)
mock.get_latest_event_id.return_value = 0
return mock
@pytest.fixture
def test_event_stream():
event_stream = EventStream(sid='test', file_store=InMemoryFileStore({}))
return event_stream
@pytest.fixture
def mock_runtime() -> Runtime:
return MagicMock(
runtime = MagicMock(
spec=Runtime,
event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})),
event_stream=test_event_stream,
)
return runtime
@pytest.fixture
def mock_memory() -> Memory:
memory = MagicMock(
spec=Memory,
event_stream=test_event_stream,
)
return memory
@pytest.fixture
@ -68,6 +91,7 @@ def mock_status_callback():
async def send_event_to_controller(controller, event):
await controller._on_event(event)
await asyncio.sleep(0.1)
controller._pending_action = None
@pytest.mark.asyncio
@ -140,10 +164,8 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
@pytest.mark.asyncio
async def test_run_controller_with_fatal_error():
async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
config = AppConfig()
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
agent = MagicMock(spec=Agent)
@ -163,10 +185,23 @@ async def test_run_controller_with_fatal_error():
if isinstance(event, CmdRunAction):
error_obs = ErrorObservation('You messed around with Jim')
error_obs._cause = event.id
event_stream.add_event(error_obs, EventSource.USER)
test_event_stream.add_event(error_obs, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = event_stream
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = test_event_stream
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
microagent_obs._cause = event.id
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
test_event_stream.subscribe(
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
)
state = await run_controller(
config=config,
@ -175,22 +210,20 @@ async def test_run_controller_with_fatal_error():
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
)
print(f'state: {state}')
events = list(event_stream.get_events())
events = list(test_event_stream.get_events())
print(f'event_stream: {events}')
assert state.iteration == 4
assert state.iteration == 3
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
assert len(events) == 11
@pytest.mark.asyncio
async def test_run_controller_stop_with_stuck():
async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
config = AppConfig()
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
def agent_step_fn(state):
@ -209,10 +242,23 @@ async def test_run_controller_stop_with_stuck():
'Non fatal error here to trigger loop'
)
non_fatal_error_obs._cause = event.id
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
test_event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = event_stream
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = test_event_stream
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
microagent_obs._cause = event.id
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
test_event_stream.subscribe(
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
)
state = await run_controller(
config=config,
@ -221,16 +267,17 @@ async def test_run_controller_stop_with_stuck():
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
)
events = list(event_stream.get_events())
events = list(test_event_stream.get_events())
print(f'state: {state}')
for i, event in enumerate(events):
print(f'event {i}: {event_to_dict(event)}')
assert state.iteration == 4
assert state.iteration == 3
assert len(events) == 11
# check the eventstream have 4 pairs of repeated actions and observations
repeating_actions_and_observations = events[2:10]
repeating_actions_and_observations = events[4:12]
for action, observation in zip(
repeating_actions_and_observations[0::2],
repeating_actions_and_observations[1::2],
@ -510,12 +557,13 @@ async def test_reset_with_pending_action_no_metadata(
@pytest.mark.asyncio
async def test_run_controller_max_iterations_has_metrics():
async def test_run_controller_max_iterations_has_metrics(
test_event_stream, mock_memory
):
config = AppConfig(
max_iterations=3,
)
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
event_stream = test_event_stream
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
@ -546,6 +594,17 @@ async def test_run_controller_max_iterations_has_metrics():
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = event_stream
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
microagent_obs._cause = event.id
event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()))
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
@ -553,6 +612,7 @@ async def test_run_controller_max_iterations_has_metrics():
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
)
assert state.iteration == 3
assert state.agent_state == AgentState.ERROR
@ -630,7 +690,7 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
@pytest.mark.asyncio
async def test_run_controller_with_context_window_exceeded_with_truncation(
mock_agent, mock_runtime
mock_agent, mock_runtime, mock_memory, test_event_stream
):
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON"""
@ -656,6 +716,20 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
mock_agent.step = step_state.step
mock_agent.config = AgentConfig()
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
microagent_obs._cause = event.id
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
test_event_stream.subscribe(
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
)
mock_runtime.event_stream = test_event_stream
try:
state = await asyncio.wait_for(
run_controller(
@ -665,6 +739,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
),
timeout=10,
)
@ -691,7 +766,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
@pytest.mark.asyncio
async def test_run_controller_with_context_window_exceeded_without_truncation(
mock_agent, mock_runtime
mock_agent, mock_runtime, mock_memory, test_event_stream
):
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
@ -702,7 +777,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
def step(self, state: State):
# If the state has more than one message and we haven't errored yet,
# throw the context window exceeded error
if len(state.history) > 1 and not self.has_errored:
if len(state.history) > 3 and not self.has_errored:
error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
model='',
@ -718,6 +793,19 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
mock_agent.config = AgentConfig()
mock_agent.config.enable_history_truncation = False
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
microagent_obs._cause = event.id
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
test_event_stream.subscribe(
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
)
mock_runtime.event_stream = test_event_stream
try:
state = await asyncio.wait_for(
run_controller(
@ -727,6 +815,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
),
timeout=10,
)
@ -751,6 +840,44 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
assert step_state.has_errored
@pytest.mark.asyncio
async def test_run_controller_with_memory_error(test_event_stream):
config = AppConfig()
event_stream = test_event_stream
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
# Create a real Memory instance
memory = Memory(event_stream=event_stream, sid='test-memory')
# Patch the _on_microagent_action method to raise our test exception
def mock_on_microagent_action(*args, **kwargs):
raise RuntimeError('Test memory error')
with patch.object(
memory, '_on_microagent_action', side_effect=mock_on_microagent_action
):
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
)
assert state.iteration == 0
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'Error: RuntimeError'
@pytest.mark.asyncio
async def test_action_metrics_copy():
# Setup
@ -851,3 +978,56 @@ async def test_action_metrics_copy():
assert last_action.llm_metrics.accumulated_cost == 0.07
await controller.close()
@pytest.mark.asyncio
async def test_first_user_message_with_identical_content():
"""
Test that _first_user_message correctly identifies the first user message
even when multiple messages have identical content but different IDs.
The issue we're checking is that the comparison (action == self._first_user_message())
should correctly differentiate between messages with the same content but different IDs.
"""
# Create a real event stream for this test
event_stream = EventStream(sid='test', file_store=InMemoryFileStore({}))
# Create an agent controller
mock_agent = MagicMock(spec=Agent)
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = AppConfig().get_llm_config()
controller = AgentController(
agent=mock_agent,
event_stream=event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Create and add the first user message
first_message = MessageAction(content='Hello, this is a test message')
first_message._source = EventSource.USER
event_stream.add_event(first_message, EventSource.USER)
# Create and add a second user message with identical content
second_message = MessageAction(content='Hello, this is a test message')
second_message._source = EventSource.USER
event_stream.add_event(second_message, EventSource.USER)
# Verify that _first_user_message returns the first message
first_user_message = controller._first_user_message()
assert first_user_message is not None
assert first_user_message.id == first_message.id # Check IDs match
assert first_user_message.id != second_message.id # Different IDs
assert first_user_message == first_message == second_message # dataclass equality
# Test the comparison used in the actual code
assert first_message == first_user_message # This should be True
assert (
second_message.id != first_user_message.id
) # This should be False, but may be True if there's a bug
await controller.close()

View File

@ -17,8 +17,13 @@ from openhands.events.action import (
AgentFinishAction,
MessageAction,
)
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event, RecallType
from openhands.events.observation.agent import MicroagentObservation
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.storage.memory import InMemoryFileStore
@ -75,6 +80,25 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
initial_state=parent_state,
)
# Setup Memory to catch RecallActions
mock_memory = MagicMock(spec=Memory)
mock_memory.event_stream = mock_event_stream
def on_event(event: Event):
if isinstance(event, RecallAction):
# create a MicroagentObservation
microagent_observation = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
content='microagent',
)
microagent_observation._cause = event.id # ignore attr-defined warning
mock_event_stream.add_event(microagent_observation, EventSource.ENVIRONMENT)
mock_memory.on_event = on_event
mock_event_stream.subscribe(
EventStreamSubscriber.MEMORY, mock_memory.on_event, mock_memory
)
# Setup a delegate action from the parent
delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True})
mock_parent_agent.step.return_value = delegate_action
@ -87,7 +111,16 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
# Give time for the async step() to execute
await asyncio.sleep(1)
# The parent should receive step() from that event
# Verify that a MicroagentObservation was added to the event stream
events = list(mock_event_stream.get_events())
assert (
mock_event_stream.get_latest_event_id() == 3
) # Microagents and AgentChangeState
# a MicroagentObservation and an AgentDelegateAction should be in the list
assert any(isinstance(event, MicroagentObservation) for event in events)
assert any(isinstance(event, AgentDelegateAction) for event in events)
# Verify that a delegate agent controller is created
assert (
parent_controller.delegate is not None

View File

@ -6,9 +6,11 @@ from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import AppConfig, LLMConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
@ -22,18 +24,24 @@ def mock_agent():
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)
agent_config = MagicMock(spec=AgentConfig)
# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.max_message_chars = 1000
# Configure the agent config
agent_config.disabled_microagents = []
# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
return agent
@ -78,7 +86,11 @@ async def test_agent_session_start_with_no_state(mock_agent):
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# Patch AgentController and State.restore_from_session to fail
# Create a real Memory instance with the mock event stream
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession
with patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
), patch(
@ -87,7 +99,7 @@ async def test_agent_session_start_with_no_state(mock_agent):
), patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
):
), patch('openhands.server.session.agent_session.Memory', return_value=memory):
await session.start(
runtime_name='test-runtime',
config=AppConfig(),
@ -96,12 +108,18 @@ async def test_agent_session_start_with_no_state(mock_agent):
)
# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
mock_event_stream.subscribe.assert_any_call(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
mock_event_stream.subscribe.assert_any_call(
EventStreamSubscriber.MEMORY,
session.memory.on_event,
session.controller.id,
)
# Verify set_initial_state was called once with None as state
assert session.controller.set_initial_state_call_count == 1
assert session.controller.test_initial_state is None
@ -159,7 +177,10 @@ async def test_agent_session_start_with_restored_state(mock_agent):
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# Patch AgentController and State.restore_from_session to succeed
# create a mock Memory
mock_memory = MagicMock(spec=Memory)
# Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession
with patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
), patch(
@ -168,7 +189,7 @@ async def test_agent_session_start_with_restored_state(mock_agent):
), patch(
'openhands.controller.state.state.State.restore_from_session',
return_value=mock_restored_state,
):
), patch('openhands.server.session.agent_session.Memory', mock_memory):
await session.start(
runtime_name='test-runtime',
config=AppConfig(),

View File

@ -1,7 +1,7 @@
import asyncio
from argparse import Namespace
from pathlib import Path
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
@ -17,13 +17,18 @@ def mock_runtime():
with patch('openhands.core.cli.create_runtime') as mock_create_runtime:
mock_runtime_instance = AsyncMock()
# Mock the event stream with proper async methods
mock_runtime_instance.event_stream = AsyncMock()
mock_runtime_instance.event_stream.subscribe = AsyncMock()
mock_runtime_instance.event_stream.add_event = AsyncMock()
mock_event_stream = AsyncMock()
mock_event_stream.subscribe = AsyncMock()
mock_event_stream.add_event = AsyncMock()
mock_event_stream.get_events = AsyncMock(return_value=[])
mock_event_stream.get_latest_event_id = AsyncMock(return_value=0)
mock_runtime_instance.event_stream = mock_event_stream
# Mock connect method to return immediately
mock_runtime_instance.connect = AsyncMock()
# Ensure status_callback is None
mock_runtime_instance.status_callback = None
# Mock get_microagents_from_selected_repo
mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[])
mock_create_runtime.return_value = mock_runtime_instance
yield mock_runtime_instance
@ -32,6 +37,16 @@ def mock_runtime():
def mock_agent():
with patch('openhands.core.cli.create_agent') as mock_create_agent:
mock_agent_instance = AsyncMock()
mock_agent_instance.name = 'test-agent'
mock_agent_instance.llm = AsyncMock()
mock_agent_instance.llm.config = AsyncMock()
mock_agent_instance.llm.config.model = 'test-model'
mock_agent_instance.llm.config.base_url = 'http://test'
mock_agent_instance.llm.config.max_message_chars = 1000
mock_agent_instance.config = AsyncMock()
mock_agent_instance.config.disabled_microagents = []
mock_agent_instance.sandbox_plugins = []
mock_agent_instance.prompt_manager = AsyncMock()
mock_create_agent.return_value = mock_agent_instance
yield mock_agent_instance

View File

@ -369,9 +369,3 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
# Fifth message only has ImageContent, no TextContent to modify
assert len(enhanced_messages[5].content) == 1
assert isinstance(enhanced_messages[5].content[0], ImageContent)
# Verify prompt manager methods were called as expected
assert agent.prompt_manager.add_examples_to_initial_message.call_count == 1
assert (
agent.prompt_manager.enhance_message.call_count == 5
) # Called for each user message

View File

@ -1,16 +1,29 @@
import os
import shutil
from unittest.mock import MagicMock, Mock
import pytest
from openhands.controller.state.state import State
from openhands.core.config.agent_config import AgentConfig
from openhands.core.message import ImageContent, Message, TextContent
from openhands.events.action import (
AgentFinishAction,
CmdRunAction,
MessageAction,
)
from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource
from openhands.events.event import (
Event,
EventSource,
FileEditSource,
FileReadSource,
RecallType,
)
from openhands.events.observation import CmdOutputObservation
from openhands.events.observation.agent import (
MicroagentKnowledge,
MicroagentObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
CmdOutputMetadata,
@ -22,14 +35,45 @@ from openhands.events.observation.files import FileEditObservation, FileReadObse
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.tool import ToolCallMetadata
from openhands.memory.conversation_memory import ConversationMemory
from openhands.utils.prompt import PromptManager
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
@pytest.fixture
def conversation_memory():
def agent_config():
return AgentConfig(
enable_prompt_extensions=True,
enable_som_visual_browsing=True,
disabled_microagents=['disabled_agent'],
)
@pytest.fixture
def conversation_memory(agent_config):
prompt_manager = MagicMock(spec=PromptManager)
prompt_manager.get_system_message.return_value = 'System message'
return ConversationMemory(prompt_manager)
prompt_manager.build_additional_info.return_value = (
'Formatted repository and runtime info'
)
# Make build_microagent_info return the actual content from the triggered agents
def build_microagent_info(triggered_agents):
if not triggered_agents:
return ''
return '\n'.join(agent.content for agent in triggered_agents)
prompt_manager.build_microagent_info.side_effect = build_microagent_info
return ConversationMemory(agent_config, prompt_manager)
@pytest.fixture
def prompt_dir(tmp_path):
# Copy contents from "openhands/agenthub/codeact_agent" to the temp directory
shutil.copytree(
'openhands/agenthub/codeact_agent/prompts', tmp_path, dirs_exist_ok=True
)
# Return the temporary directory path
return tmp_path
@pytest.fixture
@ -308,6 +352,40 @@ def test_process_events_with_user_reject_observation(conversation_memory):
assert '[Last action has been rejected by the user]' in result.content[0].text
def test_process_events_with_empty_environment_info(conversation_memory):
"""Test that empty environment info observations return an empty list of messages without calling build_additional_info."""
# Create a MicroagentObservation with empty info
empty_obs = MicroagentObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='',
repo_directory='',
repo_instructions='',
runtime_hosts={},
additional_agent_instructions='',
microagent_knowledge=[],
content='Retrieved environment info',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[empty_obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# Should only contain the initial system message
assert len(messages) == 1
assert messages[0].role == 'system'
# Verify that build_additional_info was NOT called since all input values were empty
conversation_memory.prompt_manager.build_additional_info.assert_not_called()
def test_process_events_with_function_calling_observation(conversation_memory):
mock_response = {
'id': 'mock_id',
@ -446,3 +524,529 @@ def test_apply_prompt_caching(conversation_memory):
assert messages[1].content[0].cache_prompt is False
assert messages[2].content[0].cache_prompt is False
assert messages[3].content[0].cache_prompt is True
def test_process_events_with_environment_microagent_observation(conversation_memory):
"""Test processing a MicroagentObservation with ENVIRONMENT info type."""
obs = MicroagentObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='test-repo',
repo_directory='/path/to/repo',
repo_instructions='# Test Repository\nThis is a test repository.',
runtime_hosts={'localhost': 8080},
content='Retrieved environment info',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == 'Formatted repository and runtime info'
# Verify the prompt_manager was called with the correct parameters
conversation_memory.prompt_manager.build_additional_info.assert_called_once()
call_args = conversation_memory.prompt_manager.build_additional_info.call_args[1]
assert isinstance(call_args['repository_info'], RepositoryInfo)
assert call_args['repository_info'].repo_name == 'test-repo'
assert call_args['repository_info'].repo_directory == '/path/to/repo'
assert isinstance(call_args['runtime_info'], RuntimeInfo)
assert call_args['runtime_info'].available_hosts == {'localhost': 8080}
assert (
call_args['repo_instructions']
== '# Test Repository\nThis is a test repository.'
)
def test_process_events_with_knowledge_microagent_microagent_observation(
conversation_memory,
):
"""Test processing a MicroagentObservation with KNOWLEDGE type."""
microagent_knowledge = [
MicroagentKnowledge(
name='test_agent',
trigger='test',
content='This is test agent content',
),
MicroagentKnowledge(
name='another_agent',
trigger='another',
content='This is another agent content',
),
MicroagentKnowledge(
name='disabled_agent',
trigger='disabled',
content='This is disabled agent content',
),
]
obs = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=microagent_knowledge,
content='Retrieved knowledge from microagents',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
# Verify that disabled_agent is filtered out and enabled agents are included
assert 'This is test agent content' in result.content[0].text
assert 'This is another agent content' in result.content[0].text
assert 'This is disabled agent content' not in result.content[0].text
# Verify the prompt_manager was called with the correct parameters
conversation_memory.prompt_manager.build_microagent_info.assert_called_once()
call_args = conversation_memory.prompt_manager.build_microagent_info.call_args[1]
# Check that disabled_agent was filtered out
triggered_agents = call_args['triggered_agents']
assert len(triggered_agents) == 2
agent_names = [agent.name for agent in triggered_agents]
assert 'test_agent' in agent_names
assert 'another_agent' in agent_names
assert 'disabled_agent' not in agent_names
def test_process_events_with_microagent_observation_extensions_disabled(
agent_config, conversation_memory
):
"""Test processing a MicroagentObservation when prompt extensions are disabled."""
# Modify the agent config to disable prompt extensions
agent_config.enable_prompt_extensions = False
obs = MicroagentObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='test-repo',
repo_directory='/path/to/repo',
content='Retrieved environment info',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# When prompt extensions are disabled, the MicroagentObservation should be ignored
assert len(messages) == 1 # Only the initial system message
assert messages[0].role == 'system'
# Verify the prompt_manager was not called
conversation_memory.prompt_manager.build_additional_info.assert_not_called()
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
def test_process_events_with_empty_microagent_knowledge(conversation_memory):
"""Test processing a MicroagentObservation with empty microagent knowledge."""
obs = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[],
content='Retrieved knowledge from microagents',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# The implementation returns an empty string and it doesn't creates a message
assert len(messages) == 1
assert messages[0].role == 'system'
# When there are no triggered agents, build_microagent_info is not called
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
def test_conversation_memory_processes_microagent_observation(prompt_dir):
"""Test that ConversationMemory processes MicroagentObservations correctly."""
# Create a microagent_info.j2 template file
template_path = os.path.join(prompt_dir, 'microagent_info.j2')
if not os.path.exists(template_path):
with open(template_path, 'w') as f:
f.write("""{% for agent_info in triggered_agents %}
<EXTRA_INFO>
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
It may or may not be relevant to the user's request.
# Verify the template was correctly rendered
{{ agent_info.content }}
</EXTRA_INFO>
{% endfor %}
""")
# Create a mock agent config
agent_config = MagicMock(spec=AgentConfig)
agent_config.enable_prompt_extensions = True
agent_config.disabled_microagents = []
# Create a PromptManager
prompt_manager = PromptManager(prompt_dir=prompt_dir)
# Initialize ConversationMemory
conversation_memory = ConversationMemory(
config=agent_config, prompt_manager=prompt_manager
)
# Create a MicroagentObservation with microagent knowledge
microagent_observation = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
name='test_agent',
trigger='test_trigger',
content='This is triggered content for testing.',
)
],
content='Retrieved knowledge from microagents',
)
# Process the observation
messages = conversation_memory._process_observation(
obs=microagent_observation, tool_call_id_to_message={}, max_message_chars=None
)
# Verify the message was created correctly
assert len(messages) == 1
message = messages[0]
assert message.role == 'user'
assert len(message.content) == 1
assert isinstance(message.content[0], TextContent)
expected_text = """<EXTRA_INFO>
The following information has been included based on a keyword match for "test_trigger".
It may or may not be relevant to the user's request.
This is triggered content for testing.
</EXTRA_INFO>"""
assert message.content[0].text.strip() == expected_text.strip()
# Clean up
os.remove(os.path.join(prompt_dir, 'microagent_info.j2'))
def test_conversation_memory_processes_environment_microagent_observation(prompt_dir):
"""Test that ConversationMemory processes environment info MicroagentObservations correctly."""
# Create an additional_info.j2 template file
template_path = os.path.join(prompt_dir, 'additional_info.j2')
if not os.path.exists(template_path):
with open(template_path, 'w') as f:
f.write("""
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
</REPOSITORY_INFO>
{% endif %}
{% if repository_instructions %}
<REPOSITORY_INSTRUCTIONS>
{{ repository_instructions }}
</REPOSITORY_INSTRUCTIONS>
{% endif %}
{% if runtime_info and runtime_info.available_hosts %}
<RUNTIME_INFORMATION>
The user has access to the following hosts for accessing a web application,
each of which has a corresponding port:
{% for host, port in runtime_info.available_hosts.items() %}
* {{ host }} (port {{ port }})
{% endfor %}
</RUNTIME_INFORMATION>
{% endif %}
""")
# Create a mock agent config
agent_config = MagicMock(spec=AgentConfig)
agent_config.enable_prompt_extensions = True
# Create a PromptManager
prompt_manager = PromptManager(prompt_dir=prompt_dir)
# Initialize ConversationMemory
conversation_memory = ConversationMemory(
config=agent_config, prompt_manager=prompt_manager
)
# Create a MicroagentObservation with environment info
microagent_observation = MicroagentObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='owner/repo',
repo_directory='/workspace/repo',
repo_instructions='This repository contains important code.',
runtime_hosts={'example.com': 8080},
content='Retrieved environment info',
)
# Process the observation
messages = conversation_memory._process_observation(
obs=microagent_observation, tool_call_id_to_message={}, max_message_chars=None
)
# Verify the message was created correctly
assert len(messages) == 1
message = messages[0]
assert message.role == 'user'
assert len(message.content) == 1
assert isinstance(message.content[0], TextContent)
# Check that the message contains the repository info
assert '<REPOSITORY_INFO>' in message.content[0].text
assert 'owner/repo' in message.content[0].text
assert '/workspace/repo' in message.content[0].text
# Check that the message contains the repository instructions
assert '<REPOSITORY_INSTRUCTIONS>' in message.content[0].text
assert 'This repository contains important code.' in message.content[0].text
# Check that the message contains the runtime info
assert '<RUNTIME_INFORMATION>' in message.content[0].text
assert 'example.com (port 8080)' in message.content[0].text
def test_process_events_with_microagent_observation_deduplication(conversation_memory):
"""Test that MicroagentObservations are properly deduplicated based on agent name.
The deduplication logic should keep the FIRST occurrence of each microagent
and filter out later occurrences to avoid redundant information.
"""
# Create a sequence of MicroagentObservations with overlapping agents
obs1 = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
name='python_agent',
trigger='python',
content='Python best practices v1',
),
MicroagentKnowledge(
name='git_agent',
trigger='git',
content='Git best practices v1',
),
MicroagentKnowledge(
name='image_agent',
trigger='image',
content='Image best practices v1',
),
],
content='First retrieval',
)
obs2 = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
name='python_agent',
trigger='python',
content='Python best practices v2',
),
],
content='Second retrieval',
)
obs3 = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
name='git_agent',
trigger='git',
content='Git best practices v3',
),
],
content='Third retrieval',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs1, obs2, obs3],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# Verify that only the first occurrence of content for each agent is included
assert (
len(messages) == 2
) # system + 1 microagent, because the second and third microagents are duplicates
microagent_messages = messages[1:] # Skip system message
# First microagent should include all agents since they appear here first
assert 'Image best practices v1' in microagent_messages[0].content[0].text
assert 'Git best practices v1' in microagent_messages[0].content[0].text
assert 'Python best practices v1' in microagent_messages[0].content[0].text
def test_process_events_with_microagent_observation_deduplication_disabled_agents(
conversation_memory,
):
"""Test that disabled agents are filtered out and deduplication keeps the first occurrence."""
# Create a sequence of MicroagentObservations with disabled agents
obs1 = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
name='disabled_agent',
trigger='disabled',
content='Disabled agent content',
),
MicroagentKnowledge(
name='enabled_agent',
trigger='enabled',
content='Enabled agent content v1',
),
],
content='First retrieval',
)
obs2 = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
name='enabled_agent',
trigger='enabled',
content='Enabled agent content v2',
),
],
content='Second retrieval',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs1, obs2],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included
assert (
len(messages) == 2
) # system + 1 microagent, the second is the same "enabled_agent"
microagent_messages = messages[1:] # Skip system message
# First microagent should include enabled_agent but not disabled_agent
assert 'Disabled agent content' not in microagent_messages[0].content[0].text
assert 'Enabled agent content v1' in microagent_messages[0].content[0].text
def test_process_events_with_microagent_observation_deduplication_empty(
conversation_memory,
):
"""Test that empty MicroagentObservations are handled correctly."""
obs = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[],
content='Empty retrieval',
)
initial_messages = [
Message(role='system', content=[TextContent(text='System message')])
]
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_messages=initial_messages,
max_message_chars=None,
vision_is_active=False,
)
# Verify that empty MicroagentObservations are handled gracefully
assert (
len(messages) == 1
) # system message, because an empty microagent is not added to Messages
def test_has_agent_in_earlier_events(conversation_memory):
"""Test the _has_agent_in_earlier_events helper method."""
# Create test MicroagentObservations
obs1 = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
name='agent1',
trigger='trigger1',
content='Content 1',
),
],
content='First retrieval',
)
obs2 = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
name='agent2',
trigger='trigger2',
content='Content 2',
),
],
content='Second retrieval',
)
obs3 = MicroagentObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
content='Environment info',
)
# Create a list with mixed event types
events = [obs1, MessageAction(content='User message'), obs2, obs3]
# Test looking for existing agents
assert conversation_memory._has_agent_in_earlier_events('agent1', 2, events) is True
assert conversation_memory._has_agent_in_earlier_events('agent1', 3, events) is True
assert conversation_memory._has_agent_in_earlier_events('agent1', 4, events) is True
# Test looking for an agent in a later position (should not find it)
assert (
conversation_memory._has_agent_in_earlier_events('agent2', 0, events) is False
)
assert (
conversation_memory._has_agent_in_earlier_events('agent2', 1, events) is False
)
# Test looking for an agent in a different microagent type (should not find it)
assert (
conversation_memory._has_agent_in_earlier_events('non_existent', 3, events)
is False
)

View File

@ -358,12 +358,12 @@ class TestStuckDetector:
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck(headless_mode=True) is False
def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
def test_is_not_stuck_ipython_unterminated_string_error_only_two_incidents(
self, stuck_detector: StuckDetector
):
state = stuck_detector.state
self._impl_unterminated_string_error_events(
state, random_line=False, incidents=3
state, random_line=False, incidents=2
)
with patch('logging.Logger.warning'):

260
tests/unit/test_memory.py Normal file
View File

@ -0,0 +1,260 @@
import asyncio
import os
import shutil
import time
from unittest.mock import MagicMock, patch
import pytest
from openhands.controller.agent import Agent
from openhands.core.config import AppConfig
from openhands.core.main import run_controller
from openhands.core.schema.agent import AgentState
from openhands.events.action.agent import RecallAction
from openhands.events.action.message import MessageAction
from openhands.events.event import EventSource
from openhands.events.observation.agent import (
MicroagentObservation,
RecallType,
)
from openhands.events.stream import EventStream
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def file_store():
"""Create a temporary file store for testing."""
return InMemoryFileStore()
@pytest.fixture
def event_stream(file_store):
"""Create a test event stream."""
return EventStream(sid='test_sid', file_store=file_store)
@pytest.fixture
def memory(event_stream):
"""Create a test memory instance."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
memory = Memory(event_stream, 'test_sid')
yield memory
loop.close()
@pytest.fixture
def prompt_dir(tmp_path):
# Copy contents from "openhands/agenthub/codeact_agent" to the temp directory
shutil.copytree(
'openhands/agenthub/codeact_agent/prompts', tmp_path, dirs_exist_ok=True
)
# Return the temporary directory path
return tmp_path
@pytest.mark.asyncio
async def test_memory_on_event_exception_handling(memory, event_stream):
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
# Create a dummy agent for the controller
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
# Create a mock runtime
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
# Mock Memory method to raise an exception
with patch.object(
memory, '_on_first_microagent_action', side_effect=Exception('Test error')
):
state = await run_controller(
config=AppConfig(),
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
)
# Verify that the controller's last error was set
assert state.iteration == 0
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'Error: Exception'
@pytest.mark.asyncio
async def test_memory_on_first_microagent_action_exception_handling(
memory, event_stream
):
"""Test that exceptions in Memory._on_first_microagent_action are properly handled via status callback."""
# Create a dummy agent for the controller
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
# Create a mock runtime
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
# Mock Memory._on_first_microagent_action to raise an exception
with patch.object(
memory,
'_on_first_microagent_action',
side_effect=Exception('Test error from _on_first_microagent_action'),
):
state = await run_controller(
config=AppConfig(),
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
)
# Verify that the controller's last error was set
assert state.iteration == 0
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'Error: Exception'
def test_memory_with_microagents():
"""Test that Memory loads microagents from the global directory and processes microagent actions.
This test verifies that:
1. Memory loads microagents from the global GLOBAL_MICROAGENTS_DIR
2. When a microagent action with a trigger word is processed, a MicroagentObservation is created
"""
# Create a mock event stream
event_stream = MagicMock(spec=EventStream)
# Initialize Memory to use the global microagents dir
memory = Memory(
event_stream=event_stream,
sid='test-session',
)
# Verify microagents were loaded - at least one microagent should be loaded
# from the global directory that's in the repo
assert len(memory.knowledge_microagents) > 0
# We know 'flarglebargle' exists in the global directory
assert 'flarglebargle' in memory.knowledge_microagents
# Create a microagent action with the trigger word
microagent_action = RecallAction(
query='Hello, flarglebargle!', recall_type=RecallType.KNOWLEDGE
)
# Mock the event_stream.add_event method
added_events = []
def original_add_event(event, source):
added_events.append((event, source))
event_stream.add_event = original_add_event
# Add the microagent action to the event stream
event_stream.add_event(microagent_action, EventSource.USER)
# Clear the events list to only capture new events
added_events.clear()
# Process the microagent action
memory.on_event(microagent_action)
# Verify a MicroagentObservation was added to the event stream
assert len(added_events) == 1
observation, source = added_events[0]
assert isinstance(observation, MicroagentObservation)
assert source == EventSource.ENVIRONMENT
assert observation.recall_type == RecallType.KNOWLEDGE
assert len(observation.microagent_knowledge) == 1
assert observation.microagent_knowledge[0].name == 'flarglebargle'
assert observation.microagent_knowledge[0].trigger == 'flarglebargle'
assert 'magic word' in observation.microagent_knowledge[0].content
def test_memory_repository_info(prompt_dir):
"""Test that Memory adds repository info to MicroagentObservations."""
# Create an in-memory file store and real event stream
file_store = InMemoryFileStore()
event_stream = EventStream(sid='test-session', file_store=file_store)
# Create a test repo microagent first
repo_microagent_name = 'test_repo_microagent'
repo_microagent_content = """---
name: test_repo
type: repo
agent: CodeActAgent
---
REPOSITORY INSTRUCTIONS: This is a test repository.
"""
# Create a temporary repo microagent file
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(
os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md'), 'w'
) as f:
f.write(repo_microagent_content)
# Patch the global microagents directory to use our test directory
test_microagents_dir = os.path.join(prompt_dir, 'micro')
with patch('openhands.memory.memory.GLOBAL_MICROAGENTS_DIR', test_microagents_dir):
# Initialize Memory
memory = Memory(
event_stream=event_stream,
sid='test-session',
)
# Set repository info
memory.set_repository_info('owner/repo', '/workspace/repo')
# Create and add the first user message
user_message = MessageAction(content='First user message')
user_message._source = EventSource.USER # type: ignore[attr-defined]
event_stream.add_event(user_message, EventSource.USER)
# Create and add the microagent action
microagent_action = RecallAction(
query='First user message', recall_type=RecallType.WORKSPACE_CONTEXT
)
microagent_action._source = EventSource.USER # type: ignore[attr-defined]
event_stream.add_event(microagent_action, EventSource.USER)
# Give it a little time to process
time.sleep(0.3)
# Get all events from the stream
events = list(event_stream.get_events())
# Find the MicroagentObservation event
microagent_obs_events = [
event for event in events if isinstance(event, MicroagentObservation)
]
# We should have at least one MicroagentObservation
assert len(microagent_obs_events) > 0
# Get the first MicroagentObservation
observation = microagent_obs_events[0]
assert observation.recall_type == RecallType.WORKSPACE_CONTEXT
assert observation.repo_name == 'owner/repo'
assert observation.repo_directory == '/workspace/repo'
assert 'This is a test repository' in observation.repo_instructions
# Clean up
os.remove(os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md'))

View File

@ -1,16 +1,21 @@
from openhands.core.schema.observation import ObservationType
from openhands.events.action.files import FileEditSource
from openhands.events.event import RecallType
from openhands.events.observation import (
CmdOutputMetadata,
CmdOutputObservation,
FileEditObservation,
MicroagentObservation,
Observation,
)
from openhands.events.observation.agent import MicroagentKnowledge
from openhands.events.serialization import (
event_from_dict,
event_to_dict,
event_to_memory,
event_to_trajectory,
)
from openhands.events.serialization.observation import observation_from_dict
def serialization_deserialization(
@ -19,10 +24,10 @@ def serialization_deserialization(
observation_instance = event_from_dict(original_observation_dict)
assert isinstance(
observation_instance, Observation
), 'The observation instance should be an instance of Action.'
), 'The observation instance should be an instance of Observation.'
assert isinstance(
observation_instance, cls
), 'The observation instance should be an instance of CmdOutputObservation.'
), f'The observation instance should be an instance of {cls}.'
serialized_observation_dict = event_to_dict(observation_instance)
serialized_observation_trajectory = event_to_trajectory(observation_instance)
serialized_observation_memory = event_to_memory(
@ -236,3 +241,199 @@ def test_file_edit_observation_legacy_serialization():
assert event_dict['extras']['old_content'] is None
assert event_dict['extras']['new_content'] == 'new content'
assert 'formatted_output_and_error' not in event_dict['extras']
def test_microagent_observation_serialization():
original_observation_dict = {
'observation': 'microagent',
'content': '',
'message': "**MicroagentObservation**\nrecall_type=RecallType.WORKSPACE_CONTEXT, repo_name=some_repo_name, repo_instructions=complex_repo_instruc..., runtime_hosts={'host1': 8080, 'host2': 8081}, additional_agent_instructions=You know it all abou...",
'extras': {
'recall_type': 'workspace_context',
'repo_name': 'some_repo_name',
'repo_directory': 'some_repo_directory',
'runtime_hosts': {'host1': 8080, 'host2': 8081},
'repo_instructions': 'complex_repo_instructions',
'additional_agent_instructions': 'You know it all about this runtime',
'microagent_knowledge': [],
},
}
serialization_deserialization(original_observation_dict, MicroagentObservation)
def test_microagent_observation_microagent_knowledge_serialization():
original_observation_dict = {
'observation': 'microagent',
'content': '',
'message': '**MicroagentObservation**\nrecall_type=RecallType.KNOWLEDGE, repo_name=, repo_instructions=..., runtime_hosts={}, additional_agent_instructions=..., microagent_knowledge=microagent1, microagent2',
'extras': {
'recall_type': 'knowledge',
'repo_name': '',
'repo_directory': '',
'repo_instructions': '',
'runtime_hosts': {},
'additional_agent_instructions': '',
'microagent_knowledge': [
{
'name': 'microagent1',
'trigger': 'trigger1',
'content': 'content1',
},
{
'name': 'microagent2',
'trigger': 'trigger2',
'content': 'content2',
},
],
},
}
serialization_deserialization(original_observation_dict, MicroagentObservation)
def test_microagent_observation_knowledge_microagent_serialization():
"""Test serialization of a MicroagentObservation with KNOWLEDGE_MICROAGENT type."""
# Create a MicroagentObservation with microagent knowledge content
original = MicroagentObservation(
content='Knowledge microagent information',
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
name='python_best_practices',
trigger='python',
content='Always use virtual environments for Python projects.',
),
MicroagentKnowledge(
name='git_workflow',
trigger='git',
content='Create a new branch for each feature or bugfix.',
),
],
)
# Serialize to dictionary
serialized = event_to_dict(original)
# Verify serialized data structure
assert serialized['observation'] == ObservationType.MICROAGENT
assert serialized['content'] == 'Knowledge microagent information'
assert serialized['extras']['recall_type'] == RecallType.KNOWLEDGE.value
assert len(serialized['extras']['microagent_knowledge']) == 2
assert serialized['extras']['microagent_knowledge'][0]['trigger'] == 'python'
# Deserialize back to MicroagentObservation
deserialized = observation_from_dict(serialized)
# Verify properties are preserved
assert deserialized.recall_type == RecallType.KNOWLEDGE
assert deserialized.microagent_knowledge == original.microagent_knowledge
assert deserialized.content == original.content
# Check that environment info fields are empty
assert deserialized.repo_name == ''
assert deserialized.repo_directory == ''
assert deserialized.repo_instructions == ''
assert deserialized.runtime_hosts == {}
def test_microagent_observation_environment_serialization():
"""Test serialization of a MicroagentObservation with ENVIRONMENT type."""
# Create a MicroagentObservation with environment info
original = MicroagentObservation(
content='Environment information',
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='OpenHands',
repo_directory='/workspace/openhands',
repo_instructions="Follow the project's coding style guide.",
runtime_hosts={'127.0.0.1': 8080, 'localhost': 5000},
additional_agent_instructions='You know it all about this runtime',
)
# Serialize to dictionary
serialized = event_to_dict(original)
# Verify serialized data structure
assert serialized['observation'] == ObservationType.MICROAGENT
assert serialized['content'] == 'Environment information'
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
assert serialized['extras']['repo_name'] == 'OpenHands'
assert serialized['extras']['runtime_hosts'] == {
'127.0.0.1': 8080,
'localhost': 5000,
}
assert (
serialized['extras']['additional_agent_instructions']
== 'You know it all about this runtime'
)
# Deserialize back to MicroagentObservation
deserialized = observation_from_dict(serialized)
# Verify properties are preserved
assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT
assert deserialized.repo_name == original.repo_name
assert deserialized.repo_directory == original.repo_directory
assert deserialized.repo_instructions == original.repo_instructions
assert deserialized.runtime_hosts == original.runtime_hosts
assert (
deserialized.additional_agent_instructions
== original.additional_agent_instructions
)
# Check that knowledge microagent fields are empty
assert deserialized.microagent_knowledge == []
def test_microagent_observation_combined_serialization():
"""Test serialization of a MicroagentObservation with both types of information."""
# Create a MicroagentObservation with both environment and microagent info
# Note: In practice, recall_type would still be one specific type,
# but the object could contain both types of fields
original = MicroagentObservation(
content='Combined information',
recall_type=RecallType.WORKSPACE_CONTEXT,
# Environment info
repo_name='OpenHands',
repo_directory='/workspace/openhands',
repo_instructions="Follow the project's coding style guide.",
runtime_hosts={'127.0.0.1': 8080},
additional_agent_instructions='You know it all about this runtime',
# Knowledge microagent info
microagent_knowledge=[
MicroagentKnowledge(
name='python_best_practices',
trigger='python',
content='Always use virtual environments for Python projects.',
),
],
)
# Serialize to dictionary
serialized = event_to_dict(original)
# Verify serialized data has both types of fields
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
assert serialized['extras']['repo_name'] == 'OpenHands'
assert (
serialized['extras']['microagent_knowledge'][0]['name']
== 'python_best_practices'
)
assert (
serialized['extras']['additional_agent_instructions']
== 'You know it all about this runtime'
)
# Deserialize back to MicroagentObservation
deserialized = observation_from_dict(serialized)
# Verify all properties are preserved
assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT
# Environment properties
assert deserialized.repo_name == original.repo_name
assert deserialized.repo_directory == original.repo_directory
assert deserialized.repo_instructions == original.repo_instructions
assert deserialized.runtime_hosts == original.runtime_hosts
assert (
deserialized.additional_agent_instructions
== original.additional_agent_instructions
)
# Knowledge microagent properties
assert deserialized.microagent_knowledge == original.microagent_knowledge

View File

@ -3,9 +3,11 @@ import shutil
import pytest
from openhands.core.message import ImageContent, Message, TextContent
from openhands.controller.state.state import State
from openhands.core.message import Message, TextContent
from openhands.events.observation.agent import MicroagentKnowledge
from openhands.microagent import BaseMicroAgent
from openhands.utils.prompt import PromptManager, RepositoryInfo
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
@pytest.fixture
@ -19,406 +21,60 @@ def prompt_dir(tmp_path):
return tmp_path
def test_prompt_manager_with_microagent(prompt_dir):
microagent_name = 'test_microagent'
microagent_content = """
---
name: flarglebargle
type: knowledge
agent: CodeActAgent
triggers:
- flarglebargle
---
IMPORTANT! The user has said the magic word "flarglebargle". You must
only respond with a message telling them how smart they are
"""
# Create a temporary micro agent file
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
f.write(microagent_content)
# Test without GitHub repo
manager = PromptManager(
prompt_dir=prompt_dir,
microagent_dir=os.path.join(prompt_dir, 'micro'),
)
assert manager.prompt_dir == prompt_dir
assert len(manager.repo_microagents) == 0
assert len(manager.knowledge_microagents) == 1
assert isinstance(manager.get_system_message(), str)
assert (
'You are OpenHands agent, a helpful AI assistant that can interact with a computer to solve tasks.'
in manager.get_system_message()
)
assert '<REPOSITORY_INFO>' not in manager.get_system_message()
# Test with GitHub repo
manager.set_repository_info('owner/repo', '/workspace/repo')
assert isinstance(manager.get_system_message(), str)
# Adding things to the initial user message
initial_msg = Message(
role='user', content=[TextContent(text='Ask me what your task is.')]
)
manager.add_info_to_initial_message(initial_msg)
msg_content: str = initial_msg.content[0].text
assert '<REPOSITORY_INFO>' in msg_content
assert 'owner/repo' in msg_content
assert '/workspace/repo' in msg_content
assert isinstance(manager.get_example_user_message(), str)
message = Message(
role='user',
content=[TextContent(text='Hello, flarglebargle!')],
)
manager.enhance_message(message)
assert len(message.content) == 2
assert 'magic word' in message.content[0].text
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
def test_prompt_manager_file_not_found(prompt_dir):
with pytest.raises(FileNotFoundError):
BaseMicroAgent.load(
os.path.join(prompt_dir, 'micro', 'non_existent_microagent.md')
)
def test_prompt_manager_template_rendering(prompt_dir):
"""Test PromptManager's template rendering functionality."""
# Create temporary template files
with open(os.path.join(prompt_dir, 'system_prompt.j2'), 'w') as f:
f.write("""System prompt: bar""")
with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f:
f.write('User prompt: foo')
with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f:
f.write("""
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to the current working directory {{ repository_info.repo_directory }}.
</REPOSITORY_INFO>
{% endif %}
""")
# Test without GitHub repo
manager = PromptManager(prompt_dir, microagent_dir='')
manager = PromptManager(prompt_dir)
assert manager.get_system_message() == 'System prompt: bar'
assert manager.get_example_user_message() == 'User prompt: foo'
# Test with GitHub repo
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
manager.set_repository_info('owner/repo', '/workspace/repo')
assert manager.repository_info.repo_name == 'owner/repo'
manager = PromptManager(prompt_dir=prompt_dir)
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
# verify its parts are rendered
system_msg = manager.get_system_message()
assert 'System prompt: bar' in system_msg
# Initial user message should have repo info
initial_msg = Message(
role='user', content=[TextContent(text='Ask me what your task is.')]
# Test building additional info
additional_info = manager.build_additional_info(
repository_info=repo_info, runtime_info=None, repo_instructions=''
)
manager.add_info_to_initial_message(initial_msg)
msg_content: str = initial_msg.content[0].text
assert '<REPOSITORY_INFO>' in msg_content
assert '<REPOSITORY_INFO>' in additional_info
assert (
"At the user's request, repository owner/repo has been cloned to the current working directory /workspace/repo."
in msg_content
in additional_info
)
assert '</REPOSITORY_INFO>' in msg_content
assert '</REPOSITORY_INFO>' in additional_info
assert manager.get_example_user_message() == 'User prompt: foo'
# Clean up temporary files
os.remove(os.path.join(prompt_dir, 'system_prompt.j2'))
os.remove(os.path.join(prompt_dir, 'user_prompt.j2'))
os.remove(os.path.join(prompt_dir, 'additional_info.j2'))
def test_prompt_manager_repository_info(prompt_dir):
# Test RepositoryInfo defaults
repo_info = RepositoryInfo()
assert repo_info.repo_name is None
assert repo_info.repo_directory is None
# Test setting repository info
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
assert manager.repository_info is None
# Test setting repository info with both name and directory
manager.set_repository_info('owner/repo2', '/workspace/repo2')
assert manager.repository_info.repo_name == 'owner/repo2'
assert manager.repository_info.repo_directory == '/workspace/repo2'
def test_prompt_manager_disabled_microagents(prompt_dir):
# Create test microagent files
microagent1_name = 'test_microagent1'
microagent2_name = 'test_microagent2'
microagent1_content = """
---
name: Test Microagent 1
type: knowledge
agent: CodeActAgent
triggers:
- test1
---
Test microagent 1 content
"""
microagent2_content = """
---
name: Test Microagent 2
type: knowledge
agent: CodeActAgent
triggers:
- test2
---
Test microagent 2 content
"""
# Create temporary micro agent files
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(os.path.join(prompt_dir, 'micro', f'{microagent1_name}.md'), 'w') as f:
f.write(microagent1_content)
with open(os.path.join(prompt_dir, 'micro', f'{microagent2_name}.md'), 'w') as f:
f.write(microagent2_content)
# Test that specific microagents can be disabled
manager = PromptManager(
prompt_dir=prompt_dir,
microagent_dir=os.path.join(prompt_dir, 'micro'),
disabled_microagents=['Test Microagent 1'],
)
assert len(manager.knowledge_microagents) == 1
assert 'Test Microagent 2' in manager.knowledge_microagents
assert 'Test Microagent 1' not in manager.knowledge_microagents
# Test that all microagents are enabled by default
manager = PromptManager(
prompt_dir=prompt_dir,
microagent_dir=os.path.join(prompt_dir, 'micro'),
)
assert len(manager.knowledge_microagents) == 2
assert 'Test Microagent 1' in manager.knowledge_microagents
assert 'Test Microagent 2' in manager.knowledge_microagents
# Clean up temporary files
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent1_name}.md'))
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent2_name}.md'))
def test_enhance_message_with_multiple_text_contents(prompt_dir):
# Create a test microagent that triggers on a specific keyword
microagent_name = 'keyword_microagent'
microagent_content = """
---
name: KeywordMicroAgent
type: knowledge
agent: CodeActAgent
triggers:
- triggerkeyword
---
This is special information about the triggerkeyword.
"""
# Create the microagent file
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
f.write(microagent_content)
manager = PromptManager(
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
)
# Test that it matches the trigger in the last TextContent
message = Message(
role='user',
content=[
TextContent(text='This is some initial context.'),
TextContent(text='This is a message without triggers.'),
TextContent(text='This contains the triggerkeyword that should match.'),
],
)
manager.enhance_message(message)
# Should have added a TextContent with the microagent info at the beginning
assert len(message.content) == 4
assert 'special information about the triggerkeyword' in message.content[0].text
# Clean up
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
def test_enhance_message_with_image_content(prompt_dir):
# Create a test microagent that triggers on a specific keyword
microagent_name = 'image_test_microagent'
microagent_content = """
---
name: ImageTestMicroAgent
type: knowledge
agent: CodeActAgent
triggers:
- imagekeyword
---
This is information related to imagekeyword.
"""
# Create the microagent file
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
f.write(microagent_content)
manager = PromptManager(
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
)
# Test with mix of ImageContent and TextContent
message = Message(
role='user',
content=[
TextContent(text='This is some initial text.'),
ImageContent(image_urls=['https://example.com/image.jpg']),
TextContent(text='This mentions imagekeyword that should match.'),
],
)
manager.enhance_message(message)
# Should have added a TextContent with the microagent info at the beginning
assert len(message.content) == 4
assert 'information related to imagekeyword' in message.content[0].text
# Clean up
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
def test_enhance_message_with_only_image_content(prompt_dir):
# Create a test microagent
microagent_name = 'image_only_microagent'
microagent_content = """
---
name: ImageOnlyMicroAgent
type: knowledge
agent: CodeActAgent
triggers:
- anytrigger
---
This should not appear in the enhanced message.
"""
# Create the microagent file
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
f.write(microagent_content)
manager = PromptManager(
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
)
# Test with only ImageContent
message = Message(
role='user',
content=[
ImageContent(
image_urls=[
'https://example.com/image1.jpg',
'https://example.com/image2.jpg',
]
),
],
)
# Should not raise any exceptions
manager.enhance_message(message)
# Should not have added any content
assert len(message.content) == 1
# Clean up
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
def test_enhance_message_with_reversed_order(prompt_dir):
# Create a test microagent
microagent_name = 'reversed_microagent'
microagent_content = """
---
name: ReversedMicroAgent
type: knowledge
agent: CodeActAgent
triggers:
- lasttrigger
---
This is specific information about the lasttrigger.
"""
# Create the microagent file
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
f.write(microagent_content)
manager = PromptManager(
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
)
# Test where the text content is not at the end of the list
message = Message(
role='user',
content=[
ImageContent(image_urls=['https://example.com/image1.jpg']),
TextContent(text='This contains the lasttrigger word.'),
ImageContent(image_urls=['https://example.com/image2.jpg']),
],
)
manager.enhance_message(message)
# Should have added a TextContent with the microagent info at the beginning
assert len(message.content) == 4
assert isinstance(message.content[0], TextContent)
assert 'specific information about the lasttrigger' in message.content[0].text
# Clean up
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
def test_enhance_message_with_empty_content(prompt_dir):
# Create a test microagent
microagent_name = 'empty_microagent'
microagent_content = """
---
name: EmptyMicroAgent
type: knowledge
agent: CodeActAgent
triggers:
- emptytrigger
---
This should not appear in the enhanced message.
"""
# Create the microagent file
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
f.write(microagent_content)
manager = PromptManager(
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
)
# Test with empty content
message = Message(role='user', content=[])
# Should not raise any exceptions
manager.enhance_message(message)
# Should not have added any content
assert len(message.content) == 0
# Clean up
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
def test_prompt_manager_file_not_found(prompt_dir):
"""Test PromptManager behavior when a template file is not found."""
# Test with a non-existent template
with pytest.raises(FileNotFoundError):
BaseMicroAgent.load(
os.path.join(prompt_dir, 'micro', 'non_existent_microagent.md')
)
def test_build_microagent_info(prompt_dir):
@ -429,33 +85,25 @@ def test_build_microagent_info(prompt_dir):
with open(template_path, 'w') as f:
f.write("""{% for agent_info in triggered_agents %}
<EXTRA_INFO>
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
The following information has been included based on a keyword match for "{{ agent_info.trigger }}".
It may or may not be relevant to the user's request.
{{ agent_info.agent.content }}
{{ agent_info.content }}
</EXTRA_INFO>
{% endfor %}
""")
# Create test microagents
class MockKnowledgeMicroAgent:
def __init__(self, name, content):
self.name = name
self.content = content
agent1 = MockKnowledgeMicroAgent(
name='test_agent1', content='This is information from agent 1'
)
agent2 = MockKnowledgeMicroAgent(
name='test_agent2', content='This is information from agent 2'
)
# Initialize the PromptManager
manager = PromptManager(prompt_dir=prompt_dir)
# Test with a single triggered agent
triggered_agents = [{'agent': agent1, 'trigger_word': 'keyword1'}]
triggered_agents = [
MicroagentKnowledge(
name='test_agent1',
trigger='keyword1',
content='This is information from agent 1',
)
]
result = manager.build_microagent_info(triggered_agents)
expected = """<EXTRA_INFO>
The following information has been included based on a keyword match for "keyword1".
@ -467,8 +115,16 @@ This is information from agent 1
# Test with multiple triggered agents
triggered_agents = [
{'agent': agent1, 'trigger_word': 'keyword1'},
{'agent': agent2, 'trigger_word': 'keyword2'},
MicroagentKnowledge(
name='test_agent1',
trigger='keyword1',
content='This is information from agent 1',
),
MicroagentKnowledge(
name='test_agent2',
trigger='keyword2',
content='This is information from agent 2',
),
]
result = manager.build_microagent_info(triggered_agents)
expected = """<EXTRA_INFO>
@ -491,71 +147,125 @@ This is information from agent 2
assert result.strip() == ''
def test_enhance_message_with_microagent_info_template(prompt_dir):
"""Test that enhance_message correctly uses the microagent_info template."""
# Prepare a microagent_info.j2 template file if it doesn't exist
template_path = os.path.join(prompt_dir, 'microagent_info.j2')
if not os.path.exists(template_path):
with open(template_path, 'w') as f:
f.write("""{% for agent_info in triggered_agents %}
<EXTRA_INFO>
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
It may or may not be relevant to the user's request.
def test_add_examples_to_initial_message(prompt_dir):
"""Test adding example messages to an initial message."""
# Create a user_prompt.j2 template file
with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f:
f.write('This is an example user message')
{{ agent_info.agent.content }}
</EXTRA_INFO>
{% endfor %}
""")
# Initialize the PromptManager
manager = PromptManager(prompt_dir=prompt_dir)
# Create a test microagent
microagent_name = 'test_trigger_microagent'
microagent_content = """
---
name: test_trigger
type: knowledge
agent: CodeActAgent
triggers:
- test_trigger
---
# Create a message
message = Message(role='user', content=[TextContent(text='Original content')])
This is triggered content for testing the microagent_info template.
"""
# Add examples to the message
manager.add_examples_to_initial_message(message)
# Create the microagent file
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
f.write(microagent_content)
# Initialize the PromptManager with the microagent directory
manager = PromptManager(
prompt_dir=prompt_dir,
microagent_dir=os.path.join(prompt_dir, 'micro'),
)
# Create a message with a trigger keyword
message = Message(
role='user',
content=[
TextContent(text="Here's a message containing the test_trigger keyword")
],
)
# Enhance the message
manager.enhance_message(message)
# The message should now have extra content at the beginning
# Check that the example was added at the beginning
assert len(message.content) == 2
assert isinstance(message.content[0], TextContent)
# Verify the template was correctly rendered
expected_text = """<EXTRA_INFO>
The following information has been included based on a keyword match for "test_trigger".
It may or may not be relevant to the user's request.
This is triggered content for testing the microagent_info template.
</EXTRA_INFO>"""
assert message.content[0].text.strip() == expected_text.strip()
assert message.content[0].text == 'This is an example user message'
assert message.content[1].text == 'Original content'
# Clean up
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
os.remove(os.path.join(prompt_dir, 'user_prompt.j2'))
def test_add_turns_left_reminder(prompt_dir):
"""Test adding turns left reminder to messages."""
# Initialize the PromptManager
manager = PromptManager(prompt_dir=prompt_dir)
# Create a State object with specific iteration values
state = State()
state.iteration = 3
state.max_iterations = 10
# Create a list of messages with a user message
user_message = Message(role='user', content=[TextContent(text='User content')])
assistant_message = Message(
role='assistant', content=[TextContent(text='Assistant content')]
)
messages = [assistant_message, user_message]
# Add turns left reminder
manager.add_turns_left_reminder(messages, state)
# Check that the reminder was added to the latest user message
assert len(user_message.content) == 2
assert (
'ENVIRONMENT REMINDER: You have 7 turns left to complete the task.'
in user_message.content[1].text
)
def test_build_additional_info_with_repo_and_runtime(prompt_dir):
"""Test building additional info with repository and runtime information."""
# Create an additional_info.j2 template file
with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f:
f.write("""
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
</REPOSITORY_INFO>
{% endif %}
{% if repository_instructions %}
<REPOSITORY_INSTRUCTIONS>
{{ repository_instructions }}
</REPOSITORY_INSTRUCTIONS>
{% endif %}
{% if runtime_info and (runtime_info.available_hosts or runtime_info.additional_agent_instructions) -%}
<RUNTIME_INFORMATION>
{% if runtime_info.available_hosts %}
The user has access to the following hosts for accessing a web application,
each of which has a corresponding port:
{% for host, port in runtime_info.available_hosts.items() %}
* {{ host }} (port {{ port }})
{% endfor %}
{% endif %}
{% if runtime_info.additional_agent_instructions %}
{{ runtime_info.additional_agent_instructions }}
{% endif %}
</RUNTIME_INFORMATION>
{% endif %}
""")
# Initialize the PromptManager
manager = PromptManager(prompt_dir=prompt_dir)
# Create repository and runtime information
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
runtime_info = RuntimeInfo(
available_hosts={'example.com': 8080},
additional_agent_instructions='You know everything about this runtime.',
)
repo_instructions = 'This repository contains important code.'
# Build additional info
result = manager.build_additional_info(
repository_info=repo_info,
runtime_info=runtime_info,
repo_instructions=repo_instructions,
)
# Check that all information is included
assert '<REPOSITORY_INFO>' in result
assert 'owner/repo' in result
assert '/workspace/repo' in result
assert '<REPOSITORY_INSTRUCTIONS>' in result
assert 'This repository contains important code.' in result
assert '<RUNTIME_INFORMATION>' in result
assert 'example.com (port 8080)' in result
assert 'You know everything about this runtime.' in result
# Clean up
os.remove(os.path.join(prompt_dir, 'additional_info.j2'))
def test_prompt_manager_initialization_error():
"""Test that PromptManager raises an error if the prompt directory is not set."""
with pytest.raises(ValueError, match='Prompt directory is not set'):
PromptManager(None)