From 4998b5de32d8b21110207904ebb3a51e9d7e4771 Mon Sep 17 00:00:00 2001 From: OpenHands Date: Mon, 16 Dec 2024 16:28:23 -0500 Subject: [PATCH] Fix issue #5559: The turn limit should be measured from the last user interaction (#5560) Co-authored-by: Graham Neubig Co-authored-by: Engel Nyst --- evaluation/benchmarks/swe_bench/run_infer.py | 3 +- openhands/controller/agent_controller.py | 16 +++++ pyproject.toml | 3 +- tests/unit/test_agent_controller.py | 45 +++++++++++--- tests/unit/test_iteration_limit.py | 62 ++++++++++++++++++++ 5 files changed, 119 insertions(+), 10 deletions(-) create mode 100644 tests/unit/test_iteration_limit.py diff --git a/evaluation/benchmarks/swe_bench/run_infer.py b/evaluation/benchmarks/swe_bench/run_infer.py index 01111f75d1..134c98cb96 100644 --- a/evaluation/benchmarks/swe_bench/run_infer.py +++ b/evaluation/benchmarks/swe_bench/run_infer.py @@ -9,7 +9,6 @@ import toml from datasets import load_dataset import openhands.agenthub - from evaluation.utils.shared import ( EvalException, EvalMetadata, @@ -76,7 +75,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata): '4. Rerun your reproduce script and confirm that the error is fixed!\n' '5. Think about edgecases and make sure your fix handles them as well\n' "Your thinking should be thorough and so it's fine if it's very long.\n" - ) + ) if RUN_WITH_BROWSING: instruction += ( diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index b2de6dd7d5..d3cde88f63 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -312,6 +312,20 @@ class AgentController: str(action), extra={'msg_type': 'ACTION', 'event_source': EventSource.USER}, ) + # Extend max iterations when the user sends a message (only in non-headless mode) + if self._initial_max_iterations is not None and not self.headless_mode: + self.state.max_iterations = ( + self.state.iteration + self._initial_max_iterations + ) + if ( + self.state.traffic_control_state == TrafficControlState.THROTTLING + or self.state.traffic_control_state == TrafficControlState.PAUSED + ): + self.state.traffic_control_state = TrafficControlState.NORMAL + self.log( + 'debug', + f'Extended max iterations to {self.state.max_iterations} after user message', + ) if self.get_agent_state() != AgentState.RUNNING: await self.set_agent_state_to(AgentState.RUNNING) elif action.source == EventSource.AGENT and action.wait_for_response: @@ -342,6 +356,7 @@ class AgentController: elif ( new_state == AgentState.RUNNING and self.state.agent_state == AgentState.PAUSED + # TODO: do we really need both THROTTLING and PAUSED states, or can we clean up one of them completely? and self.state.traffic_control_state == TrafficControlState.THROTTLING ): # user intends to interrupt traffic control and let the task resume temporarily @@ -351,6 +366,7 @@ class AgentController: self.state.iteration is not None and self.state.max_iterations is not None and self._initial_max_iterations is not None + and not self.headless_mode ): if self.state.iteration >= self.state.max_iterations: self.state.max_iterations += self._initial_max_iterations diff --git a/pyproject.toml b/pyproject.toml index 2b0d3ca1e8..2ed685a7c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ reportlab = "*" [tool.coverage.run] concurrency = ["gevent"] + [tool.poetry.group.runtime.dependencies] jupyterlab = "*" notebook = "*" @@ -107,7 +108,6 @@ jupyter_kernel_gateway = "*" flake8 = "*" opencv-python = "*" - [build-system] build-backend = "poetry.core.masonry.api" requires = [ @@ -130,6 +130,7 @@ ignore = ["D1"] [tool.ruff.lint.pydocstyle] convention = "google" + [tool.poetry.group.evaluation.dependencies] streamlit = "*" whatthepatch = "*" diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 9c07969bd0..08fe0e0f55 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -6,7 +6,7 @@ import pytest from openhands.controller.agent import Agent from openhands.controller.agent_controller import AgentController -from openhands.controller.state.state import TrafficControlState +from openhands.controller.state.state import State, TrafficControlState from openhands.core.config import AppConfig from openhands.core.main import run_controller from openhands.core.schema import AgentState @@ -41,7 +41,9 @@ def mock_agent(): @pytest.fixture def mock_event_stream(): - return MagicMock(spec=EventStream) + mock = MagicMock(spec=EventStream) + mock.get_latest_event_id.return_value = 0 + return mock @pytest.fixture @@ -278,7 +280,9 @@ async def test_delegate_step_different_states( @pytest.mark.asyncio -async def test_step_max_iterations(mock_agent, mock_event_stream): +async def test_max_iterations_extension(mock_agent, mock_event_stream): + # Test with headless_mode=False - should extend max_iterations + initial_state = State(max_iterations=10) controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, @@ -286,18 +290,34 @@ async def test_step_max_iterations(mock_agent, mock_event_stream): sid='test', confirmation_mode=False, headless_mode=False, + initial_state=initial_state, ) controller.state.agent_state = AgentState.RUNNING controller.state.iteration = 10 assert controller.state.traffic_control_state == TrafficControlState.NORMAL + + # Trigger throttling by calling _step() when we hit max_iterations await controller._step() assert controller.state.traffic_control_state == TrafficControlState.THROTTLING assert controller.state.agent_state == AgentState.ERROR + + # Simulate a new user message + message_action = MessageAction(content='Test message') + message_action._source = EventSource.USER + await controller.on_event(message_action) + + # Max iterations should be extended to current iteration + initial max_iterations + assert ( + controller.state.max_iterations == 20 + ) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10) + assert controller.state.traffic_control_state == TrafficControlState.NORMAL + assert controller.state.agent_state == AgentState.RUNNING + + # Close the controller to clean up await controller.close() - -@pytest.mark.asyncio -async def test_step_max_iterations_headless(mock_agent, mock_event_stream): + # Test with headless_mode=True - should NOT extend max_iterations + initial_state = State(max_iterations=10) controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, @@ -305,13 +325,24 @@ async def test_step_max_iterations_headless(mock_agent, mock_event_stream): sid='test', confirmation_mode=False, headless_mode=True, + initial_state=initial_state, ) controller.state.agent_state = AgentState.RUNNING controller.state.iteration = 10 assert controller.state.traffic_control_state == TrafficControlState.NORMAL + + # Simulate a new user message + message_action = MessageAction(content='Test message') + message_action._source = EventSource.USER + await controller.on_event(message_action) + + # Max iterations should NOT be extended in headless mode + assert controller.state.max_iterations == 10 # Original value unchanged + + # Trigger throttling by calling _step() when we hit max_iterations await controller._step() + assert controller.state.traffic_control_state == TrafficControlState.THROTTLING - # In headless mode, throttling results in an error assert controller.state.agent_state == AgentState.ERROR await controller.close() diff --git a/tests/unit/test_iteration_limit.py b/tests/unit/test_iteration_limit.py new file mode 100644 index 0000000000..4520231a0d --- /dev/null +++ b/tests/unit/test_iteration_limit.py @@ -0,0 +1,62 @@ +import asyncio + +import pytest + +from openhands.controller.agent_controller import AgentController +from openhands.core.schema import AgentState +from openhands.events import EventStream +from openhands.events.action import MessageAction +from openhands.events.event import EventSource + + +class DummyAgent: + def __init__(self): + self.name = 'dummy' + self.llm = type( + 'DummyLLM', + (), + {'metrics': type('DummyMetrics', (), {'merge': lambda x: None})()}, + )() + + def reset(self): + pass + + +@pytest.mark.asyncio +async def test_iteration_limit_extends_on_user_message(): + # Initialize test components + from openhands.storage.memory import InMemoryFileStore + + file_store = InMemoryFileStore() + event_stream = EventStream(sid='test', file_store=file_store) + agent = DummyAgent() + initial_max_iterations = 100 + controller = AgentController( + agent=agent, + event_stream=event_stream, + max_iterations=initial_max_iterations, + sid='test', + headless_mode=False, + ) + + # Set initial state + await controller.set_agent_state_to(AgentState.RUNNING) + controller.state.iteration = 90 # Close to the limit + assert controller.state.max_iterations == initial_max_iterations + + # Simulate user message + user_message = MessageAction('test message', EventSource.USER) + event_stream.add_event(user_message, EventSource.USER) + await asyncio.sleep(0.1) # Give time for event to be processed + + # Verify max_iterations was extended + assert controller.state.max_iterations == 90 + initial_max_iterations + + # Simulate more iterations and another user message + controller.state.iteration = 180 # Close to new limit + user_message2 = MessageAction('another message', EventSource.USER) + event_stream.add_event(user_message2, EventSource.USER) + await asyncio.sleep(0.1) # Give time for event to be processed + + # Verify max_iterations was extended again + assert controller.state.max_iterations == 180 + initial_max_iterations