mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
225 lines
7.9 KiB
Python
225 lines
7.9 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from openhands.controller.agent import Agent
|
|
from openhands.controller.agent_controller import AgentController
|
|
from openhands.controller.state.state import State
|
|
from openhands.core.config import AppConfig, LLMConfig
|
|
from openhands.core.config.agent_config import AgentConfig
|
|
from openhands.events import EventStream, EventStreamSubscriber
|
|
from openhands.llm import LLM
|
|
from openhands.llm.metrics import Metrics
|
|
from openhands.memory.memory import Memory
|
|
from openhands.runtime.impl.action_execution.action_execution_client import (
|
|
ActionExecutionClient,
|
|
)
|
|
from openhands.server.session.agent_session import AgentSession
|
|
from openhands.storage.memory import InMemoryFileStore
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent():
|
|
"""Create a properly configured mock agent with all required nested attributes"""
|
|
# Create the base mocks
|
|
agent = MagicMock(spec=Agent)
|
|
llm = MagicMock(spec=LLM)
|
|
metrics = MagicMock(spec=Metrics)
|
|
llm_config = MagicMock(spec=LLMConfig)
|
|
agent_config = MagicMock(spec=AgentConfig)
|
|
|
|
# Configure the LLM config
|
|
llm_config.model = 'test-model'
|
|
llm_config.base_url = 'http://test'
|
|
llm_config.max_message_chars = 1000
|
|
|
|
# Configure the agent config
|
|
agent_config.disabled_microagents = []
|
|
agent_config.enable_mcp = True
|
|
|
|
# Set up the chain of mocks
|
|
llm.metrics = metrics
|
|
llm.config = llm_config
|
|
agent.llm = llm
|
|
agent.name = 'test-agent'
|
|
agent.sandbox_plugins = []
|
|
agent.config = agent_config
|
|
agent.prompt_manager = MagicMock()
|
|
|
|
return agent
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_session_start_with_no_state(mock_agent):
|
|
"""Test that AgentSession.start() works correctly when there's no state to restore"""
|
|
|
|
# Setup
|
|
file_store = InMemoryFileStore({})
|
|
session = AgentSession(
|
|
sid='test-session',
|
|
file_store=file_store,
|
|
)
|
|
|
|
# Create a mock runtime and set it up
|
|
mock_runtime = MagicMock(spec=ActionExecutionClient)
|
|
|
|
# Mock the runtime creation to set up the runtime attribute
|
|
async def mock_create_runtime(*args, **kwargs):
|
|
session.runtime = mock_runtime
|
|
return True
|
|
|
|
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
|
|
|
|
# Create a mock EventStream with no events
|
|
mock_event_stream = MagicMock(spec=EventStream)
|
|
mock_event_stream.get_events.return_value = []
|
|
mock_event_stream.subscribe = MagicMock()
|
|
mock_event_stream.get_latest_event_id.return_value = 0
|
|
|
|
# Inject the mock event stream into the session
|
|
session.event_stream = mock_event_stream
|
|
|
|
# Create a spy on set_initial_state
|
|
class SpyAgentController(AgentController):
|
|
set_initial_state_call_count = 0
|
|
test_initial_state = None
|
|
|
|
def set_initial_state(self, *args, state=None, **kwargs):
|
|
self.set_initial_state_call_count += 1
|
|
self.test_initial_state = state
|
|
super().set_initial_state(*args, state=state, **kwargs)
|
|
|
|
# Create a real Memory instance with the mock event stream
|
|
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
|
memory.microagents_dir = 'test-dir'
|
|
|
|
# Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession
|
|
with (
|
|
patch(
|
|
'openhands.server.session.agent_session.AgentController', SpyAgentController
|
|
),
|
|
patch(
|
|
'openhands.server.session.agent_session.EventStream',
|
|
return_value=mock_event_stream,
|
|
),
|
|
patch(
|
|
'openhands.controller.state.state.State.restore_from_session',
|
|
side_effect=Exception('No state found'),
|
|
),
|
|
patch('openhands.server.session.agent_session.Memory', return_value=memory),
|
|
):
|
|
await session.start(
|
|
runtime_name='test-runtime',
|
|
config=AppConfig(),
|
|
agent=mock_agent,
|
|
max_iterations=10,
|
|
)
|
|
|
|
# Verify EventStream.subscribe was called with correct parameters
|
|
mock_event_stream.subscribe.assert_any_call(
|
|
EventStreamSubscriber.AGENT_CONTROLLER,
|
|
session.controller.on_event,
|
|
session.controller.id,
|
|
)
|
|
|
|
mock_event_stream.subscribe.assert_any_call(
|
|
EventStreamSubscriber.MEMORY,
|
|
session.memory.on_event,
|
|
session.controller.id,
|
|
)
|
|
|
|
# Verify set_initial_state was called once with None as state
|
|
assert session.controller.set_initial_state_call_count == 1
|
|
assert session.controller.test_initial_state is None
|
|
assert session.controller.state.max_iterations == 10
|
|
assert session.controller.agent.name == 'test-agent'
|
|
assert session.controller.state.start_id == 0
|
|
assert session.controller.state.end_id == -1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_session_start_with_restored_state(mock_agent):
|
|
"""Test that AgentSession.start() works correctly when there's a state to restore"""
|
|
|
|
# Setup
|
|
file_store = InMemoryFileStore({})
|
|
session = AgentSession(
|
|
sid='test-session',
|
|
file_store=file_store,
|
|
)
|
|
|
|
# Create a mock runtime and set it up
|
|
mock_runtime = MagicMock(spec=ActionExecutionClient)
|
|
|
|
# Mock the runtime creation to set up the runtime attribute
|
|
async def mock_create_runtime(*args, **kwargs):
|
|
session.runtime = mock_runtime
|
|
return True
|
|
|
|
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
|
|
|
|
# Create a mock EventStream with some events
|
|
mock_event_stream = MagicMock(spec=EventStream)
|
|
mock_event_stream.get_events.return_value = []
|
|
mock_event_stream.subscribe = MagicMock()
|
|
mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist
|
|
|
|
# Inject the mock event stream into the session
|
|
session.event_stream = mock_event_stream
|
|
|
|
# Create a mock restored state
|
|
mock_restored_state = MagicMock(spec=State)
|
|
mock_restored_state.start_id = -1
|
|
mock_restored_state.end_id = -1
|
|
mock_restored_state.max_iterations = 5
|
|
|
|
# Create a spy on set_initial_state by subclassing AgentController
|
|
class SpyAgentController(AgentController):
|
|
set_initial_state_call_count = 0
|
|
test_initial_state = None
|
|
|
|
def set_initial_state(self, *args, state=None, **kwargs):
|
|
self.set_initial_state_call_count += 1
|
|
self.test_initial_state = state
|
|
super().set_initial_state(*args, state=state, **kwargs)
|
|
|
|
# create a mock Memory
|
|
mock_memory = MagicMock(spec=Memory)
|
|
|
|
# Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession
|
|
with (
|
|
patch(
|
|
'openhands.server.session.agent_session.AgentController', SpyAgentController
|
|
),
|
|
patch(
|
|
'openhands.server.session.agent_session.EventStream',
|
|
return_value=mock_event_stream,
|
|
),
|
|
patch(
|
|
'openhands.controller.state.state.State.restore_from_session',
|
|
return_value=mock_restored_state,
|
|
),
|
|
patch('openhands.server.session.agent_session.Memory', mock_memory),
|
|
):
|
|
await session.start(
|
|
runtime_name='test-runtime',
|
|
config=AppConfig(),
|
|
agent=mock_agent,
|
|
max_iterations=10,
|
|
)
|
|
|
|
# Verify set_initial_state was called once with the restored state
|
|
assert session.controller.set_initial_state_call_count == 1
|
|
|
|
# Verify EventStream.subscribe was called with correct parameters
|
|
mock_event_stream.subscribe.assert_called_with(
|
|
EventStreamSubscriber.AGENT_CONTROLLER,
|
|
session.controller.on_event,
|
|
session.controller.id,
|
|
)
|
|
assert session.controller.test_initial_state is mock_restored_state
|
|
assert session.controller.state is mock_restored_state
|
|
assert session.controller.state.max_iterations == 5
|
|
assert session.controller.state.start_id == 0
|
|
assert session.controller.state.end_id == -1
|