feat: clean-up retries RemoteRuntime & add FatalErrorObservation (#4485)

This commit is contained in:
Xingyao Wang 2024-10-18 12:23:13 -05:00 committed by GitHub
parent b660aa99b8
commit 91308ba4dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 146 additions and 60 deletions

View File

@ -11,6 +11,7 @@ from datasets import load_dataset
import openhands.agenthub
from evaluation.swe_bench.prompt import CODEACT_SWE_PROMPT
from evaluation.utils.shared import (
EvalException,
EvalMetadata,
EvalOutput,
assert_and_raise,
@ -384,6 +385,13 @@ def process_instance(
)
)
# if fatal error, throw EvalError to trigger re-run
if (
state.last_error
and 'fatal error during agent execution' in state.last_error
):
raise EvalException('Fatal error detected: ' + state.last_error)
# ======= THIS IS SWE-Bench specific =======
# Get git patch
return_val = complete_runtime(runtime, instance)

View File

@ -35,6 +35,7 @@ from openhands.events.observation import (
AgentStateChangedObservation,
CmdOutputObservation,
ErrorObservation,
FatalErrorObservation,
Observation,
)
from openhands.events.serialization.event import truncate_content
@ -249,6 +250,12 @@ class AgentController:
elif isinstance(observation, ErrorObservation):
if self.state.agent_state == AgentState.ERROR:
self.state.metrics.merge(self.state.local_metrics)
elif isinstance(observation, FatalErrorObservation):
await self.report_error(
'There was a fatal error during agent execution: ' + str(observation)
)
await self.set_agent_state_to(AgentState.ERROR)
self.state.metrics.merge(self.state.local_metrics)
async def _handle_message_action(self, action: MessageAction):
"""Handles message actions from the event stream.

View File

@ -6,8 +6,11 @@ from openhands.events.observation.commands import (
)
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.files import FileReadObservation, FileWriteObservation
from openhands.events.observation.error import ErrorObservation, FatalErrorObservation
from openhands.events.observation.files import (
FileReadObservation,
FileWriteObservation,
)
from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.observation.success import SuccessObservation
@ -21,6 +24,7 @@ __all__ = [
'FileReadObservation',
'FileWriteObservation',
'ErrorObservation',
'FatalErrorObservation',
'AgentStateChangedObservation',
'AgentDelegateObservation',
'SuccessObservation',

View File

@ -6,10 +6,31 @@ from openhands.events.observation.observation import Observation
@dataclass
class ErrorObservation(Observation):
"""This data class represents an error encountered by the agent."""
"""This data class represents an error encountered by the agent.
This is the type of error that LLM can recover from.
E.g., Linter error after editing a file.
"""
observation: str = ObservationType.ERROR
@property
def message(self) -> str:
return self.content
def __str__(self) -> str:
return f'**ErrorObservation**\n{self.content}'
@dataclass
class FatalErrorObservation(Observation):
"""This data class represents a fatal error encountered by the agent.
This is the type of error that LLM CANNOT recover from, and the agent controller should stop the execution and report the error to the user.
E.g., Remote runtime action execution failure: 503 Server Error: Service Unavailable for url OR 404 Not Found.
"""
observation: str = ObservationType.ERROR
def __str__(self) -> str:
return f'**FatalErrorObservation**\n{self.content}'

View File

@ -23,7 +23,7 @@ from openhands.events.action import (
)
from openhands.events.action.action import Action
from openhands.events.observation import (
ErrorObservation,
FatalErrorObservation,
NullObservation,
Observation,
UserRejectObservation,
@ -126,7 +126,13 @@ class EventStreamRuntime(Runtime):
attach_to_existing: bool = False,
):
super().__init__(
config, event_stream, sid, plugins, env_vars, status_message_callback, attach_to_existing
config,
event_stream,
sid,
plugins,
env_vars,
status_message_callback,
attach_to_existing,
)
def __init__(
@ -192,7 +198,13 @@ class EventStreamRuntime(Runtime):
# Will initialize both the event stream and the env vars
self.init_base_runtime(
config, event_stream, sid, plugins, env_vars, status_message_callback, attach_to_existing
config,
event_stream,
sid,
plugins,
env_vars,
status_message_callback,
attach_to_existing,
)
logger.info('Waiting for client to become ready...')
@ -431,9 +443,9 @@ class EventStreamRuntime(Runtime):
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
return ErrorObservation(f'Action {action_type} does not exist.')
return FatalErrorObservation(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return ErrorObservation(
return FatalErrorObservation(
f'Action {action_type} is not supported in the current runtime.'
)
if (
@ -465,15 +477,17 @@ class EventStreamRuntime(Runtime):
logger.debug(f'response: {response}')
error_message = response.text
logger.error(f'Error from server: {error_message}')
obs = ErrorObservation(f'Action execution failed: {error_message}')
obs = FatalErrorObservation(
f'Action execution failed: {error_message}'
)
except requests.Timeout:
logger.error('No response received within the timeout period.')
obs = ErrorObservation(
obs = FatalErrorObservation(
f'Action execution timed out after {action.timeout} seconds.'
)
except Exception as e:
logger.error(f'Error during action execution: {e}')
obs = ErrorObservation(f'Action execution failed: {str(e)}')
obs = FatalErrorObservation(f'Action execution failed: {str(e)}')
self._refresh_logs()
return obs

View File

@ -21,7 +21,7 @@ from openhands.events.action import (
)
from openhands.events.action.action import Action
from openhands.events.observation import (
ErrorObservation,
FatalErrorObservation,
NullObservation,
Observation,
)
@ -31,8 +31,8 @@ from openhands.runtime.builder.remote import RemoteRuntimeBuilder
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.runtime import Runtime
from openhands.runtime.utils.request import (
DEFAULT_RETRY_EXCEPTIONS,
is_404_error,
is_503_error,
send_request_with_retry,
)
from openhands.runtime.utils.runtime_build import build_runtime_image
@ -90,7 +90,6 @@ class RemoteRuntime(Runtime):
status_message_callback,
attach_to_existing,
)
self._wait_until_alive()
self.setup_initial_env()
def _start_or_attach_to_runtime(
@ -232,10 +231,12 @@ class RemoteRuntime(Runtime):
timeout=300,
)
if response.status_code != 201:
raise RuntimeError(f'Failed to start sandbox: {response.text}')
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] Failed to start runtime: {response.text}'
)
self._parse_runtime_response(response)
logger.info(
f'Sandbox started. Runtime ID: {self.runtime_id}, URL: {self.runtime_url}'
f'[Runtime (ID={self.runtime_id})] Runtime started. URL: {self.runtime_url}'
)
def _resume_runtime(self):
@ -247,8 +248,10 @@ class RemoteRuntime(Runtime):
timeout=30,
)
if response.status_code != 200:
raise RuntimeError(f'Failed to resume sandbox: {response.text}')
logger.info(f'Sandbox resumed. Runtime ID: {self.runtime_id}')
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] Failed to resume runtime: {response.text}'
)
logger.info(f'[Runtime (ID={self.runtime_id})] Runtime resumed.')
def _parse_runtime_response(self, response: requests.Response):
start_response = response.json()
@ -298,7 +301,7 @@ class RemoteRuntime(Runtime):
# clean up the runtime
self.close()
raise RuntimeError(
f'Runtime pod failed to start. Current status: {pod_status}'
f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}'
)
# Pending otherwise - add proper sleep
time.sleep(10)
@ -307,15 +310,15 @@ class RemoteRuntime(Runtime):
self.session,
'GET',
f'{self.runtime_url}/alive',
# Retry 404 errors for the /alive endpoint
# Retry 404 & 503 errors for the /alive endpoint
# because the runtime might just be starting up
# and have not registered the endpoint yet
retry_fns=[is_404_error],
retry_fns=[is_404_error, is_503_error],
# leave enough time for the runtime to start up
timeout=600,
)
if response.status_code != 200:
msg = f'Runtime is not alive yet (id={self.runtime_id}). Status: {response.status_code}.'
msg = f'Runtime (ID={self.runtime_id}) is not alive yet. Status: {response.status_code}.'
logger.warning(msg)
raise RuntimeError(msg)
@ -333,9 +336,11 @@ class RemoteRuntime(Runtime):
timeout=timeout,
)
if response.status_code != 200:
logger.error(f'Failed to stop sandbox: {response.text}')
logger.error(
f'[Runtime (ID={self.runtime_id})] Failed to stop runtime: {response.text}'
)
else:
logger.info(f'Sandbox stopped. Runtime ID: {self.runtime_id}')
logger.info(f'[Runtime (ID={self.runtime_id})] Runtime stopped.')
except Exception as e:
raise e
finally:
@ -349,16 +354,17 @@ class RemoteRuntime(Runtime):
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
return ErrorObservation(f'Action {action_type} does not exist.')
return FatalErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action {action_type} does not exist.'
)
if not hasattr(self, action_type):
return ErrorObservation(
f'Action {action_type} is not supported in the current runtime.'
return FatalErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.'
)
assert action.timeout is not None
try:
logger.info('Executing action')
request_body = {'action': event_to_dict(action)}
logger.debug(f'Request body: {request_body}')
response = send_request_with_retry(
@ -367,13 +373,6 @@ class RemoteRuntime(Runtime):
f'{self.runtime_url}/execute_action',
json=request_body,
timeout=action.timeout,
retry_exceptions=list(
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
),
# Retry 404 errors for the /execute_action endpoint
# because the runtime might just be starting up
# and have not registered the endpoint yet
retry_fns=[is_404_error],
)
if response.status_code == 200:
output = response.json()
@ -383,13 +382,19 @@ class RemoteRuntime(Runtime):
else:
error_message = response.text
logger.error(f'Error from server: {error_message}')
obs = ErrorObservation(f'Action execution failed: {error_message}')
obs = FatalErrorObservation(
f'Action execution failed: {error_message}'
)
except Timeout:
logger.error('No response received within the timeout period.')
obs = ErrorObservation('Action execution timed out')
obs = FatalErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action execution timed out'
)
except Exception as e:
logger.error(f'Error during action execution: {e}')
obs = ErrorObservation(f'Action execution failed: {str(e)}')
obs = FatalErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action execution failed: {str(e)}'
)
return obs
def run(self, action: CmdRunAction) -> Observation:
@ -444,9 +449,6 @@ class RemoteRuntime(Runtime):
f'{self.runtime_url}/upload_file',
files=upload_data,
params=params,
retry_exceptions=list(
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
),
timeout=300,
)
if response.status_code == 200:
@ -456,11 +458,17 @@ class RemoteRuntime(Runtime):
return
else:
error_message = response.text
raise Exception(f'Copy operation failed: {error_message}')
raise Exception(
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
)
except TimeoutError:
raise TimeoutError('Copy operation timed out')
raise TimeoutError(
f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
)
except Exception as e:
raise RuntimeError(f'Copy operation failed: {str(e)}')
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
)
finally:
if recursive:
os.unlink(temp_zip_path)
@ -477,9 +485,6 @@ class RemoteRuntime(Runtime):
'POST',
f'{self.runtime_url}/list_files',
json=data,
retry_exceptions=list(
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
),
timeout=30,
)
if response.status_code == 200:
@ -488,15 +493,20 @@ class RemoteRuntime(Runtime):
return response_json
else:
error_message = response.text
raise Exception(f'List files operation failed: {error_message}')
raise Exception(
f'[Runtime (ID={self.runtime_id})] List files operation failed: {error_message}'
)
except TimeoutError:
raise TimeoutError('List files operation timed out')
raise TimeoutError(
f'[Runtime (ID={self.runtime_id})] List files operation timed out'
)
except Exception as e:
raise RuntimeError(f'List files operation failed: {str(e)}')
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] List files operation failed: {str(e)}'
)
def copy_from(self, path: str) -> bytes:
"""Zip all files in the sandbox and return as a stream of bytes."""
self._wait_until_alive()
try:
params = {'path': path}
response = send_request_with_retry(
@ -505,19 +515,22 @@ class RemoteRuntime(Runtime):
f'{self.runtime_url}/download_files',
params=params,
timeout=30,
retry_exceptions=list(
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
),
)
if response.status_code == 200:
return response.content
else:
error_message = response.text
raise Exception(f'Copy operation failed: {error_message}')
raise Exception(
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
)
except requests.Timeout:
raise TimeoutError('Copy operation timed out')
raise TimeoutError(
f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
)
except Exception as e:
raise RuntimeError(f'Copy operation failed: {str(e)}')
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
)
def send_status_message(self, message: str):
"""Sends a status message if the callback function was provided."""

View File

@ -1,7 +1,10 @@
from typing import Any, Callable, Type
import requests
from requests.exceptions import ConnectionError, Timeout
from requests.exceptions import (
ChunkedEncodingError,
ConnectionError,
)
from tenacity import (
retry,
retry_if_exception,
@ -9,6 +12,7 @@ from tenacity import (
stop_after_delay,
wait_exponential,
)
from urllib3.exceptions import IncompleteRead
from openhands.utils.tenacity_stop import stop_if_should_exit
@ -27,9 +31,24 @@ def is_404_error(exception):
)
def is_503_error(exception):
return (
isinstance(exception, requests.HTTPError)
and exception.response.status_code == 503
)
def is_502_error(exception):
return (
isinstance(exception, requests.HTTPError)
and exception.response.status_code == 502
)
DEFAULT_RETRY_EXCEPTIONS = [
ConnectionError,
Timeout,
IncompleteRead,
ChunkedEncodingError,
]
@ -45,7 +64,7 @@ def send_request_with_retry(
exceptions_to_catch = retry_exceptions or DEFAULT_RETRY_EXCEPTIONS
retry_condition = retry_if_exception_type(
tuple(exceptions_to_catch)
) | retry_if_exception(is_server_error)
) | retry_if_exception(is_502_error)
if retry_fns is not None:
for fn in retry_fns:
retry_condition |= retry_if_exception(fn)