fix(controller): stop when run into loop (#4579)

This commit is contained in:
Xingyao Wang 2024-10-26 19:40:58 -05:00 committed by GitHub
parent be3cbb045e
commit 98d4884ced
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 6 deletions

View File

@ -403,9 +403,20 @@ class AgentController:
return
if self._pending_action:
logger.debug(
f'{self.agent.name} LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration} awaiting pending action to get executed: {self._pending_action}'
)
await asyncio.sleep(1)
return
# check if agent got stuck before taking any action
if self._is_stuck():
# This need to go BEFORE report_error to sync metrics
self.event_stream.add_event(
FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER
)
return
if self.delegate is not None:
assert self.delegate != self
if self.delegate.get_agent_state() == AgentState.PAUSED:
@ -467,11 +478,6 @@ class AgentController:
await self.update_state_after_step()
logger.info(action, extra={'msg_type': 'ACTION'})
if self._is_stuck():
# This need to go BEFORE report_error to sync metrics
await self.set_agent_state_to(AgentState.ERROR)
await self.report_error('Agent got stuck in a loop')
async def _delegate_step(self):
"""Executes a single step of the delegate agent."""
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')

View File

@ -12,7 +12,11 @@ from openhands.core.main import run_controller
from openhands.core.schema import AgentState
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
from openhands.events.observation import FatalErrorObservation
from openhands.events.observation import (
ErrorObservation,
FatalErrorObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.runtime.base import Runtime
@ -177,6 +181,78 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
assert len(list(event_stream.get_events())) == 5
@pytest.mark.asyncio
async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
config = AppConfig()
file_store = get_file_store(config.file_store, config.file_store_path)
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
# a random message to send to the runtime
event = CmdRunAction(command='ls')
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
return event
agent.step = agent_step_fn
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=Runtime)
async def on_event(event: Event):
if isinstance(event, CmdRunAction):
non_fatal_error_obs = ErrorObservation(
'Non fatal error here to trigger loop'
)
non_fatal_error_obs._cause = event.id
await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
runtime.event_stream = event_stream
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
)
events = list(event_stream.get_events())
print(f'state: {state}')
for i, event in enumerate(events):
print(f'event {i}: {event_to_dict(event)}')
assert state.iteration == 4
assert len(events) == 12
# check the eventstream have 4 pairs of repeated actions and observations
repeating_actions_and_observations = events[2:10]
for action, observation in zip(
repeating_actions_and_observations[0::2],
repeating_actions_and_observations[1::2],
):
action_dict = event_to_dict(action)
observation_dict = event_to_dict(observation)
assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
assert (
observation_dict['observation'] == 'error'
and observation_dict['content'] == 'Non fatal error here to trigger loop'
)
last_event = event_to_dict(events[-1])
assert last_event['extras']['agent_state'] == 'error'
assert last_event['observation'] == 'agent_state_changed'
# it will first become AgentState.ERROR, then become AgentState.STOPPED
# in side run_controller (since the while loop + sleep no longer loop)
assert state.agent_state == AgentState.STOPPED
assert (
state.last_error
== 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop'
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'delegate_state',