(enh) StuckDetector: fix+enhance syntax error loop detection (#3628)

* fix StuckDetector and add more errors for detection

* more stringent error detection and more unit tests
This commit is contained in:
tobitege
2024-08-29 17:33:54 +02:00
committed by GitHub
parent ae153aa8ab
commit a2d94c9cb1
2 changed files with 201 additions and 96 deletions

View File

@@ -1,5 +1,3 @@
from typing import cast
from openhands.controller.state.state import State
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
@@ -16,6 +14,12 @@ from openhands.events.observation.observation import Observation
class StuckDetector:
SYNTAX_ERROR_MESSAGES = [
'SyntaxError: unterminated string literal (detected at line',
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
'SyntaxError: incomplete input',
]
def __init__(self, state: State):
self.state = state
@@ -119,37 +123,108 @@ class StuckDetector:
def _is_stuck_repeating_action_error(self, last_actions, last_observations):
# scenario 2: same action, errors
# it takes 4 actions and 4 observations to detect a loop
# check if the last four actions are the same and result in errors
# it takes 3 actions and 3 observations to detect a loop
# check if the last three actions are the same and result in errors
# are the last four actions the same?
if len(last_actions) == 4 and all(
self._eq_no_pid(last_actions[0], action) for action in last_actions
):
# and the last four observations all errors?
if all(isinstance(obs, ErrorObservation) for obs in last_observations):
if len(last_actions) < 4 or len(last_observations) < 4:
return False
# are the last three actions the "same"?
if all(self._eq_no_pid(last_actions[0], action) for action in last_actions[:3]):
# and the last three observations are all errors?
if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
logger.warning('Action, ErrorObservation loop detected')
return True
# or, are the last four observations all IPythonRunCellObservation with SyntaxError?
# or, are the last three observations all IPythonRunCellObservation with SyntaxError?
elif all(
isinstance(obs, IPythonRunCellObservation) for obs in last_observations
) and all(
cast(IPythonRunCellObservation, obs)
.content[-100:]
.find('SyntaxError: unterminated string literal (detected at line')
!= -1
and len(
cast(IPythonRunCellObservation, obs).content.split(
'SyntaxError: unterminated string literal (detected at line'
)[-1]
)
< 10
for obs in last_observations
isinstance(obs, IPythonRunCellObservation)
for obs in last_observations[:3]
):
logger.warning('Action, IPythonRunCellObservation loop detected')
return True
warning = 'Action, IPythonRunCellObservation loop detected'
for error_message in self.SYNTAX_ERROR_MESSAGES:
if error_message.startswith(
'SyntaxError: unterminated string literal (detected at line'
):
if self._check_for_consistent_line_error(
last_observations[:3], error_message
):
logger.warning(warning)
return True
elif error_message in [
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
'SyntaxError: incomplete input',
]:
if self._check_for_consistent_invalid_syntax(
last_observations[:3], error_message
):
logger.warning(warning)
return True
return False
def _check_for_consistent_invalid_syntax(self, observations, error_message):
first_lines = []
valid_observations = []
for obs in observations:
content = obs.content
lines = content.strip().split('\n')
if len(lines) < 4:
return False
first_lines.append(lines[0]) # Store the first line of each observation
# Check last three lines
if lines[-2].startswith('[Jupyter current working directory:') and lines[
-1
].startswith('[Jupyter Python interpreter:'):
if error_message in lines[-3]:
valid_observations.append(obs)
break
# Check if:
# 1. All first lines are identical
# 2. We have exactly 3 valid observations
# 3. The error message line is identical in all valid observations
return (
len(set(first_lines)) == 1
and len(valid_observations) == 3
and len(
set(
obs.content.strip().split('\n')[:-2][-1]
for obs in valid_observations
)
)
== 1
)
def _check_for_consistent_line_error(self, observations, error_message):
error_lines = []
for obs in observations:
content = obs.content
lines = content.strip().split('\n')
if len(lines) < 3:
return False
last_lines = lines[-3:]
# Check if the last two lines are our own
if not (
last_lines[-2].startswith('[Jupyter current working directory:')
and last_lines[-1].startswith('[Jupyter Python interpreter:')
):
return False
# Check for the error message in the 3rd-to-last line
if error_message in last_lines[-3]:
error_lines.append(last_lines[-3])
# Check if we found the error message in all 3 observations
# and the 3rd-to-last line is identical across all occurrences
return len(error_lines) == 3 and len(set(error_lines)) == 1
def _is_stuck_monologue(self, filtered_history):
# scenario 3: monologue
# check for repeated MessageActions with source=AGENT