mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
fix(controller): stop when run into loop (#4579)
This commit is contained in:
parent
be3cbb045e
commit
98d4884ced
@ -403,9 +403,20 @@ class AgentController:
|
||||
return
|
||||
|
||||
if self._pending_action:
|
||||
logger.debug(
|
||||
f'{self.agent.name} LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration} awaiting pending action to get executed: {self._pending_action}'
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
return
|
||||
|
||||
# check if agent got stuck before taking any action
|
||||
if self._is_stuck():
|
||||
# This need to go BEFORE report_error to sync metrics
|
||||
self.event_stream.add_event(
|
||||
FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER
|
||||
)
|
||||
return
|
||||
|
||||
if self.delegate is not None:
|
||||
assert self.delegate != self
|
||||
if self.delegate.get_agent_state() == AgentState.PAUSED:
|
||||
@ -467,11 +478,6 @@ class AgentController:
|
||||
await self.update_state_after_step()
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
|
||||
if self._is_stuck():
|
||||
# This need to go BEFORE report_error to sync metrics
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
await self.report_error('Agent got stuck in a loop')
|
||||
|
||||
async def _delegate_step(self):
|
||||
"""Executes a single step of the delegate agent."""
|
||||
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
|
||||
|
||||
@ -12,7 +12,11 @@ from openhands.core.main import run_controller
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
||||
from openhands.events.observation import FatalErrorObservation
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
FatalErrorObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.runtime.base import Runtime
|
||||
@ -177,6 +181,78 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
|
||||
assert len(list(event_stream.get_events())) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
|
||||
config = AppConfig()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
# a random message to send to the runtime
|
||||
event = CmdRunAction(command='ls')
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
return event
|
||||
|
||||
agent.step = agent_step_fn
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
|
||||
async def on_event(event: Event):
|
||||
if isinstance(event, CmdRunAction):
|
||||
non_fatal_error_obs = ErrorObservation(
|
||||
'Non fatal error here to trigger loop'
|
||||
)
|
||||
non_fatal_error_obs._cause = event.id
|
||||
await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
)
|
||||
events = list(event_stream.get_events())
|
||||
print(f'state: {state}')
|
||||
for i, event in enumerate(events):
|
||||
print(f'event {i}: {event_to_dict(event)}')
|
||||
|
||||
assert state.iteration == 4
|
||||
assert len(events) == 12
|
||||
# check the eventstream have 4 pairs of repeated actions and observations
|
||||
repeating_actions_and_observations = events[2:10]
|
||||
for action, observation in zip(
|
||||
repeating_actions_and_observations[0::2],
|
||||
repeating_actions_and_observations[1::2],
|
||||
):
|
||||
action_dict = event_to_dict(action)
|
||||
observation_dict = event_to_dict(observation)
|
||||
assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
|
||||
assert (
|
||||
observation_dict['observation'] == 'error'
|
||||
and observation_dict['content'] == 'Non fatal error here to trigger loop'
|
||||
)
|
||||
last_event = event_to_dict(events[-1])
|
||||
assert last_event['extras']['agent_state'] == 'error'
|
||||
assert last_event['observation'] == 'agent_state_changed'
|
||||
|
||||
# it will first become AgentState.ERROR, then become AgentState.STOPPED
|
||||
# in side run_controller (since the while loop + sleep no longer loop)
|
||||
assert state.agent_state == AgentState.STOPPED
|
||||
assert (
|
||||
state.last_error
|
||||
== 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'delegate_state',
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user