OpenHands/tests/unit/test_memory.py
Engel Nyst cc45f5d9c3
Add RecallActions and observations for retrieval of prompt extensions (#6909)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Calvin Smith <email@cjsmith.io>
2025-03-15 21:48:37 +01:00

261 lines
8.7 KiB
Python

import asyncio
import os
import shutil
import time
from unittest.mock import MagicMock, patch
import pytest
from openhands.controller.agent import Agent
from openhands.core.config import AppConfig
from openhands.core.main import run_controller
from openhands.core.schema.agent import AgentState
from openhands.events.action.agent import RecallAction
from openhands.events.action.message import MessageAction
from openhands.events.event import EventSource
from openhands.events.observation.agent import (
MicroagentObservation,
RecallType,
)
from openhands.events.stream import EventStream
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def file_store():
"""Create a temporary file store for testing."""
return InMemoryFileStore()
@pytest.fixture
def event_stream(file_store):
"""Create a test event stream."""
return EventStream(sid='test_sid', file_store=file_store)
@pytest.fixture
def memory(event_stream):
"""Create a test memory instance."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
memory = Memory(event_stream, 'test_sid')
yield memory
loop.close()
@pytest.fixture
def prompt_dir(tmp_path):
# Copy contents from "openhands/agenthub/codeact_agent" to the temp directory
shutil.copytree(
'openhands/agenthub/codeact_agent/prompts', tmp_path, dirs_exist_ok=True
)
# Return the temporary directory path
return tmp_path
@pytest.mark.asyncio
async def test_memory_on_event_exception_handling(memory, event_stream):
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
# Create a dummy agent for the controller
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
# Create a mock runtime
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
# Mock Memory method to raise an exception
with patch.object(
memory, '_on_first_microagent_action', side_effect=Exception('Test error')
):
state = await run_controller(
config=AppConfig(),
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
)
# Verify that the controller's last error was set
assert state.iteration == 0
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'Error: Exception'
@pytest.mark.asyncio
async def test_memory_on_first_microagent_action_exception_handling(
memory, event_stream
):
"""Test that exceptions in Memory._on_first_microagent_action are properly handled via status callback."""
# Create a dummy agent for the controller
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
# Create a mock runtime
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
# Mock Memory._on_first_microagent_action to raise an exception
with patch.object(
memory,
'_on_first_microagent_action',
side_effect=Exception('Test error from _on_first_microagent_action'),
):
state = await run_controller(
config=AppConfig(),
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
memory=memory,
)
# Verify that the controller's last error was set
assert state.iteration == 0
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'Error: Exception'
def test_memory_with_microagents():
"""Test that Memory loads microagents from the global directory and processes microagent actions.
This test verifies that:
1. Memory loads microagents from the global GLOBAL_MICROAGENTS_DIR
2. When a microagent action with a trigger word is processed, a MicroagentObservation is created
"""
# Create a mock event stream
event_stream = MagicMock(spec=EventStream)
# Initialize Memory to use the global microagents dir
memory = Memory(
event_stream=event_stream,
sid='test-session',
)
# Verify microagents were loaded - at least one microagent should be loaded
# from the global directory that's in the repo
assert len(memory.knowledge_microagents) > 0
# We know 'flarglebargle' exists in the global directory
assert 'flarglebargle' in memory.knowledge_microagents
# Create a microagent action with the trigger word
microagent_action = RecallAction(
query='Hello, flarglebargle!', recall_type=RecallType.KNOWLEDGE
)
# Mock the event_stream.add_event method
added_events = []
def original_add_event(event, source):
added_events.append((event, source))
event_stream.add_event = original_add_event
# Add the microagent action to the event stream
event_stream.add_event(microagent_action, EventSource.USER)
# Clear the events list to only capture new events
added_events.clear()
# Process the microagent action
memory.on_event(microagent_action)
# Verify a MicroagentObservation was added to the event stream
assert len(added_events) == 1
observation, source = added_events[0]
assert isinstance(observation, MicroagentObservation)
assert source == EventSource.ENVIRONMENT
assert observation.recall_type == RecallType.KNOWLEDGE
assert len(observation.microagent_knowledge) == 1
assert observation.microagent_knowledge[0].name == 'flarglebargle'
assert observation.microagent_knowledge[0].trigger == 'flarglebargle'
assert 'magic word' in observation.microagent_knowledge[0].content
def test_memory_repository_info(prompt_dir):
"""Test that Memory adds repository info to MicroagentObservations."""
# Create an in-memory file store and real event stream
file_store = InMemoryFileStore()
event_stream = EventStream(sid='test-session', file_store=file_store)
# Create a test repo microagent first
repo_microagent_name = 'test_repo_microagent'
repo_microagent_content = """---
name: test_repo
type: repo
agent: CodeActAgent
---
REPOSITORY INSTRUCTIONS: This is a test repository.
"""
# Create a temporary repo microagent file
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
with open(
os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md'), 'w'
) as f:
f.write(repo_microagent_content)
# Patch the global microagents directory to use our test directory
test_microagents_dir = os.path.join(prompt_dir, 'micro')
with patch('openhands.memory.memory.GLOBAL_MICROAGENTS_DIR', test_microagents_dir):
# Initialize Memory
memory = Memory(
event_stream=event_stream,
sid='test-session',
)
# Set repository info
memory.set_repository_info('owner/repo', '/workspace/repo')
# Create and add the first user message
user_message = MessageAction(content='First user message')
user_message._source = EventSource.USER # type: ignore[attr-defined]
event_stream.add_event(user_message, EventSource.USER)
# Create and add the microagent action
microagent_action = RecallAction(
query='First user message', recall_type=RecallType.WORKSPACE_CONTEXT
)
microagent_action._source = EventSource.USER # type: ignore[attr-defined]
event_stream.add_event(microagent_action, EventSource.USER)
# Give it a little time to process
time.sleep(0.3)
# Get all events from the stream
events = list(event_stream.get_events())
# Find the MicroagentObservation event
microagent_obs_events = [
event for event in events if isinstance(event, MicroagentObservation)
]
# We should have at least one MicroagentObservation
assert len(microagent_obs_events) > 0
# Get the first MicroagentObservation
observation = microagent_obs_events[0]
assert observation.recall_type == RecallType.WORKSPACE_CONTEXT
assert observation.repo_name == 'owner/repo'
assert observation.repo_directory == '/workspace/repo'
assert 'This is a test repository' in observation.repo_instructions
# Clean up
os.remove(os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md'))