OpenHands/tests/unit/test_agent_session.py
2025-07-15 19:03:28 -04:00

448 lines
16 KiB
Python

from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import LLMConfig, OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.integrations.service_types import ProviderType
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
# We'll use the DeprecatedState class from the main codebase
@pytest.fixture
def mock_agent():
"""Create a properly configured mock agent with all required nested attributes"""
# Create the base mocks
agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)
agent_config = MagicMock(spec=AgentConfig)
# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.max_message_chars = 1000
# Configure the agent config
agent_config.disabled_microagents = []
agent_config.enable_mcp = True
# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
return agent
@pytest.mark.asyncio
async def test_agent_session_start_with_no_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's no state to restore"""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a spy on set_initial_state
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None
def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# Create a real Memory instance with the mock event stream
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession
with (
patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
),
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
),
patch('openhands.server.session.agent_session.Memory', return_value=memory),
):
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_any_call(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
mock_event_stream.subscribe.assert_any_call(
EventStreamSubscriber.MEMORY,
session.memory.on_event,
session.controller.id,
)
# Verify set_initial_state was called once with None as state
assert session.controller.set_initial_state_call_count == 1
assert session.controller.test_initial_state is None
assert session.controller.state.iteration_flag.max_value == 10
assert session.controller.agent.name == 'test-agent'
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
@pytest.mark.asyncio
async def test_agent_session_start_with_restored_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's a state to restore"""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with some events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a mock restored state
mock_restored_state = MagicMock(spec=State)
mock_restored_state.start_id = -1
mock_restored_state.end_id = -1
# Use iteration_flag instead of max_iterations
mock_restored_state.iteration_flag = MagicMock()
mock_restored_state.iteration_flag.max_value = 5
# Add metrics attribute
mock_restored_state.metrics = MagicMock(spec=Metrics)
# Create a spy on set_initial_state by subclassing AgentController
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None
def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# create a mock Memory
mock_memory = MagicMock(spec=Memory)
# Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession
with (
patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
),
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
return_value=mock_restored_state,
),
patch('openhands.server.session.agent_session.Memory', mock_memory),
):
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify set_initial_state was called once with the restored state
assert session.controller.set_initial_state_call_count == 1
# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
assert session.controller.test_initial_state is mock_restored_state
assert session.controller.state is mock_restored_state
assert session.controller.state.iteration_flag.max_value == 5
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
@pytest.mark.asyncio
async def test_metrics_centralization_and_sharing(mock_agent):
"""Test that metrics are centralized and shared between controller and agent."""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a real Memory instance with the mock event stream
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# Patch necessary components
with (
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
),
patch('openhands.server.session.agent_session.Memory', return_value=memory),
):
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify that the agent's LLM metrics and controller's state metrics are the same object
assert session.controller.agent.llm.metrics is session.controller.state.metrics
# Add some metrics to the agent's LLM
test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost)
# Verify that the cost is reflected in the controller's state metrics
assert session.controller.state.metrics.accumulated_cost == test_cost
# Create a test metrics object to simulate an observation with metrics
test_observation_metrics = Metrics()
test_observation_metrics.add_cost(0.1)
# Get the current accumulated cost before merging
current_cost = session.controller.state.metrics.accumulated_cost
# Simulate merging metrics from an observation
session.controller.state_tracker.merge_metrics(test_observation_metrics)
# Verify that the merged metrics are reflected in both agent and controller
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
assert (
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
)
# Reset the agent and verify that metrics are not reset
session.controller.agent.reset()
# Metrics should still be the same after reset
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1
assert session.controller.agent.llm.metrics is session.controller.state.metrics
@pytest.mark.asyncio
async def test_budget_control_flag_syncs_with_metrics(mock_agent):
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a real Memory instance with the mock event stream
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# Patch necessary components
with (
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
),
patch('openhands.server.session.agent_session.Memory', return_value=memory),
):
# Start the session with a budget limit
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
max_budget_per_task=1.0, # Set a budget limit
)
# Verify that the budget control flag was created
assert session.controller.state.budget_flag is not None
assert session.controller.state.budget_flag.max_value == 1.0
assert session.controller.state.budget_flag.current_value == 0.0
# Add some metrics to the agent's LLM
test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost)
# Verify that the budget control flag's current value is updated
# This happens through the state_tracker.sync_budget_flag_with_metrics method
session.controller.state_tracker.sync_budget_flag_with_metrics()
assert session.controller.state.budget_flag.current_value == test_cost
# Create a test metrics object to simulate an observation with metrics
test_observation_metrics = Metrics()
test_observation_metrics.add_cost(0.1)
# Simulate merging metrics from an observation
session.controller.state_tracker.merge_metrics(test_observation_metrics)
# Verify that the budget control flag's current value is updated to match the new accumulated cost
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
# Reset the agent and verify that metrics and budget flag are not reset
session.controller.agent.reset()
# Budget control flag should still reflect the accumulated cost after reset
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
def test_override_provider_tokens_with_custom_secret():
"""Test that override_provider_tokens_with_custom_secret works correctly.
This test verifies that the method properly removes provider tokens when
corresponding custom secrets exist, without causing the 'dictionary changed
size during iteration' error that occurred before the fix.
"""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create test data
git_provider_tokens = {
ProviderType.GITHUB: 'github_token_123',
ProviderType.GITLAB: 'gitlab_token_456',
ProviderType.BITBUCKET: 'bitbucket_token_789',
}
# Custom secrets that will cause some providers to be removed
# Tests both lowercase and uppercase variants to ensure comprehensive coverage
custom_secrets = {
'github_token': 'custom_github_token',
'GITLAB_TOKEN': 'custom_gitlab_token',
}
# This should work without raising RuntimeError: dictionary changed size during iteration
result = session.override_provider_tokens_with_custom_secret(
git_provider_tokens, custom_secrets
)
# Verify that GitHub and GitLab tokens were removed (they have custom secrets)
assert ProviderType.GITHUB not in result
assert ProviderType.GITLAB not in result
# Verify that Bitbucket token remains (no custom secret for it)
assert ProviderType.BITBUCKET in result
assert result[ProviderType.BITBUCKET] == 'bitbucket_token_789'