OpenHands/tests/unit/test_agent_session.py
Rohit Malhotra 2fd1fdcd7e
[Refactor, Fix]: Agent controller state/metrics management (#9012)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-06-16 11:24:13 -04:00

405 lines
15 KiB
Python

from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import LLMConfig, OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
# We'll use the DeprecatedState class from the main codebase
@pytest.fixture
def mock_agent():
"""Create a properly configured mock agent with all required nested attributes"""
# Create the base mocks
agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)
agent_config = MagicMock(spec=AgentConfig)
# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.max_message_chars = 1000
# Configure the agent config
agent_config.disabled_microagents = []
agent_config.enable_mcp = True
# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
return agent
@pytest.mark.asyncio
async def test_agent_session_start_with_no_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's no state to restore"""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a spy on set_initial_state
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None
def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# Create a real Memory instance with the mock event stream
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession
with (
patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
),
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
),
patch('openhands.server.session.agent_session.Memory', return_value=memory),
):
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_any_call(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
mock_event_stream.subscribe.assert_any_call(
EventStreamSubscriber.MEMORY,
session.memory.on_event,
session.controller.id,
)
# Verify set_initial_state was called once with None as state
assert session.controller.set_initial_state_call_count == 1
assert session.controller.test_initial_state is None
assert session.controller.state.iteration_flag.max_value == 10
assert session.controller.agent.name == 'test-agent'
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
@pytest.mark.asyncio
async def test_agent_session_start_with_restored_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's a state to restore"""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with some events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a mock restored state
mock_restored_state = MagicMock(spec=State)
mock_restored_state.start_id = -1
mock_restored_state.end_id = -1
# Use iteration_flag instead of max_iterations
mock_restored_state.iteration_flag = MagicMock()
mock_restored_state.iteration_flag.max_value = 5
# Add metrics attribute
mock_restored_state.metrics = MagicMock(spec=Metrics)
# Create a spy on set_initial_state by subclassing AgentController
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None
def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# create a mock Memory
mock_memory = MagicMock(spec=Memory)
# Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession
with (
patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
),
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
return_value=mock_restored_state,
),
patch('openhands.server.session.agent_session.Memory', mock_memory),
):
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify set_initial_state was called once with the restored state
assert session.controller.set_initial_state_call_count == 1
# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
assert session.controller.test_initial_state is mock_restored_state
assert session.controller.state is mock_restored_state
assert session.controller.state.iteration_flag.max_value == 5
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
@pytest.mark.asyncio
async def test_metrics_centralization_and_sharing(mock_agent):
"""Test that metrics are centralized and shared between controller and agent."""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a real Memory instance with the mock event stream
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# Patch necessary components
with (
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
),
patch('openhands.server.session.agent_session.Memory', return_value=memory),
):
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify that the agent's LLM metrics and controller's state metrics are the same object
assert session.controller.agent.llm.metrics is session.controller.state.metrics
# Add some metrics to the agent's LLM
test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost)
# Verify that the cost is reflected in the controller's state metrics
assert session.controller.state.metrics.accumulated_cost == test_cost
# Create a test metrics object to simulate an observation with metrics
test_observation_metrics = Metrics()
test_observation_metrics.add_cost(0.1)
# Get the current accumulated cost before merging
current_cost = session.controller.state.metrics.accumulated_cost
# Simulate merging metrics from an observation
session.controller.state_tracker.merge_metrics(test_observation_metrics)
# Verify that the merged metrics are reflected in both agent and controller
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
assert (
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
)
# Reset the agent and verify that metrics are not reset
session.controller.agent.reset()
# Metrics should still be the same after reset
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1
assert session.controller.agent.llm.metrics is session.controller.state.metrics
@pytest.mark.asyncio
async def test_budget_control_flag_syncs_with_metrics(mock_agent):
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a real Memory instance with the mock event stream
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# Patch necessary components
with (
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
),
patch('openhands.server.session.agent_session.Memory', return_value=memory),
):
# Start the session with a budget limit
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
max_budget_per_task=1.0, # Set a budget limit
)
# Verify that the budget control flag was created
assert session.controller.state.budget_flag is not None
assert session.controller.state.budget_flag.max_value == 1.0
assert session.controller.state.budget_flag.current_value == 0.0
# Add some metrics to the agent's LLM
test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost)
# Verify that the budget control flag's current value is updated
# This happens through the state_tracker.sync_budget_flag_with_metrics method
session.controller.state_tracker.sync_budget_flag_with_metrics()
assert session.controller.state.budget_flag.current_value == test_cost
# Create a test metrics object to simulate an observation with metrics
test_observation_metrics = Metrics()
test_observation_metrics.add_cost(0.1)
# Simulate merging metrics from an observation
session.controller.state_tracker.merge_metrics(test_observation_metrics)
# Verify that the budget control flag's current value is updated to match the new accumulated cost
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
# Reset the agent and verify that metrics and budget flag are not reset
session.controller.agent.reset()
# Budget control flag should still reflect the accumulated cost after reset
assert session.controller.state.budget_flag.current_value == test_cost + 0.1