mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: chuckbutkus <chuck@all-hands.dev>
411 lines
16 KiB
Python
411 lines
16 KiB
Python
"""Tests for agent controller loop recovery functionality."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from openhands.controller.agent_controller import AgentController
|
|
from openhands.controller.stuck import StuckDetector
|
|
from openhands.core.schema import AgentState
|
|
from openhands.events import EventStream
|
|
from openhands.events.action import LoopRecoveryAction, MessageAction
|
|
from openhands.events.observation import LoopDetectionObservation
|
|
from openhands.server.services.conversation_stats import ConversationStats
|
|
from openhands.storage.memory import InMemoryFileStore
|
|
|
|
|
|
class TestAgentControllerLoopRecovery:
|
|
"""Tests for agent controller loop recovery functionality."""
|
|
|
|
@pytest.fixture
|
|
def mock_controller(self):
|
|
"""Create a mock agent controller for testing."""
|
|
# Create mock dependencies
|
|
mock_event_stream = MagicMock(
|
|
spec=EventStream,
|
|
event_stream=EventStream(
|
|
sid='test-session-id', file_store=InMemoryFileStore({})
|
|
),
|
|
)
|
|
mock_event_stream.sid = 'test-session-id'
|
|
mock_event_stream.get_latest_event_id.return_value = 0
|
|
|
|
mock_conversation_stats = MagicMock(spec=ConversationStats)
|
|
|
|
mock_agent = MagicMock()
|
|
mock_agent.act = AsyncMock()
|
|
|
|
# Create controller with correct parameters
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
conversation_stats=mock_conversation_stats,
|
|
iteration_delta=100,
|
|
headless_mode=True,
|
|
)
|
|
|
|
# Mock state properties
|
|
controller.state.history = []
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.state.iteration_flag = MagicMock()
|
|
controller.state.iteration_flag.current_value = 10
|
|
|
|
# Mock stuck detector
|
|
controller._stuck_detector = MagicMock(spec=StuckDetector)
|
|
controller._stuck_detector.stuck_analysis = None
|
|
controller._stuck_detector.is_stuck = MagicMock(return_value=False)
|
|
|
|
return controller
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_controller_detects_loop_and_produces_observation(
|
|
self, mock_controller
|
|
):
|
|
"""Test that controller detects loops and produces LoopDetectionObservation."""
|
|
# Setup stuck detector to detect a loop
|
|
mock_controller._stuck_detector.is_stuck.return_value = True
|
|
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
|
mock_controller._stuck_detector.stuck_analysis.loop_type = (
|
|
'repeating_action_observation'
|
|
)
|
|
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
|
|
|
# Call attempt_loop_recovery
|
|
result = mock_controller.attempt_loop_recovery()
|
|
|
|
# Verify that loop recovery was attempted
|
|
assert result is True
|
|
|
|
# Verify that LoopDetectionObservation was added to event stream
|
|
mock_controller.event_stream.add_event.assert_called()
|
|
|
|
# Check that LoopDetectionObservation was created
|
|
calls = mock_controller.event_stream.add_event.call_args_list
|
|
loop_detection_found = False
|
|
pause_action_found = False
|
|
|
|
for call in calls:
|
|
args, _ = call
|
|
# add_event only takes one argument (the event)
|
|
event = args[0]
|
|
|
|
if isinstance(event, LoopDetectionObservation):
|
|
loop_detection_found = True
|
|
assert 'Agent detected in a loop!' in event.content
|
|
assert 'repeating_action_observation' in event.content
|
|
assert 'Loop detected at iteration 10' in event.content
|
|
elif (
|
|
hasattr(event, 'agent_state') and event.agent_state == AgentState.PAUSED
|
|
):
|
|
pause_action_found = True
|
|
|
|
assert loop_detection_found, 'LoopDetectionObservation should be created'
|
|
assert pause_action_found, 'Agent should be paused'
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_controller_handles_loop_recovery_action_option_1(
|
|
self, mock_controller
|
|
):
|
|
"""Test that controller handles LoopRecoveryAction with option 1."""
|
|
# Setup stuck analysis
|
|
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
|
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
|
|
|
# Mock the _perform_loop_recovery method for this test
|
|
mock_controller._perform_loop_recovery = AsyncMock()
|
|
|
|
# Create LoopRecoveryAction with option 1
|
|
action = LoopRecoveryAction(option=1)
|
|
|
|
# Call _handle_loop_recovery_action
|
|
await mock_controller._handle_loop_recovery_action(action)
|
|
|
|
# Verify that _perform_loop_recovery was called
|
|
mock_controller._perform_loop_recovery.assert_called_once_with(
|
|
mock_controller._stuck_detector.stuck_analysis
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_controller_handles_loop_recovery_action_option_2(
|
|
self, mock_controller
|
|
):
|
|
"""Test that controller handles LoopRecoveryAction with option 2."""
|
|
# Setup stuck analysis
|
|
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
|
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
|
|
|
# Mock the _restart_with_last_user_message method for this test
|
|
mock_controller._restart_with_last_user_message = AsyncMock()
|
|
|
|
# Create LoopRecoveryAction with option 2
|
|
action = LoopRecoveryAction(option=2)
|
|
|
|
# Call _handle_loop_recovery_action
|
|
await mock_controller._handle_loop_recovery_action(action)
|
|
|
|
# Verify that _restart_with_last_user_message was called
|
|
mock_controller._restart_with_last_user_message.assert_called_once_with(
|
|
mock_controller._stuck_detector.stuck_analysis
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_controller_handles_loop_recovery_action_option_3(
|
|
self, mock_controller
|
|
):
|
|
"""Test that controller handles LoopRecoveryAction with option 3 (stop)."""
|
|
# Setup stuck analysis
|
|
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
|
|
|
# Mock the set_agent_state_to method for this test
|
|
mock_controller.set_agent_state_to = AsyncMock()
|
|
|
|
# Create LoopRecoveryAction with option 3
|
|
action = LoopRecoveryAction(option=3)
|
|
|
|
# Call _handle_loop_recovery_action
|
|
await mock_controller._handle_loop_recovery_action(action)
|
|
|
|
# Verify that set_agent_state_to was called with STOPPED
|
|
mock_controller.set_agent_state_to.assert_called_once_with(AgentState.STOPPED)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_controller_ignores_loop_recovery_without_stuck_analysis(
|
|
self, mock_controller
|
|
):
|
|
"""Test that controller ignores LoopRecoveryAction when no stuck analysis exists."""
|
|
# Ensure no stuck analysis
|
|
mock_controller._stuck_detector.stuck_analysis = None
|
|
|
|
# Mock all recovery methods for this test
|
|
mock_controller._perform_loop_recovery = AsyncMock()
|
|
mock_controller._restart_with_last_user_message = AsyncMock()
|
|
mock_controller.set_agent_state_to = AsyncMock()
|
|
|
|
# Create LoopRecoveryAction
|
|
action = LoopRecoveryAction(option=1)
|
|
|
|
# Call _handle_loop_recovery_action
|
|
await mock_controller._handle_loop_recovery_action(action)
|
|
|
|
# Verify that no recovery methods were called
|
|
mock_controller._perform_loop_recovery.assert_not_called()
|
|
mock_controller._restart_with_last_user_message.assert_not_called()
|
|
mock_controller.set_agent_state_to.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_controller_no_loop_recovery_when_not_stuck(self, mock_controller):
|
|
"""Test that controller doesn't attempt recovery when not stuck."""
|
|
# Setup no stuck analysis
|
|
mock_controller._stuck_detector.stuck_analysis = None
|
|
|
|
# Reset the mock to ignore any previous calls (like system message)
|
|
mock_controller.event_stream.add_event.reset_mock()
|
|
|
|
# Call attempt_loop_recovery
|
|
result = mock_controller.attempt_loop_recovery()
|
|
|
|
# Verify that no recovery was attempted
|
|
assert result is False
|
|
|
|
# Verify that no loop recovery events were added to the stream
|
|
# (Note: there might be other events, but no loop recovery specific ones)
|
|
calls = mock_controller.event_stream.add_event.call_args_list
|
|
loop_recovery_events = [
|
|
call
|
|
for call in calls
|
|
if len(call[0]) > 0
|
|
and (
|
|
isinstance(call[0][0], LoopDetectionObservation)
|
|
or (
|
|
hasattr(call[0][0], 'agent_state')
|
|
and call[0][0].agent_state == AgentState.PAUSED
|
|
)
|
|
)
|
|
]
|
|
assert len(loop_recovery_events) == 0, (
|
|
'No loop recovery events should be added when not stuck'
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_controller_state_transition_after_loop_recovery(
|
|
self, mock_controller
|
|
):
|
|
"""Test that controller state transitions correctly after loop recovery."""
|
|
# Setup initial state
|
|
mock_controller.state.agent_state = AgentState.RUNNING
|
|
|
|
# Setup stuck detector to detect a loop
|
|
mock_controller._stuck_detector.is_stuck.return_value = True
|
|
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
|
mock_controller._stuck_detector.stuck_analysis.loop_type = 'monologue'
|
|
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 3
|
|
|
|
# Call attempt_loop_recovery
|
|
result = mock_controller.attempt_loop_recovery()
|
|
|
|
# Verify that recovery was attempted
|
|
assert result is True
|
|
|
|
# Verify that agent was paused
|
|
calls = mock_controller.event_stream.add_event.call_args_list
|
|
pause_found = False
|
|
for call in calls:
|
|
args, _ = call
|
|
# add_event only takes one argument (the event)
|
|
event = args[0]
|
|
if hasattr(event, 'agent_state') and event.agent_state == AgentState.PAUSED:
|
|
pause_found = True
|
|
break
|
|
|
|
assert pause_found, 'Agent should be paused after loop detection'
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_controller_resumes_after_loop_recovery(self, mock_controller):
|
|
"""Test that controller can resume normal operation after loop recovery."""
|
|
# Setup stuck analysis
|
|
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
|
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
|
|
|
# Mock the _perform_loop_recovery method for this test
|
|
mock_controller._perform_loop_recovery = AsyncMock()
|
|
|
|
# Create LoopRecoveryAction with option 1
|
|
action = LoopRecoveryAction(option=1)
|
|
|
|
# Call _handle_loop_recovery_action
|
|
await mock_controller._handle_loop_recovery_action(action)
|
|
|
|
# Verify that recovery was performed
|
|
mock_controller._perform_loop_recovery.assert_called_once()
|
|
|
|
# Verify that agent can continue normal operation
|
|
# (This would be tested in integration tests with actual agent execution)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_controller_truncates_history_during_loop_recovery(
|
|
self, mock_controller
|
|
):
|
|
"""Test that controller correctly truncates history during loop recovery."""
|
|
# Setup mock history with events
|
|
from openhands.events.action import CmdRunAction
|
|
from openhands.events.observation import CmdOutputObservation, NullObservation
|
|
|
|
# Create a realistic history with 10 events
|
|
mock_history = []
|
|
|
|
# Add initial user message
|
|
user_msg = MessageAction(
|
|
content='Hello, help me with this task', wait_for_response=False
|
|
)
|
|
user_msg._source = 'user'
|
|
user_msg._id = 1
|
|
mock_history.append(user_msg)
|
|
|
|
# Add agent response
|
|
agent_obs = NullObservation(content='')
|
|
agent_obs._id = 2
|
|
mock_history.append(agent_obs)
|
|
|
|
# Add some commands and observations (simulating a loop)
|
|
for i in range(3, 11):
|
|
if i % 2 == 1: # Action
|
|
cmd = CmdRunAction(command='ls -la')
|
|
cmd._id = i
|
|
mock_history.append(cmd)
|
|
else: # Observation
|
|
obs = CmdOutputObservation(
|
|
content='file1.txt file2.txt', command='ls -la'
|
|
)
|
|
obs._id = i
|
|
obs._cause = i - 1
|
|
mock_history.append(obs)
|
|
|
|
# Set the mock history
|
|
mock_controller.state.history = mock_history
|
|
mock_controller.state.end_id = 10
|
|
|
|
# Setup stuck analysis to indicate loop starts at index 5
|
|
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
|
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
|
|
|
# Create LoopRecoveryAction with option 1 (truncate memory)
|
|
LoopRecoveryAction(option=1)
|
|
|
|
# Test actual truncation by calling the _perform_loop_recovery method directly
|
|
# Reset history for actual truncation test
|
|
mock_controller.state.history = mock_history.copy()
|
|
mock_controller.state.end_id = 10
|
|
|
|
# Call the actual _perform_loop_recovery method directly
|
|
print(
|
|
f'Before truncation: {len(mock_controller.state.history)} events, recovery_point={mock_controller._stuck_detector.stuck_analysis.loop_start_idx}'
|
|
)
|
|
print(
|
|
f'_perform_loop_recovery method: {mock_controller._perform_loop_recovery}'
|
|
)
|
|
print(
|
|
f'_truncate_memory_to_point method: {mock_controller._truncate_memory_to_point}'
|
|
)
|
|
await mock_controller._perform_loop_recovery(
|
|
mock_controller._stuck_detector.stuck_analysis
|
|
)
|
|
|
|
# Debug: print the actual history after truncation
|
|
print(f'History after truncation: {len(mock_controller.state.history)} events')
|
|
for i, event in enumerate(mock_controller.state.history):
|
|
print(f' Event {i}: id={event.id}, type={type(event).__name__}')
|
|
|
|
# Verify that history was truncated to the recovery point
|
|
# The recovery point is index 5, so we should keep events 0-4 (5 events)
|
|
assert len(mock_controller.state.history) == 5, (
|
|
f'Expected 5 events after truncation, got {len(mock_controller.state.history)}'
|
|
)
|
|
|
|
# Verify the specific events that remain
|
|
expected_ids = [1, 2, 3, 4, 5]
|
|
for i, event in enumerate(mock_controller.state.history):
|
|
assert event.id == expected_ids[i], (
|
|
f'Event at index {i} should have id {expected_ids[i]}, got {event.id}'
|
|
)
|
|
|
|
# Verify end_id was updated
|
|
assert mock_controller.state.end_id == 5, (
|
|
f'Expected end_id to be 5, got {mock_controller.state.end_id}'
|
|
)
|
|
|
|
def test_stuck_detection_config_option_exists(self):
|
|
"""Test that the enable_stuck_detection config option exists and defaults to True."""
|
|
from openhands.core.config.agent_config import AgentConfig
|
|
|
|
# Create a default config
|
|
config = AgentConfig()
|
|
|
|
# Verify the attribute exists and defaults to True
|
|
assert hasattr(config, 'enable_stuck_detection')
|
|
assert config.enable_stuck_detection is True
|
|
|
|
# Verify we can create a config with it disabled
|
|
config_disabled = AgentConfig(enable_stuck_detection=False)
|
|
assert config_disabled.enable_stuck_detection is False
|
|
|
|
def test_stuck_detection_config_from_env(self):
|
|
"""Test that enable_stuck_detection can be set via environment variable."""
|
|
import os
|
|
|
|
from openhands.core.config.agent_config import AgentConfig
|
|
|
|
# Test with enabled (default)
|
|
os.environ.pop('AGENT_ENABLE_STUCK_DETECTION', None)
|
|
config = AgentConfig()
|
|
assert config.enable_stuck_detection is True
|
|
|
|
# Test with explicitly disabled
|
|
os.environ['AGENT_ENABLE_STUCK_DETECTION'] = 'false'
|
|
# Need to reload for env var to take effect in real usage
|
|
# For this test, we just verify the config accepts the parameter
|
|
config_disabled = AgentConfig(enable_stuck_detection=False)
|
|
assert config_disabled.enable_stuck_detection is False
|
|
|
|
# Cleanup
|
|
os.environ.pop('AGENT_ENABLE_STUCK_DETECTION', None)
|