diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index d40cf1bf37..0171959c5a 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -403,9 +403,20 @@ class AgentController: return if self._pending_action: + logger.debug( + f'{self.agent.name} LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration} awaiting pending action to get executed: {self._pending_action}' + ) await asyncio.sleep(1) return + # check if agent got stuck before taking any action + if self._is_stuck(): + # This need to go BEFORE report_error to sync metrics + self.event_stream.add_event( + FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER + ) + return + if self.delegate is not None: assert self.delegate != self if self.delegate.get_agent_state() == AgentState.PAUSED: @@ -467,11 +478,6 @@ class AgentController: await self.update_state_after_step() logger.info(action, extra={'msg_type': 'ACTION'}) - if self._is_stuck(): - # This need to go BEFORE report_error to sync metrics - await self.set_agent_state_to(AgentState.ERROR) - await self.report_error('Agent got stuck in a loop') - async def _delegate_step(self): """Executes a single step of the delegate agent.""" logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...') diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 2a4e6d86ac..9b0522302f 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -12,7 +12,11 @@ from openhands.core.main import run_controller from openhands.core.schema import AgentState from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction -from openhands.events.observation import FatalErrorObservation +from openhands.events.observation import ( + ErrorObservation, + FatalErrorObservation, +) +from openhands.events.serialization import event_to_dict from openhands.llm import LLM from openhands.llm.metrics import Metrics from openhands.runtime.base import Runtime @@ -177,6 +181,78 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream): assert len(list(event_stream.get_events())) == 5 +@pytest.mark.asyncio +async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream): + config = AppConfig() + file_store = get_file_store(config.file_store, config.file_store_path) + event_stream = EventStream(sid='test', file_store=file_store) + + agent = MagicMock(spec=Agent) + # a random message to send to the runtime + event = CmdRunAction(command='ls') + + def agent_step_fn(state): + print(f'agent_step_fn received state: {state}') + return event + + agent.step = agent_step_fn + agent.llm = MagicMock(spec=LLM) + agent.llm.metrics = Metrics() + agent.llm.config = config.get_llm_config() + runtime = MagicMock(spec=Runtime) + + async def on_event(event: Event): + if isinstance(event, CmdRunAction): + non_fatal_error_obs = ErrorObservation( + 'Non fatal error here to trigger loop' + ) + non_fatal_error_obs._cause = event.id + await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER) + + event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event) + runtime.event_stream = event_stream + + state = await run_controller( + config=config, + initial_user_action=MessageAction(content='Test message'), + runtime=runtime, + sid='test', + agent=agent, + fake_user_response_fn=lambda _: 'repeat', + ) + events = list(event_stream.get_events()) + print(f'state: {state}') + for i, event in enumerate(events): + print(f'event {i}: {event_to_dict(event)}') + + assert state.iteration == 4 + assert len(events) == 12 + # check the eventstream have 4 pairs of repeated actions and observations + repeating_actions_and_observations = events[2:10] + for action, observation in zip( + repeating_actions_and_observations[0::2], + repeating_actions_and_observations[1::2], + ): + action_dict = event_to_dict(action) + observation_dict = event_to_dict(observation) + assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls' + assert ( + observation_dict['observation'] == 'error' + and observation_dict['content'] == 'Non fatal error here to trigger loop' + ) + last_event = event_to_dict(events[-1]) + assert last_event['extras']['agent_state'] == 'error' + assert last_event['observation'] == 'agent_state_changed' + + # it will first become AgentState.ERROR, then become AgentState.STOPPED + # in side run_controller (since the while loop + sleep no longer loop) + assert state.agent_state == AgentState.STOPPED + assert ( + state.last_error + == 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop' + ) + + @pytest.mark.asyncio @pytest.mark.parametrize( 'delegate_state',