fix(eval): support setting hard timeout per evaluation instance (#5110)

This commit is contained in:
Xingyao Wang 2024-11-18 20:22:55 -06:00 committed by GitHub
parent 422104c877
commit a531413d86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 7 deletions

View File

@ -145,7 +145,7 @@ def get_config(
platform='linux/amd64',
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
keep_remote_runtime_alive=False,
keep_runtime_alive=False,
remote_runtime_init_timeout=3600,
),
# do not mount workspace

View File

@ -3,9 +3,11 @@ import logging
import multiprocessing as mp
import os
import pathlib
import signal
import subprocess
import time
import traceback
from contextlib import contextmanager
from typing import Any, Awaitable, Callable, TextIO
import pandas as pd
@ -92,6 +94,27 @@ class EvalException(Exception):
pass
class EvalTimeoutException(Exception):
pass
@contextmanager
def timeout(seconds: int):
def timeout_handler(signum, frame):
raise EvalTimeoutException(f'Function timed out after {seconds} seconds')
# Set up the signal handler
original_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(seconds)
try:
yield
finally:
# Restore the original handler and disable the alarm
signal.alarm(0)
signal.signal(signal.SIGALRM, original_handler)
def codeact_user_response(
state: State,
encapsulate_solution: bool = False,
@ -280,15 +303,33 @@ def _process_instance_wrapper(
metadata: EvalMetadata,
use_mp: bool,
max_retries: int = 5,
timeout_seconds: int | None = None,
) -> EvalOutput:
"""Wrap the process_instance_func to handle retries and errors.
Retry an instance up to max_retries times if it fails (e.g., due to transient network/runtime issues).
"""
"""Wrap the process_instance_func to handle retries and errors."""
for attempt in range(max_retries + 1):
try:
result = process_instance_func(instance, metadata, use_mp)
if timeout_seconds is not None:
with timeout(timeout_seconds):
result = process_instance_func(instance, metadata, use_mp)
else:
result = process_instance_func(instance, metadata, use_mp)
return result
except EvalTimeoutException as e:
error = f'Timeout after {timeout_seconds} seconds'
stacktrace = traceback.format_exc()
msg = (
'-' * 10
+ '\n'
+ f'Timeout ({timeout_seconds} seconds) in instance [{instance.instance_id}], Stopped evaluation for this instance.'
+ '\n'
+ '-' * 10
)
logger.exception(e)
return EvalOutput(
instance_id=instance.instance_id,
test_result={},
error=error,
)
except Exception as e:
error = str(e)
stacktrace = traceback.format_exc()
@ -337,6 +378,7 @@ def run_evaluation(
[pd.Series, EvalMetadata, bool], Awaitable[EvalOutput]
],
max_retries: int = 5, # number of retries for each instance
timeout_seconds: int | None = None,
):
use_multiprocessing = num_workers > 1
@ -357,7 +399,14 @@ def run_evaluation(
if use_multiprocessing:
with mp.Pool(num_workers) as pool:
args_iter = (
(process_instance_func, instance, metadata, True, max_retries)
(
process_instance_func,
instance,
metadata,
True,
max_retries,
timeout_seconds,
)
for _, instance in dataset.iterrows()
)
results = pool.imap_unordered(_process_instance_wrapper_mp, args_iter)