diff --git a/evaluation/swe_bench/run_infer.py b/evaluation/swe_bench/run_infer.py index ff846ed1a8..36eb70554e 100644 --- a/evaluation/swe_bench/run_infer.py +++ b/evaluation/swe_bench/run_infer.py @@ -218,7 +218,7 @@ def initialize_runtime( assert obs.exit_code == 0 action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh') - action.timeout = 1800 + action.timeout = 3600 logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) diff --git a/evaluation/swe_bench/scripts/run_infer.sh b/evaluation/swe_bench/scripts/run_infer.sh index 0859e1f480..772a7d30e1 100755 --- a/evaluation/swe_bench/scripts/run_infer.sh +++ b/evaluation/swe_bench/scripts/run_infer.sh @@ -66,6 +66,11 @@ if [ "$USE_HINT_TEXT" = false ]; then EVAL_NOTE="$EVAL_NOTE-no-hint" fi +if [ -n "$EXP_NAME" ]; then + EVAL_NOTE="$EVAL_NOTE-$EXP_NAME" +fi +echo "EVAL_NOTE: $EVAL_NOTE" + unset SANDBOX_ENV_GITHUB_TOKEN # prevent the agent from using the github token to push COMMAND="poetry run python evaluation/swe_bench/run_infer.py \ diff --git a/evaluation/utils/shared.py b/evaluation/utils/shared.py index 716fbc3986..709a0e2bab 100644 --- a/evaluation/utils/shared.py +++ b/evaluation/utils/shared.py @@ -6,7 +6,7 @@ import pathlib import subprocess import time import traceback -from concurrent.futures import Future, ProcessPoolExecutor +from concurrent.futures import ProcessPoolExecutor, as_completed from typing import Any, Awaitable, Callable, TextIO import pandas as pd @@ -78,12 +78,6 @@ class EvalOutput(BaseModel): return json.dumps(dumped_dict) -class EvalError(BaseModel): - instance_id: str - error: str - stacktrace: str - - def codeact_user_response( state: State, encapsulate_solution: bool = False, @@ -235,65 +229,58 @@ def prepare_dataset( def update_progress( - result_or_future: Future | EvalOutput | EvalError, - instance: pd.Series, + result: EvalOutput, pbar: tqdm, output_fp: TextIO, - instance_queue: mp.Queue, ): """Update the progress bar and write the result to the output file.""" - try: - if isinstance(result_or_future, Future): - result = result_or_future.result() - else: - result = result_or_future - except Exception as e: - # Handle the error - # Exception may be raised in the process_instance_func and will - # be raised here when we try to access the .result() of the future - handle_error( - EvalError( - instance_id=instance.instance_id, - error=str(e), - stacktrace=traceback.format_exc(), - ), - instance, - pbar, - instance_queue, - ) - return - - # Update the progress bar and write the result to the output file - if isinstance(result, EvalOutput): - pbar.update(1) - pbar.set_description(f'Instance {result.instance_id}') - pbar.set_postfix_str(f'Test Result: {result.test_result}') - logger.info( - f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n' - ) - output_fp.write(json.dumps(result.model_dump()) + '\n') - output_fp.flush() - elif isinstance(result, EvalError): - handle_error(result, instance, pbar, instance_queue) - else: - raise ValueError(f'Unexpected result type: {type(result)}') - - -def handle_error( - error: EvalError, instance: pd.Series, pbar: tqdm, instance_queue: mp.Queue -): - """Handle an error that occurred during evaluation.""" - logger.error( - f'Retrying instance [{instance.instance_id}] due to error: {error.error}. Stacktrace:\n{error.stacktrace}' - + '\n' - + '-' * 10 - + '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]' - + '-' * 10 - + '\n' + pbar.update(1) + pbar.set_description(f'Instance {result.instance_id}') + pbar.set_postfix_str(f'Test Result: {result.test_result}') + logger.info( + f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n' ) - instance_queue.put(instance) - pbar.total += 1 - pbar.refresh() + output_fp.write(json.dumps(result.model_dump()) + '\n') + output_fp.flush() + + +def _process_instance_wrapper( + process_instance_func: Callable[[pd.Series, EvalMetadata, bool], EvalOutput], + instance: pd.Series, + metadata: EvalMetadata, + use_mp: bool, + max_retries: int = 5, +) -> EvalOutput: + """Wrap the process_instance_func to handle retries and errors. + + Retry an instance up to max_retries times if it fails (e.g., due to transient network/runtime issues). + """ + for attempt in range(max_retries + 1): + try: + result = process_instance_func(instance, metadata, use_mp) + return result + except Exception as e: + if attempt == max_retries: + # Raise an error after all retries & stop the evaluation + raise RuntimeError( + f'Maximum error retries reached for instance {instance.instance_id}' + ) from e + error = str(e) + stacktrace = traceback.format_exc() + msg = ( + '-' * 10 + + '\n' + + f'Error in instance [{instance.instance_id}]: {error}. Stacktrace:\n{stacktrace}' + + '\n' + + '-' * 10 + + '[This error occurred after maximum retries]' + + '-' * 10 + + '\n' + ) + logger.error(msg) + if use_mp: + print(msg) # use print to directly print to console + time.sleep(1) # Add a small delay before retrying def run_evaluation( @@ -304,6 +291,7 @@ def run_evaluation( process_instance_func: Callable[ [pd.Series, EvalMetadata, bool], Awaitable[EvalOutput] ], + max_retries: int = 5, # number of retries for each instance ): use_multiprocessing = num_workers > 1 logger.info( @@ -311,10 +299,6 @@ def run_evaluation( f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n' ) - instance_queue = mp.Queue() - for _, instance in dataset.iterrows(): - instance_queue.put(instance) - total_instances = len(dataset) pbar = tqdm(total=total_instances, desc='Instances processed') output_fp = open(output_file, 'a') @@ -322,53 +306,30 @@ def run_evaluation( try: if use_multiprocessing: with ProcessPoolExecutor(num_workers) as executor: - batch_futures = [] - - # Loop until there are *no more instances to be processed* and *all (in-progress) futures are done* - # since a running future may add new instances to the queue when error occurs - while not instance_queue.empty() or batch_futures: - # Submit new tasks if there are instances to be processed and available workers - while ( - not instance_queue.empty() and len(batch_futures) < num_workers - ): - try: - instance = instance_queue.get(block=False) - future = executor.submit( - process_instance_func, instance, metadata, True - ) - future.instance = ( - instance # Attach the instance to the future - ) - batch_futures.append(future) - except mp.queues.Empty: - logger.warning( - 'Queue is empty - This should not happen. This is a bug.' - ) - break # Queue is empty, stop submitting new tasks - - # Continue to wait for the futures to be done & remove completed futures - new_batch_futures = [] - for future in batch_futures: - if future.done(): - update_progress( - future, future.instance, pbar, output_fp, instance_queue - ) - else: - new_batch_futures.append(future) - batch_futures = new_batch_futures - - # Short sleep to prevent busy-waiting - time.sleep(1) - - assert instance_queue.empty(), 'instance_queue should be empty after all futures are done. This is a bug.' - assert ( - len(batch_futures) == 0 - ), 'batch_futures should be empty after all futures are done. This is a bug.' + futures = [ + executor.submit( + _process_instance_wrapper, + process_instance_func=process_instance_func, + instance=instance, + metadata=metadata, + use_mp=True, + max_retries=max_retries, + ) + for _, instance in dataset.iterrows() + ] + for future in as_completed(futures): + result = future.result() + update_progress(result, pbar, output_fp) else: - while not instance_queue.empty(): - instance = instance_queue.get() - result = process_instance_func(instance, metadata, False) - update_progress(result, instance, pbar, output_fp, instance_queue) + for _, instance in dataset.iterrows(): + result = _process_instance_wrapper( + process_instance_func=process_instance_func, + instance=instance, + metadata=metadata, + use_mp=False, + max_retries=max_retries, + ) + update_progress(result, pbar, output_fp) except KeyboardInterrupt: print('\nKeyboardInterrupt received. Cleaning up...\n')