mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
implement almost stuck reminder
This commit is contained in:
@@ -275,8 +275,17 @@ class CodeActAgent(Agent):
|
||||
)
|
||||
|
||||
if latest_user_message:
|
||||
latest_user_message['content'] += (
|
||||
f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task.'
|
||||
)
|
||||
if state.almost_stuck == 1:
|
||||
latest_user_message['content'] += (
|
||||
'\n\nENVIRONMENT REMINDER: You are almost stuck repeating the same action. You have only 1 iteration left and you must change your approach. Now.'
|
||||
)
|
||||
elif state.almost_stuck == 2:
|
||||
latest_user_message['content'] += (
|
||||
'\n\nENVIRONMENT REMINDER: You are almost stuck repeating the same action. You have only 2 iterations left and you must change your approach.'
|
||||
)
|
||||
else:
|
||||
latest_user_message['content'] += (
|
||||
f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task.'
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
@@ -187,6 +187,7 @@ class AgentController:
|
||||
self.state.history.on_event(event)
|
||||
|
||||
def reset_task(self):
|
||||
self.almost_stuck = 0
|
||||
self.agent.reset()
|
||||
|
||||
async def set_agent_state_to(self, new_state: AgentState):
|
||||
|
||||
@@ -43,6 +43,7 @@ class State:
|
||||
delegate_level: int = 0
|
||||
start_id: int = -1
|
||||
end_id: int = -1
|
||||
almost_stuck: int = 0
|
||||
summaries: dict[tuple[int, int], AgentSummarizeAction] = field(default_factory=dict)
|
||||
|
||||
def save_to_session(self, sid: str):
|
||||
|
||||
@@ -75,29 +75,50 @@ class StuckDetector:
|
||||
|
||||
def _is_stuck_repeating_action_observation(self, last_actions, last_observations):
|
||||
# scenario 1: same action, same observation
|
||||
# it takes 3 actions and 3 observations to detect a loop
|
||||
last_three_actions = last_actions[-3:]
|
||||
last_three_observations = last_observations[-3:]
|
||||
# it takes 4 actions and 4 observations to detect a loop
|
||||
# assert len(last_actions) == 4 and len(last_observations) == 4
|
||||
|
||||
# almost stuck? if two actions, obs are the same, we're almost stuck
|
||||
if len(last_actions) >= 2 and len(last_observations) >= 2:
|
||||
actions_equal = all(
|
||||
self._eq_no_pid(last_actions[0], action) for action in last_actions[:2]
|
||||
)
|
||||
observations_equal = all(
|
||||
self._eq_no_pid(last_observations[0], observation)
|
||||
for observation in last_observations[:2]
|
||||
)
|
||||
|
||||
# the last two actions and obs are the same?
|
||||
if actions_equal and observations_equal:
|
||||
self.state.almost_stuck = 2
|
||||
|
||||
# the last three actions and observations are the same?
|
||||
if len(last_actions) >= 3 and len(last_observations) >= 3:
|
||||
if (
|
||||
actions_equal
|
||||
and observations_equal
|
||||
and self._eq_no_pid(last_actions[0], last_actions[2])
|
||||
and self._eq_no_pid(last_observations[0], last_observations[2])
|
||||
):
|
||||
self.state.almost_stuck = 1
|
||||
|
||||
if len(last_actions) == 4 and len(last_observations) == 4:
|
||||
if (
|
||||
actions_equal
|
||||
and observations_equal
|
||||
and self._eq_no_pid(last_actions[0], last_actions[3])
|
||||
and self._eq_no_pid(last_observations[0], last_observations[3])
|
||||
):
|
||||
logger.warning('Action, Observation loop detected')
|
||||
self.state.almost_stuck = 0
|
||||
return True
|
||||
|
||||
# are the last three actions the same?
|
||||
if len(last_three_actions) == 3 and all(
|
||||
self._eq_no_pid(last_three_actions[0], action)
|
||||
for action in last_three_actions
|
||||
):
|
||||
# and the last three observations the same?
|
||||
if len(last_three_observations) == 3 and all(
|
||||
self._eq_no_pid(last_three_observations[0], observation)
|
||||
for observation in last_three_observations
|
||||
):
|
||||
logger.warning('Action, Observation loop detected')
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_stuck_repeating_action_error(self, last_actions, last_observations):
|
||||
# scenario 2: same action, errors
|
||||
# it takes 4 actions and 4 observations to detect a loop
|
||||
# check if the last four actions are the same and result in errors
|
||||
# retrieve the last four actions and observations starting from the end of history, wherever they are
|
||||
|
||||
# are the last four actions the same?
|
||||
if len(last_actions) == 4 and all(
|
||||
|
||||
@@ -103,6 +103,9 @@ class TestStuckDetector:
|
||||
event_stream.add_event(message_null_observation, EventSource.USER)
|
||||
# 8 events
|
||||
|
||||
assert stuck_detector.is_stuck() is False
|
||||
assert stuck_detector.state.almost_stuck == 2
|
||||
|
||||
cmd_action_3 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_3, EventSource.AGENT)
|
||||
cmd_observation_3 = CmdOutputObservation(
|
||||
@@ -112,14 +115,29 @@ class TestStuckDetector:
|
||||
event_stream.add_event(cmd_observation_3, EventSource.USER)
|
||||
# 10 events
|
||||
|
||||
# stuck_detector.state.history.set_event_stream(event_stream)
|
||||
|
||||
assert len(collect_events(event_stream)) == 10
|
||||
assert len(list(stuck_detector.state.history.get_events())) == 8
|
||||
assert len(stuck_detector.state.history.get_tuples()) == 5
|
||||
|
||||
assert stuck_detector.is_stuck() is False
|
||||
assert stuck_detector.state.almost_stuck == 1
|
||||
|
||||
cmd_action_4 = CmdRunAction(command='ls')
|
||||
event_stream.add_event(cmd_action_4, EventSource.AGENT)
|
||||
cmd_observation_4 = CmdOutputObservation(
|
||||
content='', command='ls', command_id=cmd_action_4._id
|
||||
)
|
||||
cmd_observation_4._cause = cmd_action_4._id
|
||||
event_stream.add_event(cmd_observation_4, EventSource.USER)
|
||||
# 12 events
|
||||
|
||||
assert len(collect_events(event_stream)) == 12
|
||||
assert len(list(stuck_detector.state.history.get_events())) == 10
|
||||
assert len(stuck_detector.state.history.get_tuples()) == 6
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is True
|
||||
assert stuck_detector.state.almost_stuck == 0
|
||||
mock_warning.assert_called_once_with('Action, Observation loop detected')
|
||||
|
||||
def test_is_stuck_repeating_action_error(
|
||||
@@ -171,8 +189,6 @@ class TestStuckDetector:
|
||||
event_stream.add_event(error_observation_4, EventSource.USER)
|
||||
# 12 events
|
||||
|
||||
# stuck_detector.state.history.set_event_stream(event_stream)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is True
|
||||
mock_warning.assert_called_once_with(
|
||||
@@ -290,8 +306,6 @@ class TestStuckDetector:
|
||||
ipython_observation_4._cause = ipython_action_4._id
|
||||
event_stream.add_event(ipython_observation_4, EventSource.USER)
|
||||
|
||||
# stuck_detector.state.history.set_event_stream(event_stream)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is False
|
||||
mock_warning.assert_not_called()
|
||||
@@ -358,8 +372,6 @@ class TestStuckDetector:
|
||||
read_observation_3._cause = read_action_3._id
|
||||
event_stream.add_event(read_observation_3, EventSource.USER)
|
||||
|
||||
# stuck_detector.state.history.set_event_stream(event_stream)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is True
|
||||
mock_warning.assert_called_once_with('Action, Observation pattern detected')
|
||||
@@ -428,8 +440,6 @@ class TestStuckDetector:
|
||||
read_observation_3._cause = read_action_3._id
|
||||
event_stream.add_event(read_observation_3, EventSource.USER)
|
||||
|
||||
# stuck_detector.state.history.set_event_stream(event_stream)
|
||||
|
||||
assert stuck_detector.is_stuck() is False
|
||||
|
||||
def test_is_stuck_four_tuples_cmd_kill_and_output(
|
||||
@@ -500,8 +510,6 @@ class TestStuckDetector:
|
||||
cmd_output_observation_4._cause = cmd_kill_action_4._id
|
||||
event_stream.add_event(cmd_output_observation_4, EventSource.USER)
|
||||
|
||||
# stuck_detector.state.history.set_event_stream(event_stream)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck() is True
|
||||
mock_warning.assert_called_once_with('Action, Observation loop detected')
|
||||
|
||||
Reference in New Issue
Block a user