[eval] Fix multi-processing bug (again^3) & allow set EXP_NAME for each run_infer (#3907)

Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>
This commit is contained in:
Xingyao Wang 2024-09-17 09:07:58 -05:00 committed by GitHub
parent fa0d9cfa42
commit f996b31d64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 78 additions and 112 deletions

View File

@ -218,7 +218,7 @@ def initialize_runtime(
assert obs.exit_code == 0
action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
action.timeout = 1800
action.timeout = 3600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})

View File

@ -66,6 +66,11 @@ if [ "$USE_HINT_TEXT" = false ]; then
EVAL_NOTE="$EVAL_NOTE-no-hint"
fi
if [ -n "$EXP_NAME" ]; then
EVAL_NOTE="$EVAL_NOTE-$EXP_NAME"
fi
echo "EVAL_NOTE: $EVAL_NOTE"
unset SANDBOX_ENV_GITHUB_TOKEN # prevent the agent from using the github token to push
COMMAND="poetry run python evaluation/swe_bench/run_infer.py \

View File

@ -6,7 +6,7 @@ import pathlib
import subprocess
import time
import traceback
from concurrent.futures import Future, ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Any, Awaitable, Callable, TextIO
import pandas as pd
@ -78,12 +78,6 @@ class EvalOutput(BaseModel):
return json.dumps(dumped_dict)
class EvalError(BaseModel):
instance_id: str
error: str
stacktrace: str
def codeact_user_response(
state: State,
encapsulate_solution: bool = False,
@ -235,65 +229,58 @@ def prepare_dataset(
def update_progress(
result_or_future: Future | EvalOutput | EvalError,
instance: pd.Series,
result: EvalOutput,
pbar: tqdm,
output_fp: TextIO,
instance_queue: mp.Queue,
):
"""Update the progress bar and write the result to the output file."""
try:
if isinstance(result_or_future, Future):
result = result_or_future.result()
else:
result = result_or_future
except Exception as e:
# Handle the error
# Exception may be raised in the process_instance_func and will
# be raised here when we try to access the .result() of the future
handle_error(
EvalError(
instance_id=instance.instance_id,
error=str(e),
stacktrace=traceback.format_exc(),
),
instance,
pbar,
instance_queue,
)
return
# Update the progress bar and write the result to the output file
if isinstance(result, EvalOutput):
pbar.update(1)
pbar.set_description(f'Instance {result.instance_id}')
pbar.set_postfix_str(f'Test Result: {result.test_result}')
logger.info(
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
)
output_fp.write(json.dumps(result.model_dump()) + '\n')
output_fp.flush()
elif isinstance(result, EvalError):
handle_error(result, instance, pbar, instance_queue)
else:
raise ValueError(f'Unexpected result type: {type(result)}')
def handle_error(
error: EvalError, instance: pd.Series, pbar: tqdm, instance_queue: mp.Queue
):
"""Handle an error that occurred during evaluation."""
logger.error(
f'Retrying instance [{instance.instance_id}] due to error: {error.error}. Stacktrace:\n{error.stacktrace}'
+ '\n'
+ '-' * 10
+ '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
+ '-' * 10
+ '\n'
pbar.update(1)
pbar.set_description(f'Instance {result.instance_id}')
pbar.set_postfix_str(f'Test Result: {result.test_result}')
logger.info(
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
)
instance_queue.put(instance)
pbar.total += 1
pbar.refresh()
output_fp.write(json.dumps(result.model_dump()) + '\n')
output_fp.flush()
def _process_instance_wrapper(
process_instance_func: Callable[[pd.Series, EvalMetadata, bool], EvalOutput],
instance: pd.Series,
metadata: EvalMetadata,
use_mp: bool,
max_retries: int = 5,
) -> EvalOutput:
"""Wrap the process_instance_func to handle retries and errors.
Retry an instance up to max_retries times if it fails (e.g., due to transient network/runtime issues).
"""
for attempt in range(max_retries + 1):
try:
result = process_instance_func(instance, metadata, use_mp)
return result
except Exception as e:
if attempt == max_retries:
# Raise an error after all retries & stop the evaluation
raise RuntimeError(
f'Maximum error retries reached for instance {instance.instance_id}'
) from e
error = str(e)
stacktrace = traceback.format_exc()
msg = (
'-' * 10
+ '\n'
+ f'Error in instance [{instance.instance_id}]: {error}. Stacktrace:\n{stacktrace}'
+ '\n'
+ '-' * 10
+ '[This error occurred after maximum retries]'
+ '-' * 10
+ '\n'
)
logger.error(msg)
if use_mp:
print(msg) # use print to directly print to console
time.sleep(1) # Add a small delay before retrying
def run_evaluation(
@ -304,6 +291,7 @@ def run_evaluation(
process_instance_func: Callable[
[pd.Series, EvalMetadata, bool], Awaitable[EvalOutput]
],
max_retries: int = 5, # number of retries for each instance
):
use_multiprocessing = num_workers > 1
logger.info(
@ -311,10 +299,6 @@ def run_evaluation(
f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
)
instance_queue = mp.Queue()
for _, instance in dataset.iterrows():
instance_queue.put(instance)
total_instances = len(dataset)
pbar = tqdm(total=total_instances, desc='Instances processed')
output_fp = open(output_file, 'a')
@ -322,53 +306,30 @@ def run_evaluation(
try:
if use_multiprocessing:
with ProcessPoolExecutor(num_workers) as executor:
batch_futures = []
# Loop until there are *no more instances to be processed* and *all (in-progress) futures are done*
# since a running future may add new instances to the queue when error occurs
while not instance_queue.empty() or batch_futures:
# Submit new tasks if there are instances to be processed and available workers
while (
not instance_queue.empty() and len(batch_futures) < num_workers
):
try:
instance = instance_queue.get(block=False)
future = executor.submit(
process_instance_func, instance, metadata, True
)
future.instance = (
instance # Attach the instance to the future
)
batch_futures.append(future)
except mp.queues.Empty:
logger.warning(
'Queue is empty - This should not happen. This is a bug.'
)
break # Queue is empty, stop submitting new tasks
# Continue to wait for the futures to be done & remove completed futures
new_batch_futures = []
for future in batch_futures:
if future.done():
update_progress(
future, future.instance, pbar, output_fp, instance_queue
)
else:
new_batch_futures.append(future)
batch_futures = new_batch_futures
# Short sleep to prevent busy-waiting
time.sleep(1)
assert instance_queue.empty(), 'instance_queue should be empty after all futures are done. This is a bug.'
assert (
len(batch_futures) == 0
), 'batch_futures should be empty after all futures are done. This is a bug.'
futures = [
executor.submit(
_process_instance_wrapper,
process_instance_func=process_instance_func,
instance=instance,
metadata=metadata,
use_mp=True,
max_retries=max_retries,
)
for _, instance in dataset.iterrows()
]
for future in as_completed(futures):
result = future.result()
update_progress(result, pbar, output_fp)
else:
while not instance_queue.empty():
instance = instance_queue.get()
result = process_instance_func(instance, metadata, False)
update_progress(result, instance, pbar, output_fp, instance_queue)
for _, instance in dataset.iterrows():
result = _process_instance_wrapper(
process_instance_func=process_instance_func,
instance=instance,
metadata=metadata,
use_mp=False,
max_retries=max_retries,
)
update_progress(result, pbar, output_fp)
except KeyboardInterrupt:
print('\nKeyboardInterrupt received. Cleaning up...\n')