(Hotfix): Track reason for Error AgentState (#7584)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra 2025-03-31 17:24:42 -04:00 committed by GitHub
parent abaf0da9fe
commit 9adfcede31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 3 deletions

View File

@ -521,6 +521,11 @@ def compatibility_for_eval_history_pairs(
def is_fatal_evaluation_error(error: str | None) -> bool:
"""
The AgentController class overrides last error for certain exceptions
We want to ensure those exeption do not overlap with fatal exceptions defined here
This is because we do a comparisino against the stringified error
"""
if not error:
return False

View File

@ -2078,6 +2078,7 @@
"tr": "Ajan hız sınırına ulaştı",
"ja": "エージェントがレート制限中"
},
"CHAT_INTERFACE$AGENT_PAUSED_MESSAGE": {
"en": "Agent has paused.",
"de": "Agent pausiert.",

View File

@ -227,11 +227,14 @@ class AgentController:
e: Exception,
):
"""React to an exception by setting the agent state to error and sending a status message."""
await self.set_agent_state_to(AgentState.ERROR)
# Store the error reason before setting the agent state
self.state.last_error = f'{type(e).__name__}: {str(e)}'
if self.status_callback is not None:
err_id = ''
if isinstance(e, AuthenticationError):
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
self.state.last_error = err_id
elif isinstance(
e,
(
@ -241,14 +244,21 @@ class AgentController:
),
):
err_id = 'STATUS$ERROR_LLM_SERVICE_UNAVAILABLE'
self.state.last_error = err_id
elif isinstance(e, InternalServerError):
err_id = 'STATUS$ERROR_LLM_INTERNAL_SERVER_ERROR'
self.state.last_error = err_id
elif isinstance(e, BadRequestError) and 'ExceededBudget' in str(e):
err_id = 'STATUS$ERROR_LLM_OUT_OF_CREDITS'
# Set error reason for budget exceeded
self.state.last_error = err_id
elif isinstance(e, RateLimitError):
await self.set_agent_state_to(AgentState.RATE_LIMITED)
return
self.status_callback('error', err_id, type(e).__name__ + ': ' + str(e))
self.status_callback('error', err_id, self.state.last_error)
# Set the agent state to ERROR after storing the reason
await self.set_agent_state_to(AgentState.ERROR)
def step(self):
asyncio.create_task(self._step_with_exception_handling())
@ -581,8 +591,14 @@ class AgentController:
self.event_stream.add_event(self._pending_action, EventSource.AGENT)
self.state.agent_state = new_state
# Create observation with reason field if it's an error state
reason = ''
if new_state == AgentState.ERROR:
reason = self.state.last_error
self.event_stream.add_event(
AgentStateChangedObservation('', self.state.agent_state),
AgentStateChangedObservation('', self.state.agent_state, reason),
EventSource.ENVIRONMENT,
)

View File

@ -10,6 +10,7 @@ class AgentStateChangedObservation(Observation):
"""This data class represents the result from delegating to another agent"""
agent_state: str
reason: str = ''
observation: str = ObservationType.AGENT_STATE_CHANGED
@property

View File

@ -17,6 +17,7 @@ from openhands.events.action import ChangeAgentStateAction, CmdRunAction, Messag
from openhands.events.action.agent import RecallAction
from openhands.events.event import RecallType
from openhands.events.observation import (
AgentStateChangedObservation,
ErrorObservation,
)
from openhands.events.observation.agent import RecallObservation
@ -217,9 +218,17 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
print(f'state: {state}')
events = list(test_event_stream.get_events())
print(f'event_stream: {events}')
error_observations = test_event_stream.get_matching_events(
reverse=True, limit=1, event_types=(AgentStateChangedObservation)
)
assert len(error_observations) == 1
error_observation = error_observations[0]
assert state.iteration == 3
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
assert (
error_observation.reason == 'AgentStuckInLoopError: Agent got stuck in a loop'
)
assert len(events) == 11
@ -622,6 +631,17 @@ async def test_run_controller_max_iterations_has_metrics(
state.last_error
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
)
error_observations = test_event_stream.get_matching_events(
reverse=True, limit=1, event_types=(AgentStateChangedObservation)
)
assert len(error_observations) == 1
error_observation = error_observations[0]
assert (
error_observation.reason
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
)
assert (
state.metrics.accumulated_cost == 10.0 * 3
), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}'
@ -896,6 +916,16 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
== 'LLMContextWindowExceedError: Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error'
)
error_observations = test_event_stream.get_matching_events(
reverse=True, limit=1, event_types=(AgentStateChangedObservation)
)
assert len(error_observations) == 1
error_observation = error_observations[0]
assert (
error_observation.reason
== 'LLMContextWindowExceedError: Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error'
)
# Check that the context window exceeded error was raised during the run
assert step_state.has_errored