feat(eval): increase resource factor for remote runtime when previous run failed due to resource (#5709)

This commit is contained in:
Xingyao Wang 2024-12-20 12:47:06 -05:00 committed by GitHub
parent cfbe77b367
commit 581d5ec7a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 12 deletions

View File

@ -370,6 +370,7 @@ def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
runtime_failure_count: int = 0,
) -> EvalOutput:
config = get_config(instance, metadata)
@ -380,6 +381,15 @@ def process_instance(
else:
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
# Increase resource_factor with increasing attempt_id
if runtime_failure_count > 0:
config.sandbox.remote_runtime_resource_factor = min(
config.sandbox.remote_runtime_resource_factor * (2**runtime_failure_count),
2, # hardcode maximum resource factor to 2
)
logger.warning(
f'This is the second attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
)
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)

View File

@ -8,6 +8,7 @@ import subprocess
import time
import traceback
from contextlib import contextmanager
from inspect import signature
from typing import Any, Awaitable, Callable, TextIO
import pandas as pd
@ -24,7 +25,6 @@ from openhands.core.exceptions import (
AgentRuntimeNotReadyError,
AgentRuntimeTimeoutError,
AgentRuntimeUnavailableError,
AgentStuckInLoopError,
)
from openhands.core.logger import get_console_handler
from openhands.core.logger import openhands_logger as logger
@ -316,13 +316,20 @@ def _process_instance_wrapper(
timeout_seconds: int | None = None,
) -> EvalOutput:
"""Wrap the process_instance_func to handle retries and errors."""
runtime_failure_count = 0
for attempt in range(max_retries + 1):
try:
kwargs = {}
# check if process_instance_func accepts timeout_seconds parameter
sig = signature(process_instance_func)
if 'runtime_failure_count' in sig.parameters:
kwargs['runtime_failure_count'] = runtime_failure_count
if timeout_seconds is not None:
with timeout(timeout_seconds):
result = process_instance_func(instance, metadata, use_mp)
result = process_instance_func(instance, metadata, use_mp, **kwargs)
else:
result = process_instance_func(instance, metadata, use_mp)
result = process_instance_func(instance, metadata, use_mp, **kwargs)
return result
except EvalTimeoutException as e:
error = f'Timeout after {timeout_seconds} seconds'
@ -368,6 +375,11 @@ def _process_instance_wrapper(
+ '-' * 10
+ '\n'
)
if isinstance(
e, (AgentRuntimeDisconnectedError, AgentRuntimeUnavailableError)
):
runtime_failure_count += 1
msg += f'Runtime disconnected error detected for instance {instance.instance_id}, runtime failure count: {runtime_failure_count}'
logger.error(msg)
if use_mp:
print(msg) # use print to directly print to console
@ -527,7 +539,6 @@ def is_fatal_evaluation_error(error: str | None) -> bool:
AgentRuntimeNotReadyError,
AgentRuntimeDisconnectedError,
AgentRuntimeNotFoundError,
AgentStuckInLoopError,
]
if any(exception.__name__ in error for exception in FATAL_EXCEPTIONS):

View File

@ -32,6 +32,8 @@ class SandboxConfig:
browsergym_eval_env: The BrowserGym environment to use for evaluation.
Default is None for general purpose browsing. Check evaluation/miniwob and evaluation/webarena for examples.
platform: The platform on which the image should be built. Default is None.
remote_runtime_resource_factor: Factor to scale the resource allocation for remote runtime.
Must be one of [1, 2, 4, 8]. Will only be used if the runtime is remote.
"""
remote_runtime_api_url: str = 'http://localhost:8000'
@ -56,6 +58,7 @@ class SandboxConfig:
browsergym_eval_env: str | None = None
platform: str | None = None
close_delay: int = 15
remote_runtime_resource_factor: int = 1
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""

View File

@ -41,6 +41,7 @@ from openhands.runtime.builder.remote import RemoteRuntimeBuilder
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.command import get_remote_startup_command
from openhands.runtime.utils.request import (
RequestHTTPError,
send_request,
)
from openhands.runtime.utils.runtime_build import build_runtime_image
@ -246,6 +247,7 @@ class RemoteRuntime(Runtime):
'working_dir': '/openhands/code/',
'environment': {'DEBUG': 'true'} if self.config.debug else {},
'session_id': self.sid,
'resource_factor': self.config.sandbox.remote_runtime_resource_factor,
}
# Start the sandbox using the /start endpoint
@ -451,11 +453,11 @@ class RemoteRuntime(Runtime):
except requests.Timeout:
self.log('error', 'No response received within the timeout period.')
raise
except requests.HTTPError as e:
if is_runtime_request and e.response.status_code == 404:
except RequestHTTPError as e:
if is_runtime_request and e.response.status_code in (404, 502):
raise AgentRuntimeDisconnectedError(
f'404 error while connecting to {self.runtime_url}'
)
f'{e.response.status_code} error while connecting to {self.runtime_url}'
) from e
elif is_runtime_request and e.response.status_code == 503:
if not is_retry:
self.log('warning', 'Runtime appears to be paused. Resuming...')
@ -463,7 +465,9 @@ class RemoteRuntime(Runtime):
self._wait_until_alive()
return self._send_request(method, url, True, **kwargs)
else:
raise e
raise AgentRuntimeUnavailableError(
f'{e.response.status_code} error while connecting to {self.runtime_url}'
) from e
else:
raise e

View File

@ -1,3 +1,4 @@
import json
from typing import Any
import requests
@ -30,9 +31,11 @@ def send_request(
except requests.HTTPError as e:
try:
_json = response.json()
except requests.JSONDecodeError:
raise e
except (requests.exceptions.JSONDecodeError, json.decoder.JSONDecodeError):
_json = None
raise RequestHTTPError(
e, response=e.response, detail=_json.get('detail')
e,
response=e.response,
detail=_json.get('detail') if _json is not None else None,
) from e
return response