mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add RecallActions and observations for retrieval of prompt extensions (#6909)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
parent
e34a771e66
commit
cc45f5d9c3
@ -1,8 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
import openhands
|
||||
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
@ -74,21 +72,14 @@ class CodeActAgent(Agent):
|
||||
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
|
||||
)
|
||||
logger.debug(
|
||||
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
|
||||
f'TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}'
|
||||
)
|
||||
self.prompt_manager = PromptManager(
|
||||
microagent_dir=os.path.join(
|
||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||
'microagents',
|
||||
)
|
||||
if self.config.enable_prompt_extensions
|
||||
else None,
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
disabled_microagents=self.config.disabled_microagents,
|
||||
)
|
||||
|
||||
# Create a ConversationMemory instance
|
||||
self.conversation_memory = ConversationMemory(self.prompt_manager)
|
||||
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
|
||||
|
||||
self.condenser = Condenser.from_config(self.config.condenser)
|
||||
logger.debug(f'Using condenser: {type(self.condenser)}')
|
||||
@ -168,7 +159,7 @@ class CodeActAgent(Agent):
|
||||
if not self.prompt_manager:
|
||||
raise Exception('Prompt Manager not instantiated.')
|
||||
|
||||
# Use conversation_memory to process events instead of calling events_to_messages directly
|
||||
# Use ConversationMemory to process initial messages
|
||||
messages = self.conversation_memory.process_initial_messages(
|
||||
with_caching=self.llm.is_caching_prompt_active()
|
||||
)
|
||||
@ -180,12 +171,12 @@ class CodeActAgent(Agent):
|
||||
f'Processing {len(events)} events from a total of {len(state.history)} events'
|
||||
)
|
||||
|
||||
# Use ConversationMemory to process events
|
||||
messages = self.conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_messages=messages,
|
||||
max_message_chars=self.llm.config.max_message_chars,
|
||||
vision_is_active=self.llm.vision_is_active(),
|
||||
enable_som_visual_browsing=self.config.enable_som_visual_browsing,
|
||||
)
|
||||
|
||||
messages = self._enhance_messages(messages)
|
||||
@ -216,14 +207,7 @@ class CodeActAgent(Agent):
|
||||
# compose the first user message with examples
|
||||
self.prompt_manager.add_examples_to_initial_message(msg)
|
||||
|
||||
# and/or repo/runtime info
|
||||
if self.config.enable_prompt_extensions:
|
||||
self.prompt_manager.add_info_to_initial_message(msg)
|
||||
|
||||
# enhance the user message with additional context based on keywords matched
|
||||
if msg.role == 'user':
|
||||
self.prompt_manager.enhance_message(msg)
|
||||
|
||||
elif msg.role == 'user':
|
||||
# Add double newline between consecutive user messages
|
||||
if prev_role == 'user' and len(msg.content) > 0:
|
||||
# Find the first TextContent in the message to add newlines
|
||||
|
||||
@ -20,6 +20,8 @@ When starting a web server, use the corresponding ports. You should also
|
||||
set any options to allow iframes and CORS requests, and allow the server to
|
||||
be accessed from any host (e.g. 0.0.0.0).
|
||||
{% endif %}
|
||||
{% if runtime_info.additional_agent_instructions %}
|
||||
{{ runtime_info.additional_agent_instructions }}
|
||||
{% endif %}
|
||||
</RUNTIME_INFORMATION>
|
||||
{% endif %}
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
{% for agent_info in triggered_agents %}
|
||||
<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger }}".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
{{ agent_info.agent.content }}
|
||||
{{ agent_info.content }}
|
||||
</EXTRA_INFO>
|
||||
{% endfor %}
|
||||
|
||||
@ -29,7 +29,12 @@ from openhands.core.exceptions import (
|
||||
from openhands.core.logger import LOG_ALL_EVENTS
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events import (
|
||||
EventSource,
|
||||
EventStream,
|
||||
EventStreamSubscriber,
|
||||
RecallType,
|
||||
)
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
@ -42,6 +47,7 @@ from openhands.events.action import (
|
||||
MessageAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
@ -89,7 +95,7 @@ class AgentController:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
sid: str = 'default',
|
||||
sid: str | None = None,
|
||||
confirmation_mode: bool = False,
|
||||
initial_state: State | None = None,
|
||||
is_delegate: bool = False,
|
||||
@ -116,7 +122,7 @@ class AgentController:
|
||||
status_callback: Optional callback function to handle status updates.
|
||||
replay_events: A list of logs to replay.
|
||||
"""
|
||||
self.id = sid
|
||||
self.id = sid or event_stream.sid
|
||||
self.agent = agent
|
||||
self.headless_mode = headless_mode
|
||||
self.is_delegate = is_delegate
|
||||
@ -287,8 +293,14 @@ class AgentController:
|
||||
return True
|
||||
return False
|
||||
if isinstance(event, Observation):
|
||||
if isinstance(event, NullObservation) or isinstance(
|
||||
event, AgentStateChangedObservation
|
||||
if (
|
||||
isinstance(event, NullObservation)
|
||||
and event.cause is not None
|
||||
and event.cause > 0
|
||||
):
|
||||
return True
|
||||
if isinstance(event, AgentStateChangedObservation) or isinstance(
|
||||
event, NullObservation
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@ -388,6 +400,7 @@ class AgentController:
|
||||
if observation.llm_metrics is not None:
|
||||
self.agent.llm.metrics.merge(observation.llm_metrics)
|
||||
|
||||
# this happens for runnable actions and microagent actions
|
||||
if self._pending_action and self._pending_action.id == observation.cause:
|
||||
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||
return
|
||||
@ -431,6 +444,25 @@ class AgentController:
|
||||
'debug',
|
||||
f'Extended max iterations to {self.state.max_iterations} after user message',
|
||||
)
|
||||
# try to retrieve microagents relevant to the user message
|
||||
# set pending_action while we search for information
|
||||
|
||||
# if this is the first user message for this agent, matters for the microagent info type
|
||||
first_user_message = self._first_user_message()
|
||||
is_first_user_message = (
|
||||
action.id == first_user_message.id if first_user_message else False
|
||||
)
|
||||
recall_type = (
|
||||
RecallType.WORKSPACE_CONTEXT
|
||||
if is_first_user_message
|
||||
else RecallType.KNOWLEDGE
|
||||
)
|
||||
|
||||
recall_action = RecallAction(query=action.content, recall_type=recall_type)
|
||||
self._pending_action = recall_action
|
||||
# this is source=USER because the user message is the trigger for the microagent retrieval
|
||||
self.event_stream.add_event(recall_action, EventSource.USER)
|
||||
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
elif action.source == EventSource.AGENT and action.wait_for_response:
|
||||
@ -438,6 +470,7 @@ class AgentController:
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Resets the agent controller"""
|
||||
# Runnable actions need an Observation
|
||||
# make sure there is an Observation with the tool call metadata to be recognized by the agent
|
||||
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
|
||||
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
|
||||
@ -459,6 +492,8 @@ class AgentController:
|
||||
obs._cause = self._pending_action.id # type: ignore[attr-defined]
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
|
||||
# NOTE: RecallActions don't need an ErrorObservation upon reset, as long as they have no tool calls
|
||||
|
||||
# reset the pending action, this will be called when the agent is STOPPED or ERROR
|
||||
self._pending_action = None
|
||||
self.agent.reset()
|
||||
@ -1146,3 +1181,26 @@ class AgentController:
|
||||
result = event.agent_state == AgentState.RUNNING
|
||||
return result
|
||||
return False
|
||||
|
||||
def _first_user_message(self) -> MessageAction | None:
|
||||
"""
|
||||
Get the first user message for this agent.
|
||||
|
||||
For regular agents, this is the first user message from the beginning (start_id=0).
|
||||
For delegate agents, this is the first user message after the delegate's start_id.
|
||||
|
||||
Returns:
|
||||
MessageAction | None: The first user message, or None if no user message found
|
||||
"""
|
||||
# Find the first user message from the appropriate starting point
|
||||
user_messages = list(self.event_stream.get_events(start_id=self.state.start_id))
|
||||
|
||||
# Get and return the first user message
|
||||
return next(
|
||||
(
|
||||
e
|
||||
for e in user_messages
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@ -135,7 +135,7 @@ class StuckDetector:
|
||||
# it takes 3 actions and 3 observations to detect a loop
|
||||
# check if the last three actions are the same and result in errors
|
||||
|
||||
if len(last_actions) < 4 or len(last_observations) < 4:
|
||||
if len(last_actions) < 3 or len(last_observations) < 3:
|
||||
return False
|
||||
|
||||
# are the last three actions the "same"?
|
||||
|
||||
@ -17,6 +17,7 @@ from openhands.core.schema import AgentState
|
||||
from openhands.core.setup import (
|
||||
create_agent,
|
||||
create_controller,
|
||||
create_memory,
|
||||
create_runtime,
|
||||
initialize_repository_for_runtime,
|
||||
)
|
||||
@ -170,13 +171,22 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
await runtime.connect()
|
||||
|
||||
# Initialize repository if needed
|
||||
repo_directory = None
|
||||
if config.sandbox.selected_repo:
|
||||
initialize_repository_for_runtime(
|
||||
repo_directory = initialize_repository_for_runtime(
|
||||
runtime,
|
||||
agent=agent,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
)
|
||||
|
||||
# when memory is created, it will load the microagents from the selected repository
|
||||
memory = create_memory(
|
||||
runtime=runtime,
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
repo_directory=repo_directory,
|
||||
)
|
||||
|
||||
if initial_user_action:
|
||||
# If there's an initial user action, enqueue it and do not prompt again
|
||||
event_stream.add_event(initial_user_action, EventSource.USER)
|
||||
@ -185,7 +195,7 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
asyncio.create_task(prompt_for_next_task())
|
||||
|
||||
await run_agent_until_done(
|
||||
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
|
||||
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -3,12 +3,14 @@ import asyncio
|
||||
from openhands.controller import AgentController
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
|
||||
async def run_agent_until_done(
|
||||
controller: AgentController,
|
||||
runtime: Runtime,
|
||||
memory: Memory,
|
||||
end_states: list[AgentState],
|
||||
):
|
||||
"""
|
||||
@ -37,6 +39,7 @@ async def run_agent_until_done(
|
||||
|
||||
runtime.status_callback = status_callback
|
||||
controller.status_callback = status_callback
|
||||
memory.status_callback = status_callback
|
||||
|
||||
while controller.state.agent_state not in end_states:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@ -18,6 +18,7 @@ from openhands.core.schema import AgentState
|
||||
from openhands.core.setup import (
|
||||
create_agent,
|
||||
create_controller,
|
||||
create_memory,
|
||||
create_runtime,
|
||||
generate_sid,
|
||||
initialize_repository_for_runtime,
|
||||
@ -29,6 +30,7 @@ from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.events.serialization import event_from_dict
|
||||
from openhands.io import read_input, read_task
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
@ -51,6 +53,7 @@ async def run_controller(
|
||||
exit_on_message: bool = False,
|
||||
fake_user_response_fn: FakeUserResponseFunc | None = None,
|
||||
headless_mode: bool = True,
|
||||
memory: Memory | None = None,
|
||||
) -> State | None:
|
||||
"""Main coroutine to run the agent controller with task input flexibility.
|
||||
|
||||
@ -93,6 +96,8 @@ async def run_controller(
|
||||
if agent is None:
|
||||
agent = create_agent(config)
|
||||
|
||||
# when the runtime is created, it will be connected and clone the selected repository
|
||||
repo_directory = None
|
||||
if runtime is None:
|
||||
runtime = create_runtime(
|
||||
config,
|
||||
@ -105,14 +110,23 @@ async def run_controller(
|
||||
|
||||
# Initialize repository if needed
|
||||
if config.sandbox.selected_repo:
|
||||
initialize_repository_for_runtime(
|
||||
repo_directory = initialize_repository_for_runtime(
|
||||
runtime,
|
||||
agent=agent,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
|
||||
# when memory is created, it will load the microagents from the selected repository
|
||||
if memory is None:
|
||||
memory = create_memory(
|
||||
runtime=runtime,
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
repo_directory=repo_directory,
|
||||
)
|
||||
|
||||
replay_events: list[Event] | None = None
|
||||
if config.replay_trajectory_path:
|
||||
logger.info('Trajectory replay is enabled')
|
||||
@ -172,7 +186,7 @@ async def run_controller(
|
||||
]
|
||||
|
||||
try:
|
||||
await run_agent_until_done(controller, runtime, end_states)
|
||||
await run_agent_until_done(controller, runtime, memory, end_states)
|
||||
except Exception as e:
|
||||
logger.error(f'Exception in main loop: {e}')
|
||||
|
||||
|
||||
@ -82,5 +82,8 @@ class ActionTypeSchema(BaseModel):
|
||||
SEND_PR: str = Field(default='send_pr')
|
||||
"""Send a PR to github."""
|
||||
|
||||
RECALL: str = Field(default='recall')
|
||||
"""Retrieves content from a user workspace, microagent, or other source."""
|
||||
|
||||
|
||||
ActionType = ActionTypeSchema()
|
||||
|
||||
@ -49,5 +49,8 @@ class ObservationTypeSchema(BaseModel):
|
||||
CONDENSE: str = Field(default='condense')
|
||||
"""Result of a condensation operation."""
|
||||
|
||||
MICROAGENT: str = Field(default='microagent')
|
||||
"""Result of a microagent retrieval operation."""
|
||||
|
||||
|
||||
ObservationType = ObservationTypeSchema()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Tuple, Type
|
||||
from typing import Callable, Tuple, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
@ -16,6 +16,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroAgent
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
@ -83,7 +84,6 @@ def create_runtime(
|
||||
|
||||
def initialize_repository_for_runtime(
|
||||
runtime: Runtime,
|
||||
agent: Agent | None = None,
|
||||
selected_repository: str | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
) -> str | None:
|
||||
@ -91,7 +91,6 @@ def initialize_repository_for_runtime(
|
||||
|
||||
Args:
|
||||
runtime: The runtime to initialize the repository for.
|
||||
agent: (optional) The agent to load microagents for.
|
||||
selected_repository: (optional) The GitHub repository to use.
|
||||
github_token: (optional) The GitHub token to use.
|
||||
|
||||
@ -99,10 +98,10 @@ def initialize_repository_for_runtime(
|
||||
The repository directory path if a repository was cloned, None otherwise.
|
||||
"""
|
||||
# clone selected repository if provided
|
||||
repo_directory = None
|
||||
github_token = (
|
||||
SecretStr(os.environ.get('GITHUB_TOKEN')) if not github_token else github_token
|
||||
)
|
||||
repo_directory = None
|
||||
if selected_repository and github_token:
|
||||
logger.debug(f'Selected repository {selected_repository}.')
|
||||
repo_directory = runtime.clone_repo(
|
||||
@ -111,16 +110,47 @@ def initialize_repository_for_runtime(
|
||||
None,
|
||||
)
|
||||
|
||||
# load microagents from selected repository
|
||||
if agent and agent.prompt_manager and selected_repository and repo_directory:
|
||||
agent.prompt_manager.set_runtime_info(runtime)
|
||||
return repo_directory
|
||||
|
||||
|
||||
def create_memory(
|
||||
runtime: Runtime,
|
||||
event_stream: EventStream,
|
||||
sid: str,
|
||||
selected_repository: str | None = None,
|
||||
repo_directory: str | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
) -> Memory:
|
||||
"""Create a memory for the agent to use.
|
||||
|
||||
Args:
|
||||
runtime: The runtime to use.
|
||||
event_stream: The event stream it will subscribe to.
|
||||
sid: The session id.
|
||||
selected_repository: The repository to clone and start with, if any.
|
||||
repo_directory: The repository directory, if any.
|
||||
status_callback: Optional callback function to handle status updates.
|
||||
"""
|
||||
memory = Memory(
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
status_callback=status_callback,
|
||||
)
|
||||
|
||||
if runtime:
|
||||
# sets available hosts
|
||||
memory.set_runtime_info(runtime)
|
||||
|
||||
# loads microagents from repo/.openhands/microagents
|
||||
microagents: list[BaseMicroAgent] = runtime.get_microagents_from_selected_repo(
|
||||
selected_repository
|
||||
)
|
||||
agent.prompt_manager.load_microagents(microagents)
|
||||
agent.prompt_manager.set_repository_info(selected_repository, repo_directory)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
|
||||
return repo_directory
|
||||
if selected_repository and repo_directory:
|
||||
memory.set_repository_info(selected_repository, repo_directory)
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
def create_agent(config: AppConfig) -> Agent:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.event import Event, EventSource, RecallType
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
|
||||
__all__ = [
|
||||
@ -6,4 +6,5 @@ __all__ = [
|
||||
'EventSource',
|
||||
'EventStream',
|
||||
'EventStreamSubscriber',
|
||||
'RecallType',
|
||||
]
|
||||
|
||||
@ -6,6 +6,7 @@ from openhands.events.action.agent import (
|
||||
AgentSummarizeAction,
|
||||
AgentThinkAction,
|
||||
ChangeAgentStateAction,
|
||||
RecallAction,
|
||||
)
|
||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
|
||||
@ -35,4 +36,5 @@ __all__ = [
|
||||
'MessageAction',
|
||||
'ActionConfirmationStatus',
|
||||
'AgentThinkAction',
|
||||
'RecallAction',
|
||||
]
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.event import RecallType
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -106,3 +107,22 @@ class AgentDelegateAction(Action):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f"I'm asking {self.agent} for help with this task."
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecallAction(Action):
|
||||
"""This action is used for retrieving content, e.g., from the global directory or user workspace."""
|
||||
|
||||
recall_type: RecallType
|
||||
query: str = ''
|
||||
thought: str = ''
|
||||
action: str = ActionType.RECALL
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Retrieving content for: {self.query[:50]}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**RecallAction**\n'
|
||||
ret += f'QUERY: {self.query[:50]}'
|
||||
return ret
|
||||
|
||||
@ -22,6 +22,16 @@ class FileReadSource(str, Enum):
|
||||
DEFAULT = 'default'
|
||||
|
||||
|
||||
class RecallType(str, Enum):
|
||||
"""The type of information that can be retrieved from microagents."""
|
||||
|
||||
WORKSPACE_CONTEXT = 'workspace_context'
|
||||
"""Workspace context (repo instructions, runtime, etc.)"""
|
||||
|
||||
KNOWLEDGE = 'knowledge'
|
||||
"""A knowledge microagent."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
INVALID_ID = -1
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
AgentThinkObservation,
|
||||
MicroagentObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
@ -40,4 +42,6 @@ __all__ = [
|
||||
'SuccessObservation',
|
||||
'UserRejectObservation',
|
||||
'AgentCondensationObservation',
|
||||
'MicroagentObservation',
|
||||
'RecallType',
|
||||
]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@ -40,3 +41,76 @@ class AgentThinkObservation(Observation):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroagentKnowledge:
|
||||
"""
|
||||
Represents knowledge from a triggered microagent.
|
||||
|
||||
Attributes:
|
||||
name: The name of the microagent that was triggered
|
||||
trigger: The word that triggered this microagent
|
||||
content: The actual content/knowledge from the microagent
|
||||
"""
|
||||
|
||||
name: str
|
||||
trigger: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroagentObservation(Observation):
|
||||
"""The retrieval of content from a microagent or more microagents."""
|
||||
|
||||
recall_type: RecallType
|
||||
observation: str = ObservationType.MICROAGENT
|
||||
|
||||
# environment
|
||||
repo_name: str = ''
|
||||
repo_directory: str = ''
|
||||
repo_instructions: str = ''
|
||||
runtime_hosts: dict[str, int] = field(default_factory=dict)
|
||||
additional_agent_instructions: str = ''
|
||||
|
||||
# knowledge
|
||||
microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list)
|
||||
"""
|
||||
A list of MicroagentKnowledge objects, each containing information from a triggered microagent.
|
||||
|
||||
Example:
|
||||
[
|
||||
MicroagentKnowledge(
|
||||
name="python_best_practices",
|
||||
trigger="python",
|
||||
content="Always use virtual environments for Python projects."
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name="git_workflow",
|
||||
trigger="git",
|
||||
content="Create a new branch for each feature or bugfix."
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Build a string representation of all fields
|
||||
fields = [
|
||||
f'recall_type={self.recall_type}',
|
||||
f'repo_name={self.repo_name}',
|
||||
f'repo_instructions={self.repo_instructions[:20]}...',
|
||||
f'runtime_hosts={self.runtime_hosts}',
|
||||
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
|
||||
]
|
||||
|
||||
# Only include microagent_knowledge if it's not empty
|
||||
if self.microagent_knowledge:
|
||||
fields.append(
|
||||
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}'
|
||||
)
|
||||
|
||||
return f'**MicroagentObservation**\n{", ".join(fields)}'
|
||||
|
||||
@ -8,6 +8,7 @@ from openhands.events.action.agent import (
|
||||
AgentRejectAction,
|
||||
AgentThinkAction,
|
||||
ChangeAgentStateAction,
|
||||
RecallAction,
|
||||
)
|
||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.action.commands import (
|
||||
@ -35,6 +36,7 @@ actions = (
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
AgentDelegateAction,
|
||||
RecallAction,
|
||||
ChangeAgentStateAction,
|
||||
MessageAction,
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -102,6 +103,8 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
d['timestamp'] = d['timestamp'].isoformat()
|
||||
if key == 'source' and 'source' in d:
|
||||
d['source'] = d['source'].value
|
||||
if key == 'recall_type' and 'recall_type' in d:
|
||||
d['recall_type'] = d['recall_type'].value
|
||||
if key == 'tool_call_metadata' and 'tool_call_metadata' in d:
|
||||
d['tool_call_metadata'] = d['tool_call_metadata'].model_dump()
|
||||
if key == 'llm_metrics' and 'llm_metrics' in d:
|
||||
@ -119,7 +122,11 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
# props is a dict whose values can include a complex object like an instance of a BaseModel subclass
|
||||
# such as CmdOutputMetadata
|
||||
# we serialize it along with the rest
|
||||
d['extras'] = {k: _convert_pydantic_to_dict(v) for k, v in props.items()}
|
||||
# we also handle the Enum conversion for MicroagentObservation
|
||||
d['extras'] = {
|
||||
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
|
||||
for k, v in props.items()
|
||||
}
|
||||
# Include success field for CmdOutputObservation
|
||||
if hasattr(event, 'success'):
|
||||
d['success'] = event.success
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import copy
|
||||
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
AgentThinkObservation,
|
||||
MicroagentKnowledge,
|
||||
MicroagentObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
@ -40,6 +43,7 @@ observations = (
|
||||
UserRejectObservation,
|
||||
AgentCondensationObservation,
|
||||
AgentThinkObservation,
|
||||
MicroagentObservation,
|
||||
)
|
||||
|
||||
OBSERVATION_TYPE_TO_CLASS = {
|
||||
@ -110,4 +114,18 @@ def observation_from_dict(observation: dict) -> Observation:
|
||||
else:
|
||||
extras['metadata'] = CmdOutputMetadata()
|
||||
|
||||
if observation_class is MicroagentObservation:
|
||||
# handle the Enum conversion
|
||||
if 'recall_type' in extras:
|
||||
extras['recall_type'] = RecallType(extras['recall_type'])
|
||||
|
||||
# convert dicts in microagent_knowledge to MicroagentKnowledge objects
|
||||
if 'microagent_knowledge' in extras and isinstance(
|
||||
extras['microagent_knowledge'], list
|
||||
):
|
||||
extras['microagent_knowledge'] = [
|
||||
MicroagentKnowledge(**item) if isinstance(item, dict) else item
|
||||
for item in extras['microagent_knowledge']
|
||||
]
|
||||
|
||||
return observation_class(content=content, **extras)
|
||||
|
||||
@ -27,6 +27,7 @@ class EventStreamSubscriber(str, Enum):
|
||||
RESOLVER = 'openhands_resolver'
|
||||
SERVER = 'server'
|
||||
RUNTIME = 'runtime'
|
||||
MEMORY = 'memory'
|
||||
MAIN = 'main'
|
||||
TEST = 'test'
|
||||
|
||||
|
||||
@ -249,7 +249,8 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
# if we mocked function calling, and we have tools, convert the response back to function calling format
|
||||
if mock_function_calling and mock_fncall_tools is not None:
|
||||
assert len(resp.choices) == 1
|
||||
logger.debug(f'Response choices: {len(resp.choices)}')
|
||||
assert len(resp.choices) >= 1
|
||||
non_fncall_response_message = resp.choices[0].message
|
||||
fn_call_messages_with_response = (
|
||||
convert_non_fncall_messages_to_fncall_messages(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from litellm import ModelResponse
|
||||
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.core.schema import ActionType
|
||||
@ -16,7 +17,7 @@ from openhands.events.action import (
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
AgentDelegateObservation,
|
||||
@ -28,16 +29,21 @@ from openhands.events.observation import (
|
||||
IPythonRunCellObservation,
|
||||
UserRejectObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
MicroagentObservation,
|
||||
)
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.utils.prompt import PromptManager
|
||||
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
|
||||
|
||||
|
||||
class ConversationMemory:
|
||||
"""Processes event history into a coherent conversation for the agent."""
|
||||
|
||||
def __init__(self, prompt_manager: PromptManager):
|
||||
def __init__(self, config: AgentConfig, prompt_manager: PromptManager):
|
||||
self.agent_config = config
|
||||
self.prompt_manager = prompt_manager
|
||||
|
||||
def process_events(
|
||||
@ -53,14 +59,14 @@ class ConversationMemory:
|
||||
Ensures that tool call actions are processed correctly in function calling mode.
|
||||
|
||||
Args:
|
||||
state: The state containing the history of events to convert
|
||||
condensed_history: The condensed list of events to process
|
||||
initial_messages: The initial messages to include in the result
|
||||
condensed_history: The condensed history of events to convert
|
||||
initial_messages: The initial messages to include in the conversation
|
||||
max_message_chars: The maximum number of characters in the content of an event included
|
||||
in the prompt to the LLM. Larger observations are truncated.
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model.
|
||||
"""
|
||||
|
||||
events = condensed_history
|
||||
|
||||
# Process special events first (system prompts, etc.)
|
||||
@ -70,7 +76,7 @@ class ConversationMemory:
|
||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||
tool_call_id_to_message: dict[str, Message] = {}
|
||||
|
||||
for event in events:
|
||||
for i, event in enumerate(events):
|
||||
# create a regular message from an event
|
||||
if isinstance(event, Action):
|
||||
messages_to_add = self._process_action(
|
||||
@ -84,7 +90,9 @@ class ConversationMemory:
|
||||
tool_call_id_to_message=tool_call_id_to_message,
|
||||
max_message_chars=max_message_chars,
|
||||
vision_is_active=vision_is_active,
|
||||
enable_som_visual_browsing=enable_som_visual_browsing,
|
||||
enable_som_visual_browsing=self.agent_config.enable_som_visual_browsing,
|
||||
current_index=i,
|
||||
events=events,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown event type: {type(event)}')
|
||||
@ -270,6 +278,8 @@ class ConversationMemory:
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
enable_som_visual_browsing: bool = False,
|
||||
current_index: int = 0,
|
||||
events: list[Event] | None = None,
|
||||
) -> list[Message]:
|
||||
"""Converts an observation into a message format that can be sent to the LLM.
|
||||
|
||||
@ -291,6 +301,8 @@ class ConversationMemory:
|
||||
max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
|
||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model
|
||||
current_index: The index of the current event in the events list (for deduplication)
|
||||
events: The list of all events (for deduplication)
|
||||
|
||||
Returns:
|
||||
list[Message]: A list containing the formatted message(s) for the observation.
|
||||
@ -372,6 +384,92 @@ class ConversationMemory:
|
||||
elif isinstance(obs, AgentCondensationObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif (
|
||||
isinstance(obs, MicroagentObservation)
|
||||
and self.agent_config.enable_prompt_extensions
|
||||
):
|
||||
if obs.recall_type == RecallType.WORKSPACE_CONTEXT:
|
||||
# everything is optional, check if they are present
|
||||
repo_info = (
|
||||
RepositoryInfo(
|
||||
repo_name=obs.repo_name or '',
|
||||
repo_directory=obs.repo_directory or '',
|
||||
)
|
||||
if obs.repo_name or obs.repo_directory
|
||||
else None
|
||||
)
|
||||
if obs.runtime_hosts or obs.additional_agent_instructions:
|
||||
runtime_info = RuntimeInfo(
|
||||
available_hosts=obs.runtime_hosts,
|
||||
additional_agent_instructions=obs.additional_agent_instructions,
|
||||
)
|
||||
else:
|
||||
runtime_info = None
|
||||
|
||||
repo_instructions = (
|
||||
obs.repo_instructions if obs.repo_instructions else ''
|
||||
)
|
||||
|
||||
# Have some meaningful content before calling the template
|
||||
has_repo_info = repo_info is not None and (
|
||||
repo_info.repo_name or repo_info.repo_directory
|
||||
)
|
||||
has_runtime_info = runtime_info is not None and (
|
||||
runtime_info.available_hosts
|
||||
or runtime_info.additional_agent_instructions
|
||||
)
|
||||
has_repo_instructions = bool(repo_instructions.strip())
|
||||
|
||||
# Build additional info if we have something to render
|
||||
if has_repo_info or has_runtime_info or has_repo_instructions:
|
||||
# ok, now we can build the additional info
|
||||
formatted_text = self.prompt_manager.build_additional_info(
|
||||
repository_info=repo_info,
|
||||
runtime_info=runtime_info,
|
||||
repo_instructions=repo_instructions,
|
||||
)
|
||||
message = Message(
|
||||
role='user', content=[TextContent(text=formatted_text)]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
elif obs.recall_type == RecallType.KNOWLEDGE:
|
||||
# Use prompt manager to build the microagent info
|
||||
# First, filter out agents that appear in earlier MicroagentObservations
|
||||
filtered_agents = self._filter_agents_in_microagent_obs(
|
||||
obs, current_index, events or []
|
||||
)
|
||||
|
||||
# Create and return a message if there is microagent knowledge to include
|
||||
if filtered_agents:
|
||||
# Exclude disabled microagents
|
||||
filtered_agents = [
|
||||
agent
|
||||
for agent in filtered_agents
|
||||
if agent.name not in self.agent_config.disabled_microagents
|
||||
]
|
||||
|
||||
# Only proceed if we still have agents after filtering out disabled ones
|
||||
if filtered_agents:
|
||||
formatted_text = self.prompt_manager.build_microagent_info(
|
||||
triggered_agents=filtered_agents,
|
||||
)
|
||||
|
||||
return [
|
||||
Message(
|
||||
role='user', content=[TextContent(text=formatted_text)]
|
||||
)
|
||||
]
|
||||
|
||||
# Return empty list if no microagents to include or all were disabled
|
||||
return []
|
||||
elif (
|
||||
isinstance(obs, MicroagentObservation)
|
||||
and not self.agent_config.enable_prompt_extensions
|
||||
):
|
||||
# If prompt extensions are disabled, we don't add any additional info
|
||||
# TODO: test this
|
||||
return []
|
||||
else:
|
||||
# If an observation message is not returned, it will cause an error
|
||||
# when the LLM tries to return the next message
|
||||
@ -404,3 +502,53 @@ class ConversationMemory:
|
||||
-1
|
||||
].cache_prompt = True # Last item inside the message content
|
||||
break
|
||||
|
||||
def _filter_agents_in_microagent_obs(
|
||||
self, obs: MicroagentObservation, current_index: int, events: list[Event]
|
||||
) -> list[MicroagentKnowledge]:
|
||||
"""Filter out agents that appear in earlier MicroagentObservations.
|
||||
|
||||
Args:
|
||||
obs: The current MicroagentObservation to filter
|
||||
current_index: The index of the current event in the events list
|
||||
events: The list of all events
|
||||
|
||||
Returns:
|
||||
list[MicroagentKnowledge]: The filtered list of microagent knowledge
|
||||
"""
|
||||
if obs.recall_type != RecallType.KNOWLEDGE:
|
||||
return obs.microagent_knowledge
|
||||
|
||||
# For each agent in the current microagent observation, check if it appears in any earlier microagent observation
|
||||
filtered_agents = []
|
||||
for agent in obs.microagent_knowledge:
|
||||
# Keep this agent if it doesn't appear in any earlier observation
|
||||
# that is, if this is the first microagent observation with this microagent
|
||||
if not self._has_agent_in_earlier_events(agent.name, current_index, events):
|
||||
filtered_agents.append(agent)
|
||||
|
||||
return filtered_agents
|
||||
|
||||
def _has_agent_in_earlier_events(
|
||||
self, agent_name: str, current_index: int, events: list[Event]
|
||||
) -> bool:
|
||||
"""Check if an agent appears in any earlier MicroagentObservation in the event list.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the agent to look for
|
||||
current_index: The index of the current event in the events list
|
||||
events: The list of all events
|
||||
|
||||
Returns:
|
||||
bool: True if the agent appears in an earlier MicroagentObservation, False otherwise
|
||||
"""
|
||||
for event in events[:current_index]:
|
||||
if (
|
||||
isinstance(event, MicroagentObservation)
|
||||
and event.recall_type == RecallType.KNOWLEDGE
|
||||
):
|
||||
if any(
|
||||
agent.name == agent_name for agent in event.microagent_knowledge
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
270
openhands/memory/memory.py
Normal file
270
openhands/memory/memory.py
Normal file
@ -0,0 +1,270 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
import openhands
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event, EventSource, RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
MicroagentObservation,
|
||||
)
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
from openhands.microagent import (
|
||||
BaseMicroAgent,
|
||||
KnowledgeMicroAgent,
|
||||
RepoMicroAgent,
|
||||
load_microagents_from_dir,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.prompt import RepositoryInfo, RuntimeInfo
|
||||
|
||||
GLOBAL_MICROAGENTS_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||
'microagents',
|
||||
)
|
||||
|
||||
|
||||
class Memory:
|
||||
"""
|
||||
Memory is a component that listens to the EventStream for information retrieval actions
|
||||
(a RecallAction) and publishes observations with the content (such as MicroagentObservation).
|
||||
"""
|
||||
|
||||
sid: str
|
||||
event_stream: EventStream
|
||||
status_callback: Callable | None
|
||||
loop: asyncio.AbstractEventLoop | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_stream: EventStream,
|
||||
sid: str,
|
||||
status_callback: Callable | None = None,
|
||||
):
|
||||
self.event_stream = event_stream
|
||||
self.sid = sid if sid else str(uuid.uuid4())
|
||||
self.status_callback = status_callback
|
||||
self.loop = None
|
||||
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY,
|
||||
self.on_event,
|
||||
self.sid,
|
||||
)
|
||||
|
||||
# Additional placeholders to store user workspace microagents
|
||||
self.repo_microagents: dict[str, RepoMicroAgent] = {}
|
||||
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}
|
||||
|
||||
# Store repository / runtime info to send them to the templating later
|
||||
self.repository_info: RepositoryInfo | None = None
|
||||
self.runtime_info: RuntimeInfo | None = None
|
||||
|
||||
# Load global microagents (Knowledge + Repo)
|
||||
# from typically OpenHands/microagents (i.e., the PUBLIC microagents)
|
||||
self._load_global_microagents()
|
||||
|
||||
def on_event(self, event: Event):
|
||||
"""Handle an event from the event stream."""
|
||||
asyncio.get_event_loop().run_until_complete(self._on_event(event))
|
||||
|
||||
async def _on_event(self, event: Event):
|
||||
"""Handle an event from the event stream asynchronously."""
|
||||
try:
|
||||
observation: MicroagentObservation | NullObservation | None = None
|
||||
|
||||
if isinstance(event, RecallAction):
|
||||
# if this is a workspace context recall (on first user message)
|
||||
# create and add a MicroagentObservation
|
||||
# with info about repo and runtime.
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and event.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
):
|
||||
observation = self._on_first_microagent_action(event)
|
||||
|
||||
# continue with the next handler, to include knowledge microagents if suitable for this query
|
||||
assert observation is None or isinstance(
|
||||
observation, MicroagentObservation
|
||||
), f'Expected a MicroagentObservation, but got {type(observation)}'
|
||||
observation = self._on_microagent_action(
|
||||
event, prev_observation=observation
|
||||
)
|
||||
|
||||
if observation is None:
|
||||
observation = NullObservation(content='')
|
||||
|
||||
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||
observation._cause = event.id # type: ignore[union-attr]
|
||||
|
||||
self.event_stream.add_event(observation, EventSource.ENVIRONMENT)
|
||||
except Exception as e:
|
||||
error_str = f'Error: {str(e.__class__.__name__)}'
|
||||
logger.error(error_str)
|
||||
self.send_error_message('STATUS$ERROR_MEMORY', error_str)
|
||||
return
|
||||
|
||||
def _on_first_microagent_action(
|
||||
self, event: RecallAction
|
||||
) -> MicroagentObservation | None:
|
||||
"""Add repository and runtime information to the stream as a MicroagentObservation."""
|
||||
|
||||
# Create ENVIRONMENT info:
|
||||
# - repository_info
|
||||
# - runtime_info
|
||||
# - repository_instructions
|
||||
|
||||
# Collect raw repository instructions
|
||||
repo_instructions = ''
|
||||
assert (
|
||||
len(self.repo_microagents) <= 1
|
||||
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
|
||||
|
||||
# Retrieve the context of repo instructions
|
||||
for microagent in self.repo_microagents.values():
|
||||
if repo_instructions:
|
||||
repo_instructions += '\n\n'
|
||||
repo_instructions += microagent.content
|
||||
|
||||
# Create observation if we have anything
|
||||
if self.repository_info or self.runtime_info or repo_instructions:
|
||||
obs = MicroagentObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name=self.repository_info.repo_name
|
||||
if self.repository_info and self.repository_info.repo_name is not None
|
||||
else '',
|
||||
repo_directory=self.repository_info.repo_directory
|
||||
if self.repository_info
|
||||
and self.repository_info.repo_directory is not None
|
||||
else '',
|
||||
repo_instructions=repo_instructions if repo_instructions else '',
|
||||
runtime_hosts=self.runtime_info.available_hosts
|
||||
if self.runtime_info and self.runtime_info.available_hosts is not None
|
||||
else {},
|
||||
additional_agent_instructions=self.runtime_info.additional_agent_instructions
|
||||
if self.runtime_info
|
||||
and self.runtime_info.additional_agent_instructions is not None
|
||||
else '',
|
||||
microagent_knowledge=[],
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
return obs
|
||||
return None
|
||||
|
||||
def _on_microagent_action(
|
||||
self,
|
||||
event: RecallAction,
|
||||
prev_observation: MicroagentObservation | None = None,
|
||||
) -> MicroagentObservation | None:
|
||||
"""When a microagent action triggers microagents, create a MicroagentObservation with structured data."""
|
||||
# If there's no query, do nothing
|
||||
query = event.query.strip()
|
||||
if not query:
|
||||
return prev_observation
|
||||
|
||||
assert prev_observation is None or isinstance(
|
||||
prev_observation, MicroagentObservation
|
||||
), f'Expected a MicroagentObservation, but got {type(prev_observation)}'
|
||||
|
||||
# Process text to find suitable microagents and create a MicroagentObservation.
|
||||
recalled_content: list[MicroagentKnowledge] = []
|
||||
for name, microagent in self.knowledge_microagents.items():
|
||||
trigger = microagent.match_trigger(query)
|
||||
if trigger:
|
||||
logger.info("Microagent '%s' triggered by keyword '%s'", name, trigger)
|
||||
recalled_content.append(
|
||||
MicroagentKnowledge(
|
||||
name=microagent.name,
|
||||
trigger=trigger,
|
||||
content=microagent.content,
|
||||
)
|
||||
)
|
||||
|
||||
if recalled_content:
|
||||
if prev_observation is not None:
|
||||
# it may be on the first user message that already found some repo info etc
|
||||
prev_observation.microagent_knowledge.extend(recalled_content)
|
||||
else:
|
||||
# if it's not the first user message, we may not have found any information this step
|
||||
obs = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=recalled_content,
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
return obs
|
||||
|
||||
return prev_observation
|
||||
|
||||
def load_user_workspace_microagents(
|
||||
self, user_microagents: list[BaseMicroAgent]
|
||||
) -> None:
|
||||
"""
|
||||
This method loads microagents from a user's cloned repo or workspace directory.
|
||||
|
||||
This is typically called from agent_session or setup once the workspace is cloned.
|
||||
"""
|
||||
logger.info(
|
||||
'Loading user workspace microagents: %s', [m.name for m in user_microagents]
|
||||
)
|
||||
for user_microagent in user_microagents:
|
||||
if isinstance(user_microagent, KnowledgeMicroAgent):
|
||||
self.knowledge_microagents[user_microagent.name] = user_microagent
|
||||
elif isinstance(user_microagent, RepoMicroAgent):
|
||||
self.repo_microagents[user_microagent.name] = user_microagent
|
||||
|
||||
def _load_global_microagents(self) -> None:
|
||||
"""
|
||||
Loads microagents from the global microagents_dir
|
||||
"""
|
||||
repo_agents, knowledge_agents, _ = load_microagents_from_dir(
|
||||
GLOBAL_MICROAGENTS_DIR
|
||||
)
|
||||
for name, agent in knowledge_agents.items():
|
||||
if isinstance(agent, KnowledgeMicroAgent):
|
||||
self.knowledge_microagents[name] = agent
|
||||
for name, agent in repo_agents.items():
|
||||
if isinstance(agent, RepoMicroAgent):
|
||||
self.repo_microagents[name] = agent
|
||||
|
||||
def set_repository_info(self, repo_name: str, repo_directory: str) -> None:
|
||||
"""Store repository info so we can reference it in an observation."""
|
||||
if repo_name or repo_directory:
|
||||
self.repository_info = RepositoryInfo(repo_name, repo_directory)
|
||||
else:
|
||||
self.repository_info = None
|
||||
|
||||
def set_runtime_info(self, runtime: Runtime) -> None:
|
||||
"""Store runtime info (web hosts, ports, etc.)."""
|
||||
# e.g. { '127.0.0.1': 8080 }
|
||||
if runtime.web_hosts or runtime.additional_agent_instructions:
|
||||
self.runtime_info = RuntimeInfo(
|
||||
available_hosts=runtime.web_hosts,
|
||||
additional_agent_instructions=runtime.additional_agent_instructions,
|
||||
)
|
||||
else:
|
||||
self.runtime_info = None
|
||||
|
||||
def send_error_message(self, message_id: str, message: str):
|
||||
"""Sends an error message if the callback function was provided."""
|
||||
if self.status_callback:
|
||||
try:
|
||||
if self.loop is None:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_status_message('error', message_id, message), self.loop
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(
|
||||
f'Error sending status message: {e.__class__.__name__}',
|
||||
stack_info=False,
|
||||
)
|
||||
|
||||
async def _send_status_message(self, msg_type: str, id: str, message: str):
|
||||
"""Sends a status message to the client."""
|
||||
if self.status_callback:
|
||||
self.status_callback(msg_type, id, message)
|
||||
@ -15,7 +15,8 @@ from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import ChangeAgentStateAction, MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.microagent import BaseMicroAgent
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroAgent
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||
@ -126,6 +127,15 @@ class AgentSession:
|
||||
agent_to_llm_config=agent_to_llm_config,
|
||||
agent_configs=agent_configs,
|
||||
)
|
||||
|
||||
repo_directory = None
|
||||
if self.runtime and runtime_connected and selected_repository:
|
||||
repo_directory = selected_repository.split('/')[-1]
|
||||
self.memory = await self._create_memory(
|
||||
selected_repository=selected_repository,
|
||||
repo_directory=repo_directory,
|
||||
)
|
||||
|
||||
if github_token:
|
||||
self.event_stream.set_secrets(
|
||||
{
|
||||
@ -260,26 +270,14 @@ class AgentSession:
|
||||
)
|
||||
return False
|
||||
|
||||
repo_directory = None
|
||||
if selected_repository:
|
||||
repo_directory = await call_sync_from_async(
|
||||
await call_sync_from_async(
|
||||
self.runtime.clone_repo,
|
||||
github_token,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
)
|
||||
|
||||
if agent.prompt_manager:
|
||||
agent.prompt_manager.set_runtime_info(self.runtime)
|
||||
microagents: list[BaseMicroAgent] = await call_sync_from_async(
|
||||
self.runtime.get_microagents_from_selected_repo, selected_repository
|
||||
)
|
||||
agent.prompt_manager.load_microagents(microagents)
|
||||
if selected_repository and repo_directory:
|
||||
agent.prompt_manager.set_repository_info(
|
||||
selected_repository, repo_directory
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
||||
)
|
||||
@ -342,6 +340,29 @@ class AgentSession:
|
||||
|
||||
return controller
|
||||
|
||||
async def _create_memory(
|
||||
self, selected_repository: str | None, repo_directory: str | None
|
||||
) -> Memory:
|
||||
memory = Memory(
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
status_callback=self._status_callback,
|
||||
)
|
||||
|
||||
if self.runtime:
|
||||
# sets available hosts and other runtime info
|
||||
memory.set_runtime_info(self.runtime)
|
||||
|
||||
# loads microagents from repo/.openhands/microagents
|
||||
microagents: list[BaseMicroAgent] = await call_sync_from_async(
|
||||
self.runtime.get_microagents_from_selected_repo, selected_repository
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
|
||||
if selected_repository and repo_directory:
|
||||
memory.set_repository_info(selected_repository, repo_directory)
|
||||
return memory
|
||||
|
||||
def _maybe_restore_state(self) -> State | None:
|
||||
"""Helper method to handle state restore logic."""
|
||||
restored_state = None
|
||||
|
||||
@ -85,8 +85,8 @@ class FileConversationStore(ConversationStore):
|
||||
try:
|
||||
conversations.append(await self.get_metadata(conversation_id))
|
||||
except Exception:
|
||||
logger.error(
|
||||
f'Error loading conversation: {conversation_id}',
|
||||
logger.warning(
|
||||
f'Could not load conversation metadata: {conversation_id}',
|
||||
)
|
||||
conversations.sort(key=_sort_key, reverse=True)
|
||||
conversations = conversations[start:end]
|
||||
|
||||
@ -1,25 +1,18 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import islice
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.logger import openhands_logger
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.microagent import (
|
||||
BaseMicroAgent,
|
||||
KnowledgeMicroAgent,
|
||||
RepoMicroAgent,
|
||||
load_microagents_from_dir,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.events.observation.agent import MicroagentKnowledge
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeInfo:
|
||||
available_hosts: dict[str, int]
|
||||
additional_agent_instructions: str
|
||||
available_hosts: dict[str, int] = field(default_factory=dict)
|
||||
additional_agent_instructions: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -32,75 +25,23 @@ class RepositoryInfo:
|
||||
|
||||
class PromptManager:
|
||||
"""
|
||||
Manages prompt templates and micro-agents for AI interactions.
|
||||
Manages prompt templates and includes information from the user's workspace micro-agents and global micro-agents.
|
||||
|
||||
This class handles loading and rendering of system and user prompt templates,
|
||||
as well as loading micro-agent specifications. It provides methods to access
|
||||
rendered system and initial user messages for AI interactions.
|
||||
This class is dedicated to loading and rendering prompts (system prompt, user prompt).
|
||||
|
||||
Attributes:
|
||||
prompt_dir (str): Directory containing prompt templates.
|
||||
microagent_dir (str): Directory containing microagent specifications.
|
||||
disabled_microagents (list[str] | None): List of microagents to disable. If None, all microagents are enabled.
|
||||
prompt_dir: Directory containing prompt templates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_dir: str,
|
||||
microagent_dir: str | None = None,
|
||||
disabled_microagents: list[str] | None = None,
|
||||
):
|
||||
self.disabled_microagents: list[str] = disabled_microagents or []
|
||||
self.prompt_dir: str = prompt_dir
|
||||
self.repository_info: RepositoryInfo | None = None
|
||||
self.system_template: Template = self._load_template('system_prompt')
|
||||
self.user_template: Template = self._load_template('user_prompt')
|
||||
self.additional_info_template: Template = self._load_template('additional_info')
|
||||
self.microagent_info_template: Template = self._load_template('microagent_info')
|
||||
self.runtime_info = RuntimeInfo(
|
||||
available_hosts={}, additional_agent_instructions=''
|
||||
)
|
||||
|
||||
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}
|
||||
self.repo_microagents: dict[str, RepoMicroAgent] = {}
|
||||
|
||||
if microagent_dir:
|
||||
# This loads micro-agents from the microagent_dir
|
||||
# which is typically the OpenHands/microagents (i.e., the PUBLIC microagents)
|
||||
|
||||
# Only load KnowledgeMicroAgents
|
||||
repo_microagents, knowledge_microagents, _ = load_microagents_from_dir(
|
||||
microagent_dir
|
||||
)
|
||||
assert all(
|
||||
isinstance(microagent, KnowledgeMicroAgent)
|
||||
for microagent in knowledge_microagents.values()
|
||||
)
|
||||
for name, microagent in knowledge_microagents.items():
|
||||
if name not in self.disabled_microagents:
|
||||
self.knowledge_microagents[name] = microagent
|
||||
assert all(
|
||||
isinstance(microagent, RepoMicroAgent)
|
||||
for microagent in repo_microagents.values()
|
||||
)
|
||||
for name, microagent in repo_microagents.items():
|
||||
if name not in self.disabled_microagents:
|
||||
self.repo_microagents[name] = microagent
|
||||
|
||||
def load_microagents(self, microagents: list[BaseMicroAgent]) -> None:
|
||||
"""Load microagents from a list of BaseMicroAgents.
|
||||
|
||||
This is typically used when loading microagents from inside a repo.
|
||||
"""
|
||||
openhands_logger.info('Loading microagents: %s', [m.name for m in microagents])
|
||||
# Only keep KnowledgeMicroAgents and RepoMicroAgents
|
||||
for microagent in microagents:
|
||||
if microagent.name in self.disabled_microagents:
|
||||
continue
|
||||
if isinstance(microagent, KnowledgeMicroAgent):
|
||||
self.knowledge_microagents[microagent.name] = microagent
|
||||
elif isinstance(microagent, RepoMicroAgent):
|
||||
self.repo_microagents[microagent.name] = microagent
|
||||
|
||||
def _load_template(self, template_name: str) -> Template:
|
||||
if self.prompt_dir is None:
|
||||
@ -114,27 +55,6 @@ class PromptManager:
|
||||
def get_system_message(self) -> str:
|
||||
return self.system_template.render().strip()
|
||||
|
||||
def set_runtime_info(self, runtime: Runtime) -> None:
|
||||
self.runtime_info.available_hosts = runtime.web_hosts
|
||||
self.runtime_info.additional_agent_instructions = (
|
||||
runtime.additional_agent_instructions
|
||||
)
|
||||
|
||||
def set_repository_info(
|
||||
self,
|
||||
repo_name: str,
|
||||
repo_directory: str,
|
||||
) -> None:
|
||||
"""Sets information about the GitHub repository that has been cloned.
|
||||
|
||||
Args:
|
||||
repo_name: The name of the GitHub repository (e.g. 'owner/repo')
|
||||
repo_directory: The directory where the repository has been cloned
|
||||
"""
|
||||
self.repository_info = RepositoryInfo(
|
||||
repo_name=repo_name, repo_directory=repo_directory
|
||||
)
|
||||
|
||||
def get_example_user_message(self) -> str:
|
||||
"""This is the initial user message provided to the agent
|
||||
before *actual* user instructions are provided.
|
||||
@ -148,45 +68,6 @@ class PromptManager:
|
||||
|
||||
return self.user_template.render().strip()
|
||||
|
||||
def enhance_message(self, message: Message) -> None:
|
||||
"""Enhance the user message with additional context.
|
||||
|
||||
This method is used to enhance the user message with additional context
|
||||
about the user's task. The additional context will convert the current
|
||||
generic agent into a more specialized agent that is tailored to the user's task.
|
||||
"""
|
||||
if not message.content:
|
||||
return
|
||||
|
||||
# if there were other texts included, they were before the user message
|
||||
# so the last TextContent is the user message
|
||||
# content can be a list of TextContent or ImageContent
|
||||
message_content = ''
|
||||
for content in reversed(message.content):
|
||||
if isinstance(content, TextContent):
|
||||
message_content = content.text
|
||||
break
|
||||
|
||||
if not message_content:
|
||||
return
|
||||
|
||||
triggered_agents = []
|
||||
for name, microagent in self.knowledge_microagents.items():
|
||||
trigger = microagent.match_trigger(message_content)
|
||||
if trigger:
|
||||
openhands_logger.info(
|
||||
"Microagent '%s' triggered by keyword '%s'",
|
||||
name,
|
||||
trigger,
|
||||
)
|
||||
# Create a dictionary with the agent and trigger word
|
||||
triggered_agents.append({'agent': microagent, 'trigger_word': trigger})
|
||||
|
||||
if triggered_agents:
|
||||
formatted_text = self.build_microagent_info(triggered_agents)
|
||||
# Insert the new content at the start of the TextContent list
|
||||
message.content.insert(0, TextContent(text=formatted_text))
|
||||
|
||||
def add_examples_to_initial_message(self, message: Message) -> None:
|
||||
"""Add example_message to the first user message."""
|
||||
example_message = self.get_example_user_message() or None
|
||||
@ -195,44 +76,28 @@ class PromptManager:
|
||||
if example_message:
|
||||
message.content.insert(0, TextContent(text=example_message))
|
||||
|
||||
def add_info_to_initial_message(
|
||||
def build_additional_info(
|
||||
self,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""Adds information about the repository and runtime to the initial user message.
|
||||
|
||||
Args:
|
||||
message: The initial user message to add information to.
|
||||
"""
|
||||
repo_instructions = ''
|
||||
assert (
|
||||
len(self.repo_microagents) <= 1
|
||||
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
|
||||
for microagent in self.repo_microagents.values():
|
||||
# We assume these are the repo instructions
|
||||
if repo_instructions:
|
||||
repo_instructions += '\n\n'
|
||||
repo_instructions += microagent.content
|
||||
|
||||
additional_info = self.additional_info_template.render(
|
||||
repository_info: RepositoryInfo | None,
|
||||
runtime_info: RuntimeInfo | None,
|
||||
repo_instructions: str = '',
|
||||
) -> str:
|
||||
"""Renders the additional info template with the stored repository/runtime info."""
|
||||
return self.additional_info_template.render(
|
||||
repository_info=repository_info,
|
||||
repository_instructions=repo_instructions,
|
||||
repository_info=self.repository_info,
|
||||
runtime_info=self.runtime_info,
|
||||
runtime_info=runtime_info,
|
||||
).strip()
|
||||
|
||||
# Insert the new content at the start of the TextContent list
|
||||
if additional_info:
|
||||
message.content.insert(0, TextContent(text=additional_info))
|
||||
|
||||
def build_microagent_info(
|
||||
self,
|
||||
triggered_agents: list[dict],
|
||||
triggered_agents: list[MicroagentKnowledge],
|
||||
) -> str:
|
||||
"""Renders the microagent info template with the triggered agents.
|
||||
|
||||
Args:
|
||||
triggered_agents: A list of dictionaries, each containing an "agent"
|
||||
(KnowledgeMicroAgent) and a "trigger_word" (str).
|
||||
triggered_agents: A list of MicroagentKnowledge objects containing information
|
||||
about triggered microagents.
|
||||
"""
|
||||
return self.microagent_info_template.render(
|
||||
triggered_agents=triggered_agents
|
||||
|
||||
@ -9,6 +9,7 @@ from openhands.events.action import (
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
MessageAction,
|
||||
RecallAction,
|
||||
)
|
||||
from openhands.events.action.action import ActionConfirmationStatus
|
||||
from openhands.events.action.files import FileEditSource, FileReadSource
|
||||
@ -356,6 +357,18 @@ def test_file_ohaci_edit_action_legacy_serialization():
|
||||
assert event_dict['args']['end'] == -1
|
||||
|
||||
|
||||
def test_agent_microagent_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'recall',
|
||||
'args': {
|
||||
'query': 'What is the capital of France?',
|
||||
'thought': 'I need to find information about France',
|
||||
'recall_type': 'knowledge',
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, RecallAction)
|
||||
|
||||
|
||||
def test_file_read_action_legacy_serialization():
|
||||
original_action_dict = {
|
||||
'action': 'read',
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -14,12 +14,16 @@ from openhands.core.main import run_controller
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import MicroagentObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
@ -47,17 +51,36 @@ def mock_agent():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
mock = MagicMock(spec=EventStream)
|
||||
mock = MagicMock(
|
||||
spec=EventStream,
|
||||
event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})),
|
||||
)
|
||||
mock.get_latest_event_id.return_value = 0
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_event_stream():
|
||||
event_stream = EventStream(sid='test', file_store=InMemoryFileStore({}))
|
||||
return event_stream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime() -> Runtime:
|
||||
return MagicMock(
|
||||
runtime = MagicMock(
|
||||
spec=Runtime,
|
||||
event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})),
|
||||
event_stream=test_event_stream,
|
||||
)
|
||||
return runtime
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory() -> Memory:
|
||||
memory = MagicMock(
|
||||
spec=Memory,
|
||||
event_stream=test_event_stream,
|
||||
)
|
||||
return memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -68,6 +91,7 @@ def mock_status_callback():
|
||||
async def send_event_to_controller(controller, event):
|
||||
await controller._on_event(event)
|
||||
await asyncio.sleep(0.1)
|
||||
controller._pending_action = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -140,10 +164,8 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_fatal_error():
|
||||
async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
|
||||
config = AppConfig()
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent = MagicMock(spec=Agent)
|
||||
@ -163,10 +185,23 @@ async def test_run_controller_with_fatal_error():
|
||||
if isinstance(event, CmdRunAction):
|
||||
error_obs = ErrorObservation('You messed around with Jim')
|
||||
error_obs._cause = event.id
|
||||
event_stream.add_event(error_obs, EventSource.USER)
|
||||
test_event_stream.add_event(error_obs, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = test_event_stream
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
@ -175,22 +210,20 @@ async def test_run_controller_with_fatal_error():
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
print(f'state: {state}')
|
||||
events = list(event_stream.get_events())
|
||||
events = list(test_event_stream.get_events())
|
||||
print(f'event_stream: {events}')
|
||||
assert state.iteration == 4
|
||||
assert state.iteration == 3
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
|
||||
assert len(events) == 11
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_stop_with_stuck():
|
||||
async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
|
||||
config = AppConfig()
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
|
||||
def agent_step_fn(state):
|
||||
@ -209,10 +242,23 @@ async def test_run_controller_stop_with_stuck():
|
||||
'Non fatal error here to trigger loop'
|
||||
)
|
||||
non_fatal_error_obs._cause = event.id
|
||||
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
|
||||
test_event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = test_event_stream
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
@ -221,16 +267,17 @@ async def test_run_controller_stop_with_stuck():
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
events = list(event_stream.get_events())
|
||||
events = list(test_event_stream.get_events())
|
||||
print(f'state: {state}')
|
||||
for i, event in enumerate(events):
|
||||
print(f'event {i}: {event_to_dict(event)}')
|
||||
|
||||
assert state.iteration == 4
|
||||
assert state.iteration == 3
|
||||
assert len(events) == 11
|
||||
# check the eventstream have 4 pairs of repeated actions and observations
|
||||
repeating_actions_and_observations = events[2:10]
|
||||
repeating_actions_and_observations = events[4:12]
|
||||
for action, observation in zip(
|
||||
repeating_actions_and_observations[0::2],
|
||||
repeating_actions_and_observations[1::2],
|
||||
@ -510,12 +557,13 @@ async def test_reset_with_pending_action_no_metadata(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_max_iterations_has_metrics():
|
||||
async def test_run_controller_max_iterations_has_metrics(
|
||||
test_event_stream, mock_memory
|
||||
):
|
||||
config = AppConfig(
|
||||
max_iterations=3,
|
||||
)
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
event_stream = test_event_stream
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
@ -546,6 +594,17 @@ async def test_run_controller_max_iterations_has_metrics():
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()))
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
@ -553,6 +612,7 @@ async def test_run_controller_max_iterations_has_metrics():
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
assert state.iteration == 3
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
@ -630,7 +690,7 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
mock_agent, mock_runtime
|
||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
||||
):
|
||||
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON"""
|
||||
|
||||
@ -656,6 +716,20 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
mock_agent.step = step_state.step
|
||||
mock_agent.config = AgentConfig()
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
mock_runtime.event_stream = test_event_stream
|
||||
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
@ -665,6 +739,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
@ -691,7 +766,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
mock_agent, mock_runtime
|
||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
||||
):
|
||||
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
|
||||
|
||||
@ -702,7 +777,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
def step(self, state: State):
|
||||
# If the state has more than one message and we haven't errored yet,
|
||||
# throw the context window exceeded error
|
||||
if len(state.history) > 1 and not self.has_errored:
|
||||
if len(state.history) > 3 and not self.has_errored:
|
||||
error = ContextWindowExceededError(
|
||||
message='prompt is too long: 233885 tokens > 200000 maximum',
|
||||
model='',
|
||||
@ -718,6 +793,19 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
mock_agent.config = AgentConfig()
|
||||
mock_agent.config.enable_history_truncation = False
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
mock_runtime.event_stream = test_event_stream
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
@ -727,6 +815,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
@ -751,6 +840,44 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
assert step_state.has_errored
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_memory_error(test_event_stream):
|
||||
config = AppConfig()
|
||||
event_stream = test_event_stream
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Create a real Memory instance
|
||||
memory = Memory(event_stream=event_stream, sid='test-memory')
|
||||
|
||||
# Patch the _on_microagent_action method to raise our test exception
|
||||
def mock_on_microagent_action(*args, **kwargs):
|
||||
raise RuntimeError('Test memory error')
|
||||
|
||||
with patch.object(
|
||||
memory, '_on_microagent_action', side_effect=mock_on_microagent_action
|
||||
):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
assert state.iteration == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Error: RuntimeError'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_metrics_copy():
|
||||
# Setup
|
||||
@ -851,3 +978,56 @@ async def test_action_metrics_copy():
|
||||
assert last_action.llm_metrics.accumulated_cost == 0.07
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_user_message_with_identical_content():
|
||||
"""
|
||||
Test that _first_user_message correctly identifies the first user message
|
||||
even when multiple messages have identical content but different IDs.
|
||||
|
||||
The issue we're checking is that the comparison (action == self._first_user_message())
|
||||
should correctly differentiate between messages with the same content but different IDs.
|
||||
"""
|
||||
# Create a real event stream for this test
|
||||
event_stream = EventStream(sid='test', file_store=InMemoryFileStore({}))
|
||||
|
||||
# Create an agent controller
|
||||
mock_agent = MagicMock(spec=Agent)
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create and add the first user message
|
||||
first_message = MessageAction(content='Hello, this is a test message')
|
||||
first_message._source = EventSource.USER
|
||||
event_stream.add_event(first_message, EventSource.USER)
|
||||
|
||||
# Create and add a second user message with identical content
|
||||
second_message = MessageAction(content='Hello, this is a test message')
|
||||
second_message._source = EventSource.USER
|
||||
event_stream.add_event(second_message, EventSource.USER)
|
||||
|
||||
# Verify that _first_user_message returns the first message
|
||||
first_user_message = controller._first_user_message()
|
||||
assert first_user_message is not None
|
||||
assert first_user_message.id == first_message.id # Check IDs match
|
||||
assert first_user_message.id != second_message.id # Different IDs
|
||||
assert first_user_message == first_message == second_message # dataclass equality
|
||||
|
||||
# Test the comparison used in the actual code
|
||||
assert first_message == first_user_message # This should be True
|
||||
assert (
|
||||
second_message.id != first_user_message.id
|
||||
) # This should be False, but may be True if there's a bug
|
||||
|
||||
await controller.close()
|
||||
|
||||
@ -17,8 +17,13 @@ from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation.agent import MicroagentObservation
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@ -75,6 +80,25 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
initial_state=parent_state,
|
||||
)
|
||||
|
||||
# Setup Memory to catch RecallActions
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_memory.event_stream = mock_event_stream
|
||||
|
||||
def on_event(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
# create a MicroagentObservation
|
||||
microagent_observation = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
content='microagent',
|
||||
)
|
||||
microagent_observation._cause = event.id # ignore attr-defined warning
|
||||
mock_event_stream.add_event(microagent_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
mock_memory.on_event = on_event
|
||||
mock_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, mock_memory.on_event, mock_memory
|
||||
)
|
||||
|
||||
# Setup a delegate action from the parent
|
||||
delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True})
|
||||
mock_parent_agent.step.return_value = delegate_action
|
||||
@ -87,7 +111,16 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
# Give time for the async step() to execute
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# The parent should receive step() from that event
|
||||
# Verify that a MicroagentObservation was added to the event stream
|
||||
events = list(mock_event_stream.get_events())
|
||||
assert (
|
||||
mock_event_stream.get_latest_event_id() == 3
|
||||
) # Microagents and AgentChangeState
|
||||
|
||||
# a MicroagentObservation and an AgentDelegateAction should be in the list
|
||||
assert any(isinstance(event, MicroagentObservation) for event in events)
|
||||
assert any(isinstance(event, AgentDelegateAction) for event in events)
|
||||
|
||||
# Verify that a delegate agent controller is created
|
||||
assert (
|
||||
parent_controller.delegate is not None
|
||||
|
||||
@ -6,9 +6,11 @@ from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AppConfig, LLMConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.events import EventStream, EventStreamSubscriber
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
@ -22,18 +24,24 @@ def mock_agent():
|
||||
llm = MagicMock(spec=LLM)
|
||||
metrics = MagicMock(spec=Metrics)
|
||||
llm_config = MagicMock(spec=LLMConfig)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
|
||||
# Configure the LLM config
|
||||
llm_config.model = 'test-model'
|
||||
llm_config.base_url = 'http://test'
|
||||
llm_config.max_message_chars = 1000
|
||||
|
||||
# Configure the agent config
|
||||
agent_config.disabled_microagents = []
|
||||
|
||||
# Set up the chain of mocks
|
||||
llm.metrics = metrics
|
||||
llm.config = llm_config
|
||||
agent.llm = llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
|
||||
return agent
|
||||
|
||||
@ -78,7 +86,11 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
self.test_initial_state = state
|
||||
super().set_initial_state(*args, state=state, **kwargs)
|
||||
|
||||
# Patch AgentController and State.restore_from_session to fail
|
||||
# Create a real Memory instance with the mock event stream
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession
|
||||
with patch(
|
||||
'openhands.server.session.agent_session.AgentController', SpyAgentController
|
||||
), patch(
|
||||
@ -87,7 +99,7 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
), patch(
|
||||
'openhands.controller.state.state.State.restore_from_session',
|
||||
side_effect=Exception('No state found'),
|
||||
):
|
||||
), patch('openhands.server.session.agent_session.Memory', return_value=memory):
|
||||
await session.start(
|
||||
runtime_name='test-runtime',
|
||||
config=AppConfig(),
|
||||
@ -96,12 +108,18 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
)
|
||||
|
||||
# Verify EventStream.subscribe was called with correct parameters
|
||||
mock_event_stream.subscribe.assert_called_with(
|
||||
mock_event_stream.subscribe.assert_any_call(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER,
|
||||
session.controller.on_event,
|
||||
session.controller.id,
|
||||
)
|
||||
|
||||
mock_event_stream.subscribe.assert_any_call(
|
||||
EventStreamSubscriber.MEMORY,
|
||||
session.memory.on_event,
|
||||
session.controller.id,
|
||||
)
|
||||
|
||||
# Verify set_initial_state was called once with None as state
|
||||
assert session.controller.set_initial_state_call_count == 1
|
||||
assert session.controller.test_initial_state is None
|
||||
@ -159,7 +177,10 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
self.test_initial_state = state
|
||||
super().set_initial_state(*args, state=state, **kwargs)
|
||||
|
||||
# Patch AgentController and State.restore_from_session to succeed
|
||||
# create a mock Memory
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
|
||||
# Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession
|
||||
with patch(
|
||||
'openhands.server.session.agent_session.AgentController', SpyAgentController
|
||||
), patch(
|
||||
@ -168,7 +189,7 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
), patch(
|
||||
'openhands.controller.state.state.State.restore_from_session',
|
||||
return_value=mock_restored_state,
|
||||
):
|
||||
), patch('openhands.server.session.agent_session.Memory', mock_memory):
|
||||
await session.start(
|
||||
runtime_name='test-runtime',
|
||||
config=AppConfig(),
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -17,13 +17,18 @@ def mock_runtime():
|
||||
with patch('openhands.core.cli.create_runtime') as mock_create_runtime:
|
||||
mock_runtime_instance = AsyncMock()
|
||||
# Mock the event stream with proper async methods
|
||||
mock_runtime_instance.event_stream = AsyncMock()
|
||||
mock_runtime_instance.event_stream.subscribe = AsyncMock()
|
||||
mock_runtime_instance.event_stream.add_event = AsyncMock()
|
||||
mock_event_stream = AsyncMock()
|
||||
mock_event_stream.subscribe = AsyncMock()
|
||||
mock_event_stream.add_event = AsyncMock()
|
||||
mock_event_stream.get_events = AsyncMock(return_value=[])
|
||||
mock_event_stream.get_latest_event_id = AsyncMock(return_value=0)
|
||||
mock_runtime_instance.event_stream = mock_event_stream
|
||||
# Mock connect method to return immediately
|
||||
mock_runtime_instance.connect = AsyncMock()
|
||||
# Ensure status_callback is None
|
||||
mock_runtime_instance.status_callback = None
|
||||
# Mock get_microagents_from_selected_repo
|
||||
mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[])
|
||||
mock_create_runtime.return_value = mock_runtime_instance
|
||||
yield mock_runtime_instance
|
||||
|
||||
@ -32,6 +37,16 @@ def mock_runtime():
|
||||
def mock_agent():
|
||||
with patch('openhands.core.cli.create_agent') as mock_create_agent:
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_instance.name = 'test-agent'
|
||||
mock_agent_instance.llm = AsyncMock()
|
||||
mock_agent_instance.llm.config = AsyncMock()
|
||||
mock_agent_instance.llm.config.model = 'test-model'
|
||||
mock_agent_instance.llm.config.base_url = 'http://test'
|
||||
mock_agent_instance.llm.config.max_message_chars = 1000
|
||||
mock_agent_instance.config = AsyncMock()
|
||||
mock_agent_instance.config.disabled_microagents = []
|
||||
mock_agent_instance.sandbox_plugins = []
|
||||
mock_agent_instance.prompt_manager = AsyncMock()
|
||||
mock_create_agent.return_value = mock_agent_instance
|
||||
yield mock_agent_instance
|
||||
|
||||
|
||||
@ -369,9 +369,3 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
|
||||
# Fifth message only has ImageContent, no TextContent to modify
|
||||
assert len(enhanced_messages[5].content) == 1
|
||||
assert isinstance(enhanced_messages[5].content[0], ImageContent)
|
||||
|
||||
# Verify prompt manager methods were called as expected
|
||||
assert agent.prompt_manager.add_examples_to_initial_message.call_count == 1
|
||||
assert (
|
||||
agent.prompt_manager.enhance_message.call_count == 5
|
||||
) # Called for each user message
|
||||
|
||||
@ -1,16 +1,29 @@
|
||||
import os
|
||||
import shutil
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource
|
||||
from openhands.events.event import (
|
||||
Event,
|
||||
EventSource,
|
||||
FileEditSource,
|
||||
FileReadSource,
|
||||
RecallType,
|
||||
)
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
MicroagentObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputMetadata,
|
||||
@ -22,14 +35,45 @@ from openhands.events.observation.files import FileEditObservation, FileReadObse
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.memory.conversation_memory import ConversationMemory
|
||||
from openhands.utils.prompt import PromptManager
|
||||
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_memory():
|
||||
def agent_config():
|
||||
return AgentConfig(
|
||||
enable_prompt_extensions=True,
|
||||
enable_som_visual_browsing=True,
|
||||
disabled_microagents=['disabled_agent'],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_memory(agent_config):
|
||||
prompt_manager = MagicMock(spec=PromptManager)
|
||||
prompt_manager.get_system_message.return_value = 'System message'
|
||||
return ConversationMemory(prompt_manager)
|
||||
prompt_manager.build_additional_info.return_value = (
|
||||
'Formatted repository and runtime info'
|
||||
)
|
||||
|
||||
# Make build_microagent_info return the actual content from the triggered agents
|
||||
def build_microagent_info(triggered_agents):
|
||||
if not triggered_agents:
|
||||
return ''
|
||||
return '\n'.join(agent.content for agent in triggered_agents)
|
||||
|
||||
prompt_manager.build_microagent_info.side_effect = build_microagent_info
|
||||
return ConversationMemory(agent_config, prompt_manager)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_dir(tmp_path):
|
||||
# Copy contents from "openhands/agenthub/codeact_agent" to the temp directory
|
||||
shutil.copytree(
|
||||
'openhands/agenthub/codeact_agent/prompts', tmp_path, dirs_exist_ok=True
|
||||
)
|
||||
|
||||
# Return the temporary directory path
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -308,6 +352,40 @@ def test_process_events_with_user_reject_observation(conversation_memory):
|
||||
assert '[Last action has been rejected by the user]' in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_empty_environment_info(conversation_memory):
|
||||
"""Test that empty environment info observations return an empty list of messages without calling build_additional_info."""
|
||||
# Create a MicroagentObservation with empty info
|
||||
|
||||
empty_obs = MicroagentObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='',
|
||||
repo_directory='',
|
||||
repo_instructions='',
|
||||
runtime_hosts={},
|
||||
additional_agent_instructions='',
|
||||
microagent_knowledge=[],
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[empty_obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Should only contain the initial system message
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
# Verify that build_additional_info was NOT called since all input values were empty
|
||||
conversation_memory.prompt_manager.build_additional_info.assert_not_called()
|
||||
|
||||
|
||||
def test_process_events_with_function_calling_observation(conversation_memory):
|
||||
mock_response = {
|
||||
'id': 'mock_id',
|
||||
@ -446,3 +524,529 @@ def test_apply_prompt_caching(conversation_memory):
|
||||
assert messages[1].content[0].cache_prompt is False
|
||||
assert messages[2].content[0].cache_prompt is False
|
||||
assert messages[3].content[0].cache_prompt is True
|
||||
|
||||
|
||||
def test_process_events_with_environment_microagent_observation(conversation_memory):
|
||||
"""Test processing a MicroagentObservation with ENVIRONMENT info type."""
|
||||
obs = MicroagentObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='test-repo',
|
||||
repo_directory='/path/to/repo',
|
||||
repo_instructions='# Test Repository\nThis is a test repository.',
|
||||
runtime_hosts={'localhost': 8080},
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == 'Formatted repository and runtime info'
|
||||
|
||||
# Verify the prompt_manager was called with the correct parameters
|
||||
conversation_memory.prompt_manager.build_additional_info.assert_called_once()
|
||||
call_args = conversation_memory.prompt_manager.build_additional_info.call_args[1]
|
||||
assert isinstance(call_args['repository_info'], RepositoryInfo)
|
||||
assert call_args['repository_info'].repo_name == 'test-repo'
|
||||
assert call_args['repository_info'].repo_directory == '/path/to/repo'
|
||||
assert isinstance(call_args['runtime_info'], RuntimeInfo)
|
||||
assert call_args['runtime_info'].available_hosts == {'localhost': 8080}
|
||||
assert (
|
||||
call_args['repo_instructions']
|
||||
== '# Test Repository\nThis is a test repository.'
|
||||
)
|
||||
|
||||
|
||||
def test_process_events_with_knowledge_microagent_microagent_observation(
|
||||
conversation_memory,
|
||||
):
|
||||
"""Test processing a MicroagentObservation with KNOWLEDGE type."""
|
||||
microagent_knowledge = [
|
||||
MicroagentKnowledge(
|
||||
name='test_agent',
|
||||
trigger='test',
|
||||
content='This is test agent content',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='another_agent',
|
||||
trigger='another',
|
||||
content='This is another agent content',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='disabled_agent',
|
||||
trigger='disabled',
|
||||
content='This is disabled agent content',
|
||||
),
|
||||
]
|
||||
|
||||
obs = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=microagent_knowledge,
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
# Verify that disabled_agent is filtered out and enabled agents are included
|
||||
assert 'This is test agent content' in result.content[0].text
|
||||
assert 'This is another agent content' in result.content[0].text
|
||||
assert 'This is disabled agent content' not in result.content[0].text
|
||||
|
||||
# Verify the prompt_manager was called with the correct parameters
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_called_once()
|
||||
call_args = conversation_memory.prompt_manager.build_microagent_info.call_args[1]
|
||||
|
||||
# Check that disabled_agent was filtered out
|
||||
triggered_agents = call_args['triggered_agents']
|
||||
assert len(triggered_agents) == 2
|
||||
agent_names = [agent.name for agent in triggered_agents]
|
||||
assert 'test_agent' in agent_names
|
||||
assert 'another_agent' in agent_names
|
||||
assert 'disabled_agent' not in agent_names
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_extensions_disabled(
|
||||
agent_config, conversation_memory
|
||||
):
|
||||
"""Test processing a MicroagentObservation when prompt extensions are disabled."""
|
||||
# Modify the agent config to disable prompt extensions
|
||||
agent_config.enable_prompt_extensions = False
|
||||
|
||||
obs = MicroagentObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='test-repo',
|
||||
repo_directory='/path/to/repo',
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# When prompt extensions are disabled, the MicroagentObservation should be ignored
|
||||
assert len(messages) == 1 # Only the initial system message
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
# Verify the prompt_manager was not called
|
||||
conversation_memory.prompt_manager.build_additional_info.assert_not_called()
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
|
||||
|
||||
|
||||
def test_process_events_with_empty_microagent_knowledge(conversation_memory):
|
||||
"""Test processing a MicroagentObservation with empty microagent knowledge."""
|
||||
obs = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[],
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# The implementation returns an empty string and it doesn't creates a message
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
# When there are no triggered agents, build_microagent_info is not called
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
|
||||
|
||||
|
||||
def test_conversation_memory_processes_microagent_observation(prompt_dir):
|
||||
"""Test that ConversationMemory processes MicroagentObservations correctly."""
|
||||
# Create a microagent_info.j2 template file
|
||||
template_path = os.path.join(prompt_dir, 'microagent_info.j2')
|
||||
if not os.path.exists(template_path):
|
||||
with open(template_path, 'w') as f:
|
||||
f.write("""{% for agent_info in triggered_agents %}
|
||||
<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
# Verify the template was correctly rendered
|
||||
{{ agent_info.content }}
|
||||
</EXTRA_INFO>
|
||||
{% endfor %}
|
||||
""")
|
||||
|
||||
# Create a mock agent config
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
agent_config.enable_prompt_extensions = True
|
||||
agent_config.disabled_microagents = []
|
||||
|
||||
# Create a PromptManager
|
||||
prompt_manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Initialize ConversationMemory
|
||||
conversation_memory = ConversationMemory(
|
||||
config=agent_config, prompt_manager=prompt_manager
|
||||
)
|
||||
|
||||
# Create a MicroagentObservation with microagent knowledge
|
||||
microagent_observation = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='test_agent',
|
||||
trigger='test_trigger',
|
||||
content='This is triggered content for testing.',
|
||||
)
|
||||
],
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
# Process the observation
|
||||
messages = conversation_memory._process_observation(
|
||||
obs=microagent_observation, tool_call_id_to_message={}, max_message_chars=None
|
||||
)
|
||||
|
||||
# Verify the message was created correctly
|
||||
assert len(messages) == 1
|
||||
message = messages[0]
|
||||
assert message.role == 'user'
|
||||
assert len(message.content) == 1
|
||||
assert isinstance(message.content[0], TextContent)
|
||||
|
||||
expected_text = """<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "test_trigger".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
This is triggered content for testing.
|
||||
</EXTRA_INFO>"""
|
||||
|
||||
assert message.content[0].text.strip() == expected_text.strip()
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'microagent_info.j2'))
|
||||
|
||||
|
||||
def test_conversation_memory_processes_environment_microagent_observation(prompt_dir):
|
||||
"""Test that ConversationMemory processes environment info MicroagentObservations correctly."""
|
||||
# Create an additional_info.j2 template file
|
||||
template_path = os.path.join(prompt_dir, 'additional_info.j2')
|
||||
if not os.path.exists(template_path):
|
||||
with open(template_path, 'w') as f:
|
||||
f.write("""
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
|
||||
{% if repository_instructions %}
|
||||
<REPOSITORY_INSTRUCTIONS>
|
||||
{{ repository_instructions }}
|
||||
</REPOSITORY_INSTRUCTIONS>
|
||||
{% endif %}
|
||||
|
||||
{% if runtime_info and runtime_info.available_hosts %}
|
||||
<RUNTIME_INFORMATION>
|
||||
The user has access to the following hosts for accessing a web application,
|
||||
each of which has a corresponding port:
|
||||
{% for host, port in runtime_info.available_hosts.items() %}
|
||||
* {{ host }} (port {{ port }})
|
||||
{% endfor %}
|
||||
</RUNTIME_INFORMATION>
|
||||
{% endif %}
|
||||
""")
|
||||
|
||||
# Create a mock agent config
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
agent_config.enable_prompt_extensions = True
|
||||
|
||||
# Create a PromptManager
|
||||
prompt_manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Initialize ConversationMemory
|
||||
conversation_memory = ConversationMemory(
|
||||
config=agent_config, prompt_manager=prompt_manager
|
||||
)
|
||||
|
||||
# Create a MicroagentObservation with environment info
|
||||
microagent_observation = MicroagentObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='owner/repo',
|
||||
repo_directory='/workspace/repo',
|
||||
repo_instructions='This repository contains important code.',
|
||||
runtime_hosts={'example.com': 8080},
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
# Process the observation
|
||||
messages = conversation_memory._process_observation(
|
||||
obs=microagent_observation, tool_call_id_to_message={}, max_message_chars=None
|
||||
)
|
||||
|
||||
# Verify the message was created correctly
|
||||
assert len(messages) == 1
|
||||
message = messages[0]
|
||||
assert message.role == 'user'
|
||||
assert len(message.content) == 1
|
||||
assert isinstance(message.content[0], TextContent)
|
||||
|
||||
# Check that the message contains the repository info
|
||||
assert '<REPOSITORY_INFO>' in message.content[0].text
|
||||
assert 'owner/repo' in message.content[0].text
|
||||
assert '/workspace/repo' in message.content[0].text
|
||||
|
||||
# Check that the message contains the repository instructions
|
||||
assert '<REPOSITORY_INSTRUCTIONS>' in message.content[0].text
|
||||
assert 'This repository contains important code.' in message.content[0].text
|
||||
|
||||
# Check that the message contains the runtime info
|
||||
assert '<RUNTIME_INFORMATION>' in message.content[0].text
|
||||
assert 'example.com (port 8080)' in message.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication(conversation_memory):
|
||||
"""Test that MicroagentObservations are properly deduplicated based on agent name.
|
||||
|
||||
The deduplication logic should keep the FIRST occurrence of each microagent
|
||||
and filter out later occurrences to avoid redundant information.
|
||||
"""
|
||||
# Create a sequence of MicroagentObservations with overlapping agents
|
||||
obs1 = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_agent',
|
||||
trigger='python',
|
||||
content='Python best practices v1',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='git_agent',
|
||||
trigger='git',
|
||||
content='Git best practices v1',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='image_agent',
|
||||
trigger='image',
|
||||
content='Image best practices v1',
|
||||
),
|
||||
],
|
||||
content='First retrieval',
|
||||
)
|
||||
|
||||
obs2 = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_agent',
|
||||
trigger='python',
|
||||
content='Python best practices v2',
|
||||
),
|
||||
],
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
obs3 = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='git_agent',
|
||||
trigger='git',
|
||||
content='Git best practices v3',
|
||||
),
|
||||
],
|
||||
content='Third retrieval',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2, obs3],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that only the first occurrence of content for each agent is included
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # system + 1 microagent, because the second and third microagents are duplicates
|
||||
microagent_messages = messages[1:] # Skip system message
|
||||
|
||||
# First microagent should include all agents since they appear here first
|
||||
assert 'Image best practices v1' in microagent_messages[0].content[0].text
|
||||
assert 'Git best practices v1' in microagent_messages[0].content[0].text
|
||||
assert 'Python best practices v1' in microagent_messages[0].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_disabled_agents(
|
||||
conversation_memory,
|
||||
):
|
||||
"""Test that disabled agents are filtered out and deduplication keeps the first occurrence."""
|
||||
# Create a sequence of MicroagentObservations with disabled agents
|
||||
obs1 = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='disabled_agent',
|
||||
trigger='disabled',
|
||||
content='Disabled agent content',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='enabled_agent',
|
||||
trigger='enabled',
|
||||
content='Enabled agent content v1',
|
||||
),
|
||||
],
|
||||
content='First retrieval',
|
||||
)
|
||||
|
||||
obs2 = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='enabled_agent',
|
||||
trigger='enabled',
|
||||
content='Enabled agent content v2',
|
||||
),
|
||||
],
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # system + 1 microagent, the second is the same "enabled_agent"
|
||||
microagent_messages = messages[1:] # Skip system message
|
||||
|
||||
# First microagent should include enabled_agent but not disabled_agent
|
||||
assert 'Disabled agent content' not in microagent_messages[0].content[0].text
|
||||
assert 'Enabled agent content v1' in microagent_messages[0].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
conversation_memory,
|
||||
):
|
||||
"""Test that empty MicroagentObservations are handled correctly."""
|
||||
obs = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[],
|
||||
content='Empty retrieval',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that empty MicroagentObservations are handled gracefully
|
||||
assert (
|
||||
len(messages) == 1
|
||||
) # system message, because an empty microagent is not added to Messages
|
||||
|
||||
|
||||
def test_has_agent_in_earlier_events(conversation_memory):
|
||||
"""Test the _has_agent_in_earlier_events helper method."""
|
||||
# Create test MicroagentObservations
|
||||
obs1 = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='agent1',
|
||||
trigger='trigger1',
|
||||
content='Content 1',
|
||||
),
|
||||
],
|
||||
content='First retrieval',
|
||||
)
|
||||
|
||||
obs2 = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='agent2',
|
||||
trigger='trigger2',
|
||||
content='Content 2',
|
||||
),
|
||||
],
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
obs3 = MicroagentObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
content='Environment info',
|
||||
)
|
||||
|
||||
# Create a list with mixed event types
|
||||
events = [obs1, MessageAction(content='User message'), obs2, obs3]
|
||||
|
||||
# Test looking for existing agents
|
||||
assert conversation_memory._has_agent_in_earlier_events('agent1', 2, events) is True
|
||||
assert conversation_memory._has_agent_in_earlier_events('agent1', 3, events) is True
|
||||
assert conversation_memory._has_agent_in_earlier_events('agent1', 4, events) is True
|
||||
|
||||
# Test looking for an agent in a later position (should not find it)
|
||||
assert (
|
||||
conversation_memory._has_agent_in_earlier_events('agent2', 0, events) is False
|
||||
)
|
||||
assert (
|
||||
conversation_memory._has_agent_in_earlier_events('agent2', 1, events) is False
|
||||
)
|
||||
|
||||
# Test looking for an agent in a different microagent type (should not find it)
|
||||
assert (
|
||||
conversation_memory._has_agent_in_earlier_events('non_existent', 3, events)
|
||||
is False
|
||||
)
|
||||
|
||||
@ -358,12 +358,12 @@ class TestStuckDetector:
|
||||
with patch('logging.Logger.warning'):
|
||||
assert stuck_detector.is_stuck(headless_mode=True) is False
|
||||
|
||||
def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
|
||||
def test_is_not_stuck_ipython_unterminated_string_error_only_two_incidents(
|
||||
self, stuck_detector: StuckDetector
|
||||
):
|
||||
state = stuck_detector.state
|
||||
self._impl_unterminated_string_error_events(
|
||||
state, random_line=False, incidents=3
|
||||
state, random_line=False, incidents=2
|
||||
)
|
||||
|
||||
with patch('logging.Logger.warning'):
|
||||
|
||||
260
tests/unit/test_memory.py
Normal file
260
tests/unit/test_memory.py
Normal file
@ -0,0 +1,260 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.main import run_controller
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentObservation,
|
||||
RecallType,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_store():
|
||||
"""Create a temporary file store for testing."""
|
||||
return InMemoryFileStore()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_stream(file_store):
|
||||
"""Create a test event stream."""
|
||||
return EventStream(sid='test_sid', file_store=file_store)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory(event_stream):
|
||||
"""Create a test memory instance."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
memory = Memory(event_stream, 'test_sid')
|
||||
yield memory
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_dir(tmp_path):
|
||||
# Copy contents from "openhands/agenthub/codeact_agent" to the temp directory
|
||||
shutil.copytree(
|
||||
'openhands/agenthub/codeact_agent/prompts', tmp_path, dirs_exist_ok=True
|
||||
)
|
||||
|
||||
# Return the temporary directory path
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_event_exception_handling(memory, event_stream):
|
||||
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
|
||||
|
||||
# Create a dummy agent for the controller
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory method to raise an exception
|
||||
with patch.object(
|
||||
memory, '_on_first_microagent_action', side_effect=Exception('Test error')
|
||||
):
|
||||
state = await run_controller(
|
||||
config=AppConfig(),
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
assert state.iteration == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Error: Exception'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_first_microagent_action_exception_handling(
|
||||
memory, event_stream
|
||||
):
|
||||
"""Test that exceptions in Memory._on_first_microagent_action are properly handled via status callback."""
|
||||
|
||||
# Create a dummy agent for the controller
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory._on_first_microagent_action to raise an exception
|
||||
with patch.object(
|
||||
memory,
|
||||
'_on_first_microagent_action',
|
||||
side_effect=Exception('Test error from _on_first_microagent_action'),
|
||||
):
|
||||
state = await run_controller(
|
||||
config=AppConfig(),
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
assert state.iteration == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Error: Exception'
|
||||
|
||||
|
||||
def test_memory_with_microagents():
|
||||
"""Test that Memory loads microagents from the global directory and processes microagent actions.
|
||||
|
||||
This test verifies that:
|
||||
1. Memory loads microagents from the global GLOBAL_MICROAGENTS_DIR
|
||||
2. When a microagent action with a trigger word is processed, a MicroagentObservation is created
|
||||
"""
|
||||
# Create a mock event stream
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
|
||||
# Initialize Memory to use the global microagents dir
|
||||
memory = Memory(
|
||||
event_stream=event_stream,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
# Verify microagents were loaded - at least one microagent should be loaded
|
||||
# from the global directory that's in the repo
|
||||
assert len(memory.knowledge_microagents) > 0
|
||||
|
||||
# We know 'flarglebargle' exists in the global directory
|
||||
assert 'flarglebargle' in memory.knowledge_microagents
|
||||
|
||||
# Create a microagent action with the trigger word
|
||||
microagent_action = RecallAction(
|
||||
query='Hello, flarglebargle!', recall_type=RecallType.KNOWLEDGE
|
||||
)
|
||||
|
||||
# Mock the event_stream.add_event method
|
||||
added_events = []
|
||||
|
||||
def original_add_event(event, source):
|
||||
added_events.append((event, source))
|
||||
|
||||
event_stream.add_event = original_add_event
|
||||
|
||||
# Add the microagent action to the event stream
|
||||
event_stream.add_event(microagent_action, EventSource.USER)
|
||||
|
||||
# Clear the events list to only capture new events
|
||||
added_events.clear()
|
||||
|
||||
# Process the microagent action
|
||||
memory.on_event(microagent_action)
|
||||
|
||||
# Verify a MicroagentObservation was added to the event stream
|
||||
assert len(added_events) == 1
|
||||
observation, source = added_events[0]
|
||||
assert isinstance(observation, MicroagentObservation)
|
||||
assert source == EventSource.ENVIRONMENT
|
||||
assert observation.recall_type == RecallType.KNOWLEDGE
|
||||
assert len(observation.microagent_knowledge) == 1
|
||||
assert observation.microagent_knowledge[0].name == 'flarglebargle'
|
||||
assert observation.microagent_knowledge[0].trigger == 'flarglebargle'
|
||||
assert 'magic word' in observation.microagent_knowledge[0].content
|
||||
|
||||
|
||||
def test_memory_repository_info(prompt_dir):
|
||||
"""Test that Memory adds repository info to MicroagentObservations."""
|
||||
# Create an in-memory file store and real event stream
|
||||
file_store = InMemoryFileStore()
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
|
||||
# Create a test repo microagent first
|
||||
repo_microagent_name = 'test_repo_microagent'
|
||||
repo_microagent_content = """---
|
||||
name: test_repo
|
||||
type: repo
|
||||
agent: CodeActAgent
|
||||
---
|
||||
|
||||
REPOSITORY INSTRUCTIONS: This is a test repository.
|
||||
"""
|
||||
|
||||
# Create a temporary repo microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(
|
||||
os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md'), 'w'
|
||||
) as f:
|
||||
f.write(repo_microagent_content)
|
||||
|
||||
# Patch the global microagents directory to use our test directory
|
||||
test_microagents_dir = os.path.join(prompt_dir, 'micro')
|
||||
with patch('openhands.memory.memory.GLOBAL_MICROAGENTS_DIR', test_microagents_dir):
|
||||
# Initialize Memory
|
||||
memory = Memory(
|
||||
event_stream=event_stream,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
# Set repository info
|
||||
memory.set_repository_info('owner/repo', '/workspace/repo')
|
||||
|
||||
# Create and add the first user message
|
||||
user_message = MessageAction(content='First user message')
|
||||
user_message._source = EventSource.USER # type: ignore[attr-defined]
|
||||
event_stream.add_event(user_message, EventSource.USER)
|
||||
|
||||
# Create and add the microagent action
|
||||
microagent_action = RecallAction(
|
||||
query='First user message', recall_type=RecallType.WORKSPACE_CONTEXT
|
||||
)
|
||||
microagent_action._source = EventSource.USER # type: ignore[attr-defined]
|
||||
event_stream.add_event(microagent_action, EventSource.USER)
|
||||
|
||||
# Give it a little time to process
|
||||
time.sleep(0.3)
|
||||
|
||||
# Get all events from the stream
|
||||
events = list(event_stream.get_events())
|
||||
|
||||
# Find the MicroagentObservation event
|
||||
microagent_obs_events = [
|
||||
event for event in events if isinstance(event, MicroagentObservation)
|
||||
]
|
||||
|
||||
# We should have at least one MicroagentObservation
|
||||
assert len(microagent_obs_events) > 0
|
||||
|
||||
# Get the first MicroagentObservation
|
||||
observation = microagent_obs_events[0]
|
||||
assert observation.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
assert observation.repo_name == 'owner/repo'
|
||||
assert observation.repo_directory == '/workspace/repo'
|
||||
assert 'This is a test repository' in observation.repo_instructions
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md'))
|
||||
@ -1,16 +1,21 @@
|
||||
from openhands.core.schema.observation import ObservationType
|
||||
from openhands.events.action.files import FileEditSource
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation import (
|
||||
CmdOutputMetadata,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
MicroagentObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.observation.agent import MicroagentKnowledge
|
||||
from openhands.events.serialization import (
|
||||
event_from_dict,
|
||||
event_to_dict,
|
||||
event_to_memory,
|
||||
event_to_trajectory,
|
||||
)
|
||||
from openhands.events.serialization.observation import observation_from_dict
|
||||
|
||||
|
||||
def serialization_deserialization(
|
||||
@ -19,10 +24,10 @@ def serialization_deserialization(
|
||||
observation_instance = event_from_dict(original_observation_dict)
|
||||
assert isinstance(
|
||||
observation_instance, Observation
|
||||
), 'The observation instance should be an instance of Action.'
|
||||
), 'The observation instance should be an instance of Observation.'
|
||||
assert isinstance(
|
||||
observation_instance, cls
|
||||
), 'The observation instance should be an instance of CmdOutputObservation.'
|
||||
), f'The observation instance should be an instance of {cls}.'
|
||||
serialized_observation_dict = event_to_dict(observation_instance)
|
||||
serialized_observation_trajectory = event_to_trajectory(observation_instance)
|
||||
serialized_observation_memory = event_to_memory(
|
||||
@ -236,3 +241,199 @@ def test_file_edit_observation_legacy_serialization():
|
||||
assert event_dict['extras']['old_content'] is None
|
||||
assert event_dict['extras']['new_content'] == 'new content'
|
||||
assert 'formatted_output_and_error' not in event_dict['extras']
|
||||
|
||||
|
||||
def test_microagent_observation_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'microagent',
|
||||
'content': '',
|
||||
'message': "**MicroagentObservation**\nrecall_type=RecallType.WORKSPACE_CONTEXT, repo_name=some_repo_name, repo_instructions=complex_repo_instruc..., runtime_hosts={'host1': 8080, 'host2': 8081}, additional_agent_instructions=You know it all abou...",
|
||||
'extras': {
|
||||
'recall_type': 'workspace_context',
|
||||
'repo_name': 'some_repo_name',
|
||||
'repo_directory': 'some_repo_directory',
|
||||
'runtime_hosts': {'host1': 8080, 'host2': 8081},
|
||||
'repo_instructions': 'complex_repo_instructions',
|
||||
'additional_agent_instructions': 'You know it all about this runtime',
|
||||
'microagent_knowledge': [],
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, MicroagentObservation)
|
||||
|
||||
|
||||
def test_microagent_observation_microagent_knowledge_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'microagent',
|
||||
'content': '',
|
||||
'message': '**MicroagentObservation**\nrecall_type=RecallType.KNOWLEDGE, repo_name=, repo_instructions=..., runtime_hosts={}, additional_agent_instructions=..., microagent_knowledge=microagent1, microagent2',
|
||||
'extras': {
|
||||
'recall_type': 'knowledge',
|
||||
'repo_name': '',
|
||||
'repo_directory': '',
|
||||
'repo_instructions': '',
|
||||
'runtime_hosts': {},
|
||||
'additional_agent_instructions': '',
|
||||
'microagent_knowledge': [
|
||||
{
|
||||
'name': 'microagent1',
|
||||
'trigger': 'trigger1',
|
||||
'content': 'content1',
|
||||
},
|
||||
{
|
||||
'name': 'microagent2',
|
||||
'trigger': 'trigger2',
|
||||
'content': 'content2',
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, MicroagentObservation)
|
||||
|
||||
|
||||
def test_microagent_observation_knowledge_microagent_serialization():
|
||||
"""Test serialization of a MicroagentObservation with KNOWLEDGE_MICROAGENT type."""
|
||||
# Create a MicroagentObservation with microagent knowledge content
|
||||
original = MicroagentObservation(
|
||||
content='Knowledge microagent information',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_best_practices',
|
||||
trigger='python',
|
||||
content='Always use virtual environments for Python projects.',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='git_workflow',
|
||||
trigger='git',
|
||||
content='Create a new branch for each feature or bugfix.',
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Serialize to dictionary
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data structure
|
||||
assert serialized['observation'] == ObservationType.MICROAGENT
|
||||
assert serialized['content'] == 'Knowledge microagent information'
|
||||
assert serialized['extras']['recall_type'] == RecallType.KNOWLEDGE.value
|
||||
assert len(serialized['extras']['microagent_knowledge']) == 2
|
||||
assert serialized['extras']['microagent_knowledge'][0]['trigger'] == 'python'
|
||||
|
||||
# Deserialize back to MicroagentObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify properties are preserved
|
||||
assert deserialized.recall_type == RecallType.KNOWLEDGE
|
||||
assert deserialized.microagent_knowledge == original.microagent_knowledge
|
||||
assert deserialized.content == original.content
|
||||
|
||||
# Check that environment info fields are empty
|
||||
assert deserialized.repo_name == ''
|
||||
assert deserialized.repo_directory == ''
|
||||
assert deserialized.repo_instructions == ''
|
||||
assert deserialized.runtime_hosts == {}
|
||||
|
||||
|
||||
def test_microagent_observation_environment_serialization():
|
||||
"""Test serialization of a MicroagentObservation with ENVIRONMENT type."""
|
||||
# Create a MicroagentObservation with environment info
|
||||
original = MicroagentObservation(
|
||||
content='Environment information',
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='OpenHands',
|
||||
repo_directory='/workspace/openhands',
|
||||
repo_instructions="Follow the project's coding style guide.",
|
||||
runtime_hosts={'127.0.0.1': 8080, 'localhost': 5000},
|
||||
additional_agent_instructions='You know it all about this runtime',
|
||||
)
|
||||
|
||||
# Serialize to dictionary
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data structure
|
||||
assert serialized['observation'] == ObservationType.MICROAGENT
|
||||
assert serialized['content'] == 'Environment information'
|
||||
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
|
||||
assert serialized['extras']['repo_name'] == 'OpenHands'
|
||||
assert serialized['extras']['runtime_hosts'] == {
|
||||
'127.0.0.1': 8080,
|
||||
'localhost': 5000,
|
||||
}
|
||||
assert (
|
||||
serialized['extras']['additional_agent_instructions']
|
||||
== 'You know it all about this runtime'
|
||||
)
|
||||
# Deserialize back to MicroagentObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify properties are preserved
|
||||
assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
assert deserialized.repo_name == original.repo_name
|
||||
assert deserialized.repo_directory == original.repo_directory
|
||||
assert deserialized.repo_instructions == original.repo_instructions
|
||||
assert deserialized.runtime_hosts == original.runtime_hosts
|
||||
assert (
|
||||
deserialized.additional_agent_instructions
|
||||
== original.additional_agent_instructions
|
||||
)
|
||||
# Check that knowledge microagent fields are empty
|
||||
assert deserialized.microagent_knowledge == []
|
||||
|
||||
|
||||
def test_microagent_observation_combined_serialization():
|
||||
"""Test serialization of a MicroagentObservation with both types of information."""
|
||||
# Create a MicroagentObservation with both environment and microagent info
|
||||
# Note: In practice, recall_type would still be one specific type,
|
||||
# but the object could contain both types of fields
|
||||
original = MicroagentObservation(
|
||||
content='Combined information',
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
# Environment info
|
||||
repo_name='OpenHands',
|
||||
repo_directory='/workspace/openhands',
|
||||
repo_instructions="Follow the project's coding style guide.",
|
||||
runtime_hosts={'127.0.0.1': 8080},
|
||||
additional_agent_instructions='You know it all about this runtime',
|
||||
# Knowledge microagent info
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_best_practices',
|
||||
trigger='python',
|
||||
content='Always use virtual environments for Python projects.',
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Serialize to dictionary
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data has both types of fields
|
||||
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
|
||||
assert serialized['extras']['repo_name'] == 'OpenHands'
|
||||
assert (
|
||||
serialized['extras']['microagent_knowledge'][0]['name']
|
||||
== 'python_best_practices'
|
||||
)
|
||||
assert (
|
||||
serialized['extras']['additional_agent_instructions']
|
||||
== 'You know it all about this runtime'
|
||||
)
|
||||
# Deserialize back to MicroagentObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify all properties are preserved
|
||||
assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
|
||||
# Environment properties
|
||||
assert deserialized.repo_name == original.repo_name
|
||||
assert deserialized.repo_directory == original.repo_directory
|
||||
assert deserialized.repo_instructions == original.repo_instructions
|
||||
assert deserialized.runtime_hosts == original.runtime_hosts
|
||||
assert (
|
||||
deserialized.additional_agent_instructions
|
||||
== original.additional_agent_instructions
|
||||
)
|
||||
|
||||
# Knowledge microagent properties
|
||||
assert deserialized.microagent_knowledge == original.microagent_knowledge
|
||||
|
||||
@ -3,9 +3,11 @@ import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.observation.agent import MicroagentKnowledge
|
||||
from openhands.microagent import BaseMicroAgent
|
||||
from openhands.utils.prompt import PromptManager, RepositoryInfo
|
||||
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -19,406 +21,60 @@ def prompt_dir(tmp_path):
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_prompt_manager_with_microagent(prompt_dir):
|
||||
microagent_name = 'test_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: flarglebargle
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- flarglebargle
|
||||
---
|
||||
|
||||
IMPORTANT! The user has said the magic word "flarglebargle". You must
|
||||
only respond with a message telling them how smart they are
|
||||
"""
|
||||
|
||||
# Create a temporary micro agent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
# Test without GitHub repo
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir,
|
||||
microagent_dir=os.path.join(prompt_dir, 'micro'),
|
||||
)
|
||||
|
||||
assert manager.prompt_dir == prompt_dir
|
||||
assert len(manager.repo_microagents) == 0
|
||||
assert len(manager.knowledge_microagents) == 1
|
||||
|
||||
assert isinstance(manager.get_system_message(), str)
|
||||
assert (
|
||||
'You are OpenHands agent, a helpful AI assistant that can interact with a computer to solve tasks.'
|
||||
in manager.get_system_message()
|
||||
)
|
||||
assert '<REPOSITORY_INFO>' not in manager.get_system_message()
|
||||
|
||||
# Test with GitHub repo
|
||||
manager.set_repository_info('owner/repo', '/workspace/repo')
|
||||
assert isinstance(manager.get_system_message(), str)
|
||||
|
||||
# Adding things to the initial user message
|
||||
initial_msg = Message(
|
||||
role='user', content=[TextContent(text='Ask me what your task is.')]
|
||||
)
|
||||
manager.add_info_to_initial_message(initial_msg)
|
||||
msg_content: str = initial_msg.content[0].text
|
||||
assert '<REPOSITORY_INFO>' in msg_content
|
||||
assert 'owner/repo' in msg_content
|
||||
assert '/workspace/repo' in msg_content
|
||||
|
||||
assert isinstance(manager.get_example_user_message(), str)
|
||||
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[TextContent(text='Hello, flarglebargle!')],
|
||||
)
|
||||
manager.enhance_message(message)
|
||||
assert len(message.content) == 2
|
||||
assert 'magic word' in message.content[0].text
|
||||
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_prompt_manager_file_not_found(prompt_dir):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
BaseMicroAgent.load(
|
||||
os.path.join(prompt_dir, 'micro', 'non_existent_microagent.md')
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_manager_template_rendering(prompt_dir):
|
||||
"""Test PromptManager's template rendering functionality."""
|
||||
# Create temporary template files
|
||||
with open(os.path.join(prompt_dir, 'system_prompt.j2'), 'w') as f:
|
||||
f.write("""System prompt: bar""")
|
||||
with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f:
|
||||
f.write('User prompt: foo')
|
||||
with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f:
|
||||
f.write("""
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to the current working directory {{ repository_info.repo_directory }}.
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
""")
|
||||
|
||||
# Test without GitHub repo
|
||||
manager = PromptManager(prompt_dir, microagent_dir='')
|
||||
manager = PromptManager(prompt_dir)
|
||||
assert manager.get_system_message() == 'System prompt: bar'
|
||||
assert manager.get_example_user_message() == 'User prompt: foo'
|
||||
|
||||
# Test with GitHub repo
|
||||
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
|
||||
manager.set_repository_info('owner/repo', '/workspace/repo')
|
||||
assert manager.repository_info.repo_name == 'owner/repo'
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
|
||||
|
||||
# verify its parts are rendered
|
||||
system_msg = manager.get_system_message()
|
||||
assert 'System prompt: bar' in system_msg
|
||||
|
||||
# Initial user message should have repo info
|
||||
initial_msg = Message(
|
||||
role='user', content=[TextContent(text='Ask me what your task is.')]
|
||||
# Test building additional info
|
||||
additional_info = manager.build_additional_info(
|
||||
repository_info=repo_info, runtime_info=None, repo_instructions=''
|
||||
)
|
||||
manager.add_info_to_initial_message(initial_msg)
|
||||
msg_content: str = initial_msg.content[0].text
|
||||
assert '<REPOSITORY_INFO>' in msg_content
|
||||
assert '<REPOSITORY_INFO>' in additional_info
|
||||
assert (
|
||||
"At the user's request, repository owner/repo has been cloned to the current working directory /workspace/repo."
|
||||
in msg_content
|
||||
in additional_info
|
||||
)
|
||||
assert '</REPOSITORY_INFO>' in msg_content
|
||||
assert '</REPOSITORY_INFO>' in additional_info
|
||||
assert manager.get_example_user_message() == 'User prompt: foo'
|
||||
|
||||
# Clean up temporary files
|
||||
os.remove(os.path.join(prompt_dir, 'system_prompt.j2'))
|
||||
os.remove(os.path.join(prompt_dir, 'user_prompt.j2'))
|
||||
os.remove(os.path.join(prompt_dir, 'additional_info.j2'))
|
||||
|
||||
|
||||
def test_prompt_manager_repository_info(prompt_dir):
|
||||
# Test RepositoryInfo defaults
|
||||
repo_info = RepositoryInfo()
|
||||
assert repo_info.repo_name is None
|
||||
assert repo_info.repo_directory is None
|
||||
|
||||
# Test setting repository info
|
||||
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
|
||||
assert manager.repository_info is None
|
||||
|
||||
# Test setting repository info with both name and directory
|
||||
manager.set_repository_info('owner/repo2', '/workspace/repo2')
|
||||
assert manager.repository_info.repo_name == 'owner/repo2'
|
||||
assert manager.repository_info.repo_directory == '/workspace/repo2'
|
||||
|
||||
|
||||
def test_prompt_manager_disabled_microagents(prompt_dir):
|
||||
# Create test microagent files
|
||||
microagent1_name = 'test_microagent1'
|
||||
microagent2_name = 'test_microagent2'
|
||||
microagent1_content = """
|
||||
---
|
||||
name: Test Microagent 1
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- test1
|
||||
---
|
||||
|
||||
Test microagent 1 content
|
||||
"""
|
||||
microagent2_content = """
|
||||
---
|
||||
name: Test Microagent 2
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- test2
|
||||
---
|
||||
|
||||
Test microagent 2 content
|
||||
"""
|
||||
|
||||
# Create temporary micro agent files
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent1_name}.md'), 'w') as f:
|
||||
f.write(microagent1_content)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent2_name}.md'), 'w') as f:
|
||||
f.write(microagent2_content)
|
||||
|
||||
# Test that specific microagents can be disabled
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir,
|
||||
microagent_dir=os.path.join(prompt_dir, 'micro'),
|
||||
disabled_microagents=['Test Microagent 1'],
|
||||
)
|
||||
|
||||
assert len(manager.knowledge_microagents) == 1
|
||||
assert 'Test Microagent 2' in manager.knowledge_microagents
|
||||
assert 'Test Microagent 1' not in manager.knowledge_microagents
|
||||
|
||||
# Test that all microagents are enabled by default
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir,
|
||||
microagent_dir=os.path.join(prompt_dir, 'micro'),
|
||||
)
|
||||
|
||||
assert len(manager.knowledge_microagents) == 2
|
||||
assert 'Test Microagent 1' in manager.knowledge_microagents
|
||||
assert 'Test Microagent 2' in manager.knowledge_microagents
|
||||
|
||||
# Clean up temporary files
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent1_name}.md'))
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent2_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_multiple_text_contents(prompt_dir):
|
||||
# Create a test microagent that triggers on a specific keyword
|
||||
microagent_name = 'keyword_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: KeywordMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- triggerkeyword
|
||||
---
|
||||
|
||||
This is special information about the triggerkeyword.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test that it matches the trigger in the last TextContent
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
TextContent(text='This is some initial context.'),
|
||||
TextContent(text='This is a message without triggers.'),
|
||||
TextContent(text='This contains the triggerkeyword that should match.'),
|
||||
],
|
||||
)
|
||||
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should have added a TextContent with the microagent info at the beginning
|
||||
assert len(message.content) == 4
|
||||
assert 'special information about the triggerkeyword' in message.content[0].text
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_image_content(prompt_dir):
|
||||
# Create a test microagent that triggers on a specific keyword
|
||||
microagent_name = 'image_test_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: ImageTestMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- imagekeyword
|
||||
---
|
||||
|
||||
This is information related to imagekeyword.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test with mix of ImageContent and TextContent
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
TextContent(text='This is some initial text.'),
|
||||
ImageContent(image_urls=['https://example.com/image.jpg']),
|
||||
TextContent(text='This mentions imagekeyword that should match.'),
|
||||
],
|
||||
)
|
||||
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should have added a TextContent with the microagent info at the beginning
|
||||
assert len(message.content) == 4
|
||||
assert 'information related to imagekeyword' in message.content[0].text
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_only_image_content(prompt_dir):
|
||||
# Create a test microagent
|
||||
microagent_name = 'image_only_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: ImageOnlyMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- anytrigger
|
||||
---
|
||||
|
||||
This should not appear in the enhanced message.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test with only ImageContent
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
ImageContent(
|
||||
image_urls=[
|
||||
'https://example.com/image1.jpg',
|
||||
'https://example.com/image2.jpg',
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Should not raise any exceptions
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should not have added any content
|
||||
assert len(message.content) == 1
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_reversed_order(prompt_dir):
|
||||
# Create a test microagent
|
||||
microagent_name = 'reversed_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: ReversedMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- lasttrigger
|
||||
---
|
||||
|
||||
This is specific information about the lasttrigger.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test where the text content is not at the end of the list
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
ImageContent(image_urls=['https://example.com/image1.jpg']),
|
||||
TextContent(text='This contains the lasttrigger word.'),
|
||||
ImageContent(image_urls=['https://example.com/image2.jpg']),
|
||||
],
|
||||
)
|
||||
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should have added a TextContent with the microagent info at the beginning
|
||||
assert len(message.content) == 4
|
||||
assert isinstance(message.content[0], TextContent)
|
||||
assert 'specific information about the lasttrigger' in message.content[0].text
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_empty_content(prompt_dir):
|
||||
# Create a test microagent
|
||||
microagent_name = 'empty_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: EmptyMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- emptytrigger
|
||||
---
|
||||
|
||||
This should not appear in the enhanced message.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test with empty content
|
||||
message = Message(role='user', content=[])
|
||||
|
||||
# Should not raise any exceptions
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should not have added any content
|
||||
assert len(message.content) == 0
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
def test_prompt_manager_file_not_found(prompt_dir):
|
||||
"""Test PromptManager behavior when a template file is not found."""
|
||||
# Test with a non-existent template
|
||||
with pytest.raises(FileNotFoundError):
|
||||
BaseMicroAgent.load(
|
||||
os.path.join(prompt_dir, 'micro', 'non_existent_microagent.md')
|
||||
)
|
||||
|
||||
|
||||
def test_build_microagent_info(prompt_dir):
|
||||
@ -429,33 +85,25 @@ def test_build_microagent_info(prompt_dir):
|
||||
with open(template_path, 'w') as f:
|
||||
f.write("""{% for agent_info in triggered_agents %}
|
||||
<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger }}".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
{{ agent_info.agent.content }}
|
||||
{{ agent_info.content }}
|
||||
</EXTRA_INFO>
|
||||
{% endfor %}
|
||||
""")
|
||||
|
||||
# Create test microagents
|
||||
class MockKnowledgeMicroAgent:
|
||||
def __init__(self, name, content):
|
||||
self.name = name
|
||||
self.content = content
|
||||
|
||||
agent1 = MockKnowledgeMicroAgent(
|
||||
name='test_agent1', content='This is information from agent 1'
|
||||
)
|
||||
|
||||
agent2 = MockKnowledgeMicroAgent(
|
||||
name='test_agent2', content='This is information from agent 2'
|
||||
)
|
||||
|
||||
# Initialize the PromptManager
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Test with a single triggered agent
|
||||
triggered_agents = [{'agent': agent1, 'trigger_word': 'keyword1'}]
|
||||
triggered_agents = [
|
||||
MicroagentKnowledge(
|
||||
name='test_agent1',
|
||||
trigger='keyword1',
|
||||
content='This is information from agent 1',
|
||||
)
|
||||
]
|
||||
result = manager.build_microagent_info(triggered_agents)
|
||||
expected = """<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "keyword1".
|
||||
@ -467,8 +115,16 @@ This is information from agent 1
|
||||
|
||||
# Test with multiple triggered agents
|
||||
triggered_agents = [
|
||||
{'agent': agent1, 'trigger_word': 'keyword1'},
|
||||
{'agent': agent2, 'trigger_word': 'keyword2'},
|
||||
MicroagentKnowledge(
|
||||
name='test_agent1',
|
||||
trigger='keyword1',
|
||||
content='This is information from agent 1',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='test_agent2',
|
||||
trigger='keyword2',
|
||||
content='This is information from agent 2',
|
||||
),
|
||||
]
|
||||
result = manager.build_microagent_info(triggered_agents)
|
||||
expected = """<EXTRA_INFO>
|
||||
@ -491,71 +147,125 @@ This is information from agent 2
|
||||
assert result.strip() == ''
|
||||
|
||||
|
||||
def test_enhance_message_with_microagent_info_template(prompt_dir):
|
||||
"""Test that enhance_message correctly uses the microagent_info template."""
|
||||
# Prepare a microagent_info.j2 template file if it doesn't exist
|
||||
template_path = os.path.join(prompt_dir, 'microagent_info.j2')
|
||||
if not os.path.exists(template_path):
|
||||
with open(template_path, 'w') as f:
|
||||
f.write("""{% for agent_info in triggered_agents %}
|
||||
<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
|
||||
It may or may not be relevant to the user's request.
|
||||
def test_add_examples_to_initial_message(prompt_dir):
|
||||
"""Test adding example messages to an initial message."""
|
||||
# Create a user_prompt.j2 template file
|
||||
with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f:
|
||||
f.write('This is an example user message')
|
||||
|
||||
{{ agent_info.agent.content }}
|
||||
</EXTRA_INFO>
|
||||
{% endfor %}
|
||||
""")
|
||||
# Initialize the PromptManager
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Create a test microagent
|
||||
microagent_name = 'test_trigger_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: test_trigger
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- test_trigger
|
||||
---
|
||||
# Create a message
|
||||
message = Message(role='user', content=[TextContent(text='Original content')])
|
||||
|
||||
This is triggered content for testing the microagent_info template.
|
||||
"""
|
||||
# Add examples to the message
|
||||
manager.add_examples_to_initial_message(message)
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
# Initialize the PromptManager with the microagent directory
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir,
|
||||
microagent_dir=os.path.join(prompt_dir, 'micro'),
|
||||
)
|
||||
|
||||
# Create a message with a trigger keyword
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
TextContent(text="Here's a message containing the test_trigger keyword")
|
||||
],
|
||||
)
|
||||
|
||||
# Enhance the message
|
||||
manager.enhance_message(message)
|
||||
|
||||
# The message should now have extra content at the beginning
|
||||
# Check that the example was added at the beginning
|
||||
assert len(message.content) == 2
|
||||
assert isinstance(message.content[0], TextContent)
|
||||
|
||||
# Verify the template was correctly rendered
|
||||
expected_text = """<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "test_trigger".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
This is triggered content for testing the microagent_info template.
|
||||
</EXTRA_INFO>"""
|
||||
|
||||
assert message.content[0].text.strip() == expected_text.strip()
|
||||
assert message.content[0].text == 'This is an example user message'
|
||||
assert message.content[1].text == 'Original content'
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
os.remove(os.path.join(prompt_dir, 'user_prompt.j2'))
|
||||
|
||||
|
||||
def test_add_turns_left_reminder(prompt_dir):
|
||||
"""Test adding turns left reminder to messages."""
|
||||
# Initialize the PromptManager
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Create a State object with specific iteration values
|
||||
state = State()
|
||||
state.iteration = 3
|
||||
state.max_iterations = 10
|
||||
|
||||
# Create a list of messages with a user message
|
||||
user_message = Message(role='user', content=[TextContent(text='User content')])
|
||||
assistant_message = Message(
|
||||
role='assistant', content=[TextContent(text='Assistant content')]
|
||||
)
|
||||
messages = [assistant_message, user_message]
|
||||
|
||||
# Add turns left reminder
|
||||
manager.add_turns_left_reminder(messages, state)
|
||||
|
||||
# Check that the reminder was added to the latest user message
|
||||
assert len(user_message.content) == 2
|
||||
assert (
|
||||
'ENVIRONMENT REMINDER: You have 7 turns left to complete the task.'
|
||||
in user_message.content[1].text
|
||||
)
|
||||
|
||||
|
||||
def test_build_additional_info_with_repo_and_runtime(prompt_dir):
|
||||
"""Test building additional info with repository and runtime information."""
|
||||
# Create an additional_info.j2 template file
|
||||
with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f:
|
||||
f.write("""
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
|
||||
{% if repository_instructions %}
|
||||
<REPOSITORY_INSTRUCTIONS>
|
||||
{{ repository_instructions }}
|
||||
</REPOSITORY_INSTRUCTIONS>
|
||||
{% endif %}
|
||||
|
||||
{% if runtime_info and (runtime_info.available_hosts or runtime_info.additional_agent_instructions) -%}
|
||||
<RUNTIME_INFORMATION>
|
||||
{% if runtime_info.available_hosts %}
|
||||
The user has access to the following hosts for accessing a web application,
|
||||
each of which has a corresponding port:
|
||||
{% for host, port in runtime_info.available_hosts.items() %}
|
||||
* {{ host }} (port {{ port }})
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if runtime_info.additional_agent_instructions %}
|
||||
{{ runtime_info.additional_agent_instructions }}
|
||||
{% endif %}
|
||||
</RUNTIME_INFORMATION>
|
||||
{% endif %}
|
||||
""")
|
||||
|
||||
# Initialize the PromptManager
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Create repository and runtime information
|
||||
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
|
||||
runtime_info = RuntimeInfo(
|
||||
available_hosts={'example.com': 8080},
|
||||
additional_agent_instructions='You know everything about this runtime.',
|
||||
)
|
||||
repo_instructions = 'This repository contains important code.'
|
||||
|
||||
# Build additional info
|
||||
result = manager.build_additional_info(
|
||||
repository_info=repo_info,
|
||||
runtime_info=runtime_info,
|
||||
repo_instructions=repo_instructions,
|
||||
)
|
||||
|
||||
# Check that all information is included
|
||||
assert '<REPOSITORY_INFO>' in result
|
||||
assert 'owner/repo' in result
|
||||
assert '/workspace/repo' in result
|
||||
assert '<REPOSITORY_INSTRUCTIONS>' in result
|
||||
assert 'This repository contains important code.' in result
|
||||
assert '<RUNTIME_INFORMATION>' in result
|
||||
assert 'example.com (port 8080)' in result
|
||||
assert 'You know everything about this runtime.' in result
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'additional_info.j2'))
|
||||
|
||||
|
||||
def test_prompt_manager_initialization_error():
|
||||
"""Test that PromptManager raises an error if the prompt directory is not set."""
|
||||
with pytest.raises(ValueError, match='Prompt directory is not set'):
|
||||
PromptManager(None)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user