fix: allow to continue when the agent is stuck in interactive mode (#5597)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Engel Nyst 2024-12-14 20:49:04 +01:00 committed by GitHub
parent 7ef6fa666d
commit f0257c793b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 119 additions and 77 deletions

View File

@ -319,7 +319,7 @@ class AgentController:
def _reset(self) -> None:
"""Resets the agent controller"""
self.almost_stuck = 0
self._pending_action = None
self.agent.reset()
@ -912,7 +912,7 @@ class AgentController:
if self.delegate and self.delegate._is_stuck():
return True
return self._stuck_detector.is_stuck()
return self._stuck_detector.is_stuck(self.headless_mode)
def __repr__(self):
return (

View File

@ -94,7 +94,7 @@ class State:
end_id: int = -1
# truncation_id tracks where to load history after context window truncation
truncation_id: int = -1
almost_stuck: int = 0
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
# NOTE: This will never be used by the controller, but it can be used by different
# evaluation tasks to store extra data needed to track the progress/state of the task.

View File

@ -24,16 +24,44 @@ class StuckDetector:
def __init__(self, state: State):
self.state = state
def is_stuck(self):
# filter out MessageAction with source='user' from history
def is_stuck(self, headless_mode: bool = True):
"""Checks if the agent is stuck in a loop.
Args:
headless_mode: Matches AgentController's headless_mode.
If True: Consider all history (automated/testing)
If False: Consider only history after last user message (interactive)
Returns:
bool: True if the agent is stuck in a loop, False otherwise.
"""
if not headless_mode:
# In interactive mode, only look at history after the last user message
last_user_msg_idx = -1
for i, event in enumerate(reversed(self.state.history)):
if (
isinstance(event, MessageAction)
and event.source == EventSource.USER
):
last_user_msg_idx = len(self.state.history) - i - 1
break
history_to_check = self.state.history[last_user_msg_idx + 1 :]
else:
# In headless mode, look at all history
history_to_check = self.state.history
# Filter out user messages and null events
filtered_history = [
event
for event in self.state.history
for event in history_to_check
if not (
# Filter works elegantly in both modes:
# - In headless: actively filters out user messages from full history
# - In non-headless: no-op since we already sliced after last user message
(isinstance(event, MessageAction) and event.source == EventSource.USER)
or
# there might be some NullAction or NullObservation in the history at least for now
isinstance(event, (NullAction, NullObservation))
or isinstance(event, (NullAction, NullObservation))
)
]
@ -81,43 +109,19 @@ class StuckDetector:
# it takes 4 actions and 4 observations to detect a loop
# assert len(last_actions) == 4 and len(last_observations) == 4
# reset almost_stuck reminder
self.state.almost_stuck = 0
# almost stuck? if two actions, obs are the same, we're almost stuck
if len(last_actions) >= 2 and len(last_observations) >= 2:
# Check for a loop of 4 identical action-observation pairs
if len(last_actions) == 4 and len(last_observations) == 4:
actions_equal = all(
self._eq_no_pid(last_actions[0], action) for action in last_actions[:2]
self._eq_no_pid(last_actions[0], action) for action in last_actions
)
observations_equal = all(
self._eq_no_pid(last_observations[0], observation)
for observation in last_observations[:2]
for observation in last_observations
)
# the last two actions and obs are the same?
if actions_equal and observations_equal:
self.state.almost_stuck = 2
# the last three actions and observations are the same?
if len(last_actions) >= 3 and len(last_observations) >= 3:
if (
actions_equal
and observations_equal
and self._eq_no_pid(last_actions[0], last_actions[2])
and self._eq_no_pid(last_observations[0], last_observations[2])
):
self.state.almost_stuck = 1
if len(last_actions) == 4 and len(last_observations) == 4:
if (
actions_equal
and observations_equal
and self._eq_no_pid(last_actions[0], last_actions[3])
and self._eq_no_pid(last_observations[0], last_observations[3])
):
logger.warning('Action, Observation loop detected')
self.state.almost_stuck = 0
return True
logger.warning('Action, Observation loop detected')
return True
return False

View File

@ -112,7 +112,66 @@ class TestStuckDetector:
# cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(headless_mode=True) is False
def test_interactive_mode_resets_after_user_message(
self, stuck_detector: StuckDetector
):
state = stuck_detector.state
# First add some actions that would be stuck in non-UI mode
for i in range(4):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(
content='', command='ls', command_id=i
)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)
# In headless mode, this should be stuck
assert stuck_detector.is_stuck(headless_mode=True) is True
# with the UI, it will ALSO be stuck initially
assert stuck_detector.is_stuck(headless_mode=False) is True
# Add a user message
message_action = MessageAction(content='Hello', wait_for_response=False)
message_action._source = EventSource.USER
state.history.append(message_action)
# In not-headless mode, this should not be stuck because we ignore history before user message
assert stuck_detector.is_stuck(headless_mode=False) is False
# But in headless mode, this should be still stuck because user messages do not count
assert stuck_detector.is_stuck(headless_mode=True) is True
# Add two more identical actions - still not stuck because we need at least 3
for i in range(2):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i + 4
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(
content='', command='ls', command_id=i + 4
)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)
assert stuck_detector.is_stuck(headless_mode=False) is False
# Add two more identical actions - now it should be stuck
for i in range(2):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i + 6
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(
content='', command='ls', command_id=i + 6
)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)
assert stuck_detector.is_stuck(headless_mode=False) is True
def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetector):
state = stuck_detector.state
@ -148,8 +207,7 @@ class TestStuckDetector:
state.history.append(message_null_observation)
# 8 events
assert stuck_detector.is_stuck() is False
assert stuck_detector.state.almost_stuck == 2
assert stuck_detector.is_stuck(headless_mode=True) is False
cmd_action_3 = CmdRunAction(command='ls')
cmd_action_3._id = 3
@ -160,20 +218,7 @@ class TestStuckDetector:
# 10 events
assert len(state.history) == 10
assert (
len(state.history) == 10
) # Adjusted since history is a list and the controller is not running
# FIXME are we still testing this without this test?
# assert (
# len(
# get_pairs_from_events(state.history)
# )
# == 5
# )
assert stuck_detector.is_stuck() is False
assert stuck_detector.state.almost_stuck == 1
assert stuck_detector.is_stuck(headless_mode=True) is False
cmd_action_4 = CmdRunAction(command='ls')
cmd_action_4._id = 4
@ -184,16 +229,9 @@ class TestStuckDetector:
# 12 events
assert len(state.history) == 12
# assert (
# len(
# get_pairs_from_events(state.history)
# )
# == 6
# )
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.state.almost_stuck == 0
assert stuck_detector.is_stuck(headless_mode=True) is True
mock_warning.assert_called_once_with('Action, Observation loop detected')
def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
@ -245,7 +283,7 @@ class TestStuckDetector:
# 12 events
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(headless_mode=True) is True
mock_warning.assert_called_once_with(
'Action, ErrorObservation loop detected'
)
@ -259,7 +297,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(headless_mode=True) is True
def test_is_not_stuck_invalid_syntax_error_random_lines(
self, stuck_detector: StuckDetector
@ -272,7 +310,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(headless_mode=True) is False
def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
self, stuck_detector: StuckDetector
@ -286,7 +324,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(headless_mode=True) is False
def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
@ -297,7 +335,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(headless_mode=True) is True
def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
@ -308,7 +346,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(headless_mode=True) is False
def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
self, stuck_detector: StuckDetector
@ -317,7 +355,7 @@ class TestStuckDetector:
self._impl_unterminated_string_error_events(state, random_line=True)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(headless_mode=True) is False
def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
self, stuck_detector: StuckDetector
@ -328,7 +366,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(headless_mode=True) is False
def test_is_stuck_ipython_unterminated_string_error(
self, stuck_detector: StuckDetector
@ -337,7 +375,7 @@ class TestStuckDetector:
self._impl_unterminated_string_error_events(state, random_line=False)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(headless_mode=True) is True
def test_is_not_stuck_ipython_syntax_error_not_at_end(
self, stuck_detector: StuckDetector
@ -382,7 +420,7 @@ class TestStuckDetector:
state.history.append(ipython_observation_4)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(headless_mode=True) is False
mock_warning.assert_not_called()
def test_is_stuck_repeating_action_observation_pattern(
@ -451,7 +489,7 @@ class TestStuckDetector:
state.history.append(read_observation_3)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(headless_mode=True) is True
mock_warning.assert_called_once_with('Action, Observation pattern detected')
def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
@ -517,7 +555,7 @@ class TestStuckDetector:
# read_observation_3._cause = read_action_3._id
state.history.append(read_observation_3)
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(headless_mode=True) is False
def test_is_stuck_monologue(self, stuck_detector):
state = stuck_detector.state
@ -547,7 +585,7 @@ class TestStuckDetector:
message_action_6._source = EventSource.AGENT
state.history.append(message_action_6)
assert stuck_detector.is_stuck()
assert stuck_detector.is_stuck(headless_mode=True)
# Add an observation event between the repeated message actions
cmd_output_observation = CmdOutputObservation(
@ -567,7 +605,7 @@ class TestStuckDetector:
state.history.append(message_action_8)
with patch('logging.Logger.warning'):
assert not stuck_detector.is_stuck()
assert not stuck_detector.is_stuck(headless_mode=True)
class TestAgentController: