OpenHands/tests/unit/controller/test_agent_controller_loop_recovery.py
Graham Neubig 1e513ad63f
feat: Add configurable stuck/loop detection (#11799)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: chuckbutkus <chuck@all-hands.dev>
2025-11-21 22:27:38 +00:00

411 lines
16 KiB
Python

"""Tests for agent controller loop recovery functionality."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from openhands.controller.agent_controller import AgentController
from openhands.controller.stuck import StuckDetector
from openhands.core.schema import AgentState
from openhands.events import EventStream
from openhands.events.action import LoopRecoveryAction, MessageAction
from openhands.events.observation import LoopDetectionObservation
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore
class TestAgentControllerLoopRecovery:
"""Tests for agent controller loop recovery functionality."""
@pytest.fixture
def mock_controller(self):
"""Create a mock agent controller for testing."""
# Create mock dependencies
mock_event_stream = MagicMock(
spec=EventStream,
event_stream=EventStream(
sid='test-session-id', file_store=InMemoryFileStore({})
),
)
mock_event_stream.sid = 'test-session-id'
mock_event_stream.get_latest_event_id.return_value = 0
mock_conversation_stats = MagicMock(spec=ConversationStats)
mock_agent = MagicMock()
mock_agent.act = AsyncMock()
# Create controller with correct parameters
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
conversation_stats=mock_conversation_stats,
iteration_delta=100,
headless_mode=True,
)
# Mock state properties
controller.state.history = []
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration_flag = MagicMock()
controller.state.iteration_flag.current_value = 10
# Mock stuck detector
controller._stuck_detector = MagicMock(spec=StuckDetector)
controller._stuck_detector.stuck_analysis = None
controller._stuck_detector.is_stuck = MagicMock(return_value=False)
return controller
@pytest.mark.asyncio
async def test_controller_detects_loop_and_produces_observation(
self, mock_controller
):
"""Test that controller detects loops and produces LoopDetectionObservation."""
# Setup stuck detector to detect a loop
mock_controller._stuck_detector.is_stuck.return_value = True
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_type = (
'repeating_action_observation'
)
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Call attempt_loop_recovery
result = mock_controller.attempt_loop_recovery()
# Verify that loop recovery was attempted
assert result is True
# Verify that LoopDetectionObservation was added to event stream
mock_controller.event_stream.add_event.assert_called()
# Check that LoopDetectionObservation was created
calls = mock_controller.event_stream.add_event.call_args_list
loop_detection_found = False
pause_action_found = False
for call in calls:
args, _ = call
# add_event only takes one argument (the event)
event = args[0]
if isinstance(event, LoopDetectionObservation):
loop_detection_found = True
assert 'Agent detected in a loop!' in event.content
assert 'repeating_action_observation' in event.content
assert 'Loop detected at iteration 10' in event.content
elif (
hasattr(event, 'agent_state') and event.agent_state == AgentState.PAUSED
):
pause_action_found = True
assert loop_detection_found, 'LoopDetectionObservation should be created'
assert pause_action_found, 'Agent should be paused'
@pytest.mark.asyncio
async def test_controller_handles_loop_recovery_action_option_1(
self, mock_controller
):
"""Test that controller handles LoopRecoveryAction with option 1."""
# Setup stuck analysis
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Mock the _perform_loop_recovery method for this test
mock_controller._perform_loop_recovery = AsyncMock()
# Create LoopRecoveryAction with option 1
action = LoopRecoveryAction(option=1)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that _perform_loop_recovery was called
mock_controller._perform_loop_recovery.assert_called_once_with(
mock_controller._stuck_detector.stuck_analysis
)
@pytest.mark.asyncio
async def test_controller_handles_loop_recovery_action_option_2(
self, mock_controller
):
"""Test that controller handles LoopRecoveryAction with option 2."""
# Setup stuck analysis
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Mock the _restart_with_last_user_message method for this test
mock_controller._restart_with_last_user_message = AsyncMock()
# Create LoopRecoveryAction with option 2
action = LoopRecoveryAction(option=2)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that _restart_with_last_user_message was called
mock_controller._restart_with_last_user_message.assert_called_once_with(
mock_controller._stuck_detector.stuck_analysis
)
@pytest.mark.asyncio
async def test_controller_handles_loop_recovery_action_option_3(
self, mock_controller
):
"""Test that controller handles LoopRecoveryAction with option 3 (stop)."""
# Setup stuck analysis
mock_controller._stuck_detector.stuck_analysis = MagicMock()
# Mock the set_agent_state_to method for this test
mock_controller.set_agent_state_to = AsyncMock()
# Create LoopRecoveryAction with option 3
action = LoopRecoveryAction(option=3)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that set_agent_state_to was called with STOPPED
mock_controller.set_agent_state_to.assert_called_once_with(AgentState.STOPPED)
@pytest.mark.asyncio
async def test_controller_ignores_loop_recovery_without_stuck_analysis(
self, mock_controller
):
"""Test that controller ignores LoopRecoveryAction when no stuck analysis exists."""
# Ensure no stuck analysis
mock_controller._stuck_detector.stuck_analysis = None
# Mock all recovery methods for this test
mock_controller._perform_loop_recovery = AsyncMock()
mock_controller._restart_with_last_user_message = AsyncMock()
mock_controller.set_agent_state_to = AsyncMock()
# Create LoopRecoveryAction
action = LoopRecoveryAction(option=1)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that no recovery methods were called
mock_controller._perform_loop_recovery.assert_not_called()
mock_controller._restart_with_last_user_message.assert_not_called()
mock_controller.set_agent_state_to.assert_not_called()
@pytest.mark.asyncio
async def test_controller_no_loop_recovery_when_not_stuck(self, mock_controller):
"""Test that controller doesn't attempt recovery when not stuck."""
# Setup no stuck analysis
mock_controller._stuck_detector.stuck_analysis = None
# Reset the mock to ignore any previous calls (like system message)
mock_controller.event_stream.add_event.reset_mock()
# Call attempt_loop_recovery
result = mock_controller.attempt_loop_recovery()
# Verify that no recovery was attempted
assert result is False
# Verify that no loop recovery events were added to the stream
# (Note: there might be other events, but no loop recovery specific ones)
calls = mock_controller.event_stream.add_event.call_args_list
loop_recovery_events = [
call
for call in calls
if len(call[0]) > 0
and (
isinstance(call[0][0], LoopDetectionObservation)
or (
hasattr(call[0][0], 'agent_state')
and call[0][0].agent_state == AgentState.PAUSED
)
)
]
assert len(loop_recovery_events) == 0, (
'No loop recovery events should be added when not stuck'
)
@pytest.mark.asyncio
async def test_controller_state_transition_after_loop_recovery(
self, mock_controller
):
"""Test that controller state transitions correctly after loop recovery."""
# Setup initial state
mock_controller.state.agent_state = AgentState.RUNNING
# Setup stuck detector to detect a loop
mock_controller._stuck_detector.is_stuck.return_value = True
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_type = 'monologue'
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 3
# Call attempt_loop_recovery
result = mock_controller.attempt_loop_recovery()
# Verify that recovery was attempted
assert result is True
# Verify that agent was paused
calls = mock_controller.event_stream.add_event.call_args_list
pause_found = False
for call in calls:
args, _ = call
# add_event only takes one argument (the event)
event = args[0]
if hasattr(event, 'agent_state') and event.agent_state == AgentState.PAUSED:
pause_found = True
break
assert pause_found, 'Agent should be paused after loop detection'
@pytest.mark.asyncio
async def test_controller_resumes_after_loop_recovery(self, mock_controller):
"""Test that controller can resume normal operation after loop recovery."""
# Setup stuck analysis
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Mock the _perform_loop_recovery method for this test
mock_controller._perform_loop_recovery = AsyncMock()
# Create LoopRecoveryAction with option 1
action = LoopRecoveryAction(option=1)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that recovery was performed
mock_controller._perform_loop_recovery.assert_called_once()
# Verify that agent can continue normal operation
# (This would be tested in integration tests with actual agent execution)
@pytest.mark.asyncio
async def test_controller_truncates_history_during_loop_recovery(
self, mock_controller
):
"""Test that controller correctly truncates history during loop recovery."""
# Setup mock history with events
from openhands.events.action import CmdRunAction
from openhands.events.observation import CmdOutputObservation, NullObservation
# Create a realistic history with 10 events
mock_history = []
# Add initial user message
user_msg = MessageAction(
content='Hello, help me with this task', wait_for_response=False
)
user_msg._source = 'user'
user_msg._id = 1
mock_history.append(user_msg)
# Add agent response
agent_obs = NullObservation(content='')
agent_obs._id = 2
mock_history.append(agent_obs)
# Add some commands and observations (simulating a loop)
for i in range(3, 11):
if i % 2 == 1: # Action
cmd = CmdRunAction(command='ls -la')
cmd._id = i
mock_history.append(cmd)
else: # Observation
obs = CmdOutputObservation(
content='file1.txt file2.txt', command='ls -la'
)
obs._id = i
obs._cause = i - 1
mock_history.append(obs)
# Set the mock history
mock_controller.state.history = mock_history
mock_controller.state.end_id = 10
# Setup stuck analysis to indicate loop starts at index 5
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Create LoopRecoveryAction with option 1 (truncate memory)
LoopRecoveryAction(option=1)
# Test actual truncation by calling the _perform_loop_recovery method directly
# Reset history for actual truncation test
mock_controller.state.history = mock_history.copy()
mock_controller.state.end_id = 10
# Call the actual _perform_loop_recovery method directly
print(
f'Before truncation: {len(mock_controller.state.history)} events, recovery_point={mock_controller._stuck_detector.stuck_analysis.loop_start_idx}'
)
print(
f'_perform_loop_recovery method: {mock_controller._perform_loop_recovery}'
)
print(
f'_truncate_memory_to_point method: {mock_controller._truncate_memory_to_point}'
)
await mock_controller._perform_loop_recovery(
mock_controller._stuck_detector.stuck_analysis
)
# Debug: print the actual history after truncation
print(f'History after truncation: {len(mock_controller.state.history)} events')
for i, event in enumerate(mock_controller.state.history):
print(f' Event {i}: id={event.id}, type={type(event).__name__}')
# Verify that history was truncated to the recovery point
# The recovery point is index 5, so we should keep events 0-4 (5 events)
assert len(mock_controller.state.history) == 5, (
f'Expected 5 events after truncation, got {len(mock_controller.state.history)}'
)
# Verify the specific events that remain
expected_ids = [1, 2, 3, 4, 5]
for i, event in enumerate(mock_controller.state.history):
assert event.id == expected_ids[i], (
f'Event at index {i} should have id {expected_ids[i]}, got {event.id}'
)
# Verify end_id was updated
assert mock_controller.state.end_id == 5, (
f'Expected end_id to be 5, got {mock_controller.state.end_id}'
)
def test_stuck_detection_config_option_exists(self):
"""Test that the enable_stuck_detection config option exists and defaults to True."""
from openhands.core.config.agent_config import AgentConfig
# Create a default config
config = AgentConfig()
# Verify the attribute exists and defaults to True
assert hasattr(config, 'enable_stuck_detection')
assert config.enable_stuck_detection is True
# Verify we can create a config with it disabled
config_disabled = AgentConfig(enable_stuck_detection=False)
assert config_disabled.enable_stuck_detection is False
def test_stuck_detection_config_from_env(self):
"""Test that enable_stuck_detection can be set via environment variable."""
import os
from openhands.core.config.agent_config import AgentConfig
# Test with enabled (default)
os.environ.pop('AGENT_ENABLE_STUCK_DETECTION', None)
config = AgentConfig()
assert config.enable_stuck_detection is True
# Test with explicitly disabled
os.environ['AGENT_ENABLE_STUCK_DETECTION'] = 'false'
# Need to reload for env var to take effect in real usage
# For this test, we just verify the config accepts the parameter
config_disabled = AgentConfig(enable_stuck_detection=False)
assert config_disabled.enable_stuck_detection is False
# Cleanup
os.environ.pop('AGENT_ENABLE_STUCK_DETECTION', None)