mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[eval] refactor process instance logic into update_progress (#3875)
This commit is contained in:
parent
ecf4aed28b
commit
2b3925278d
@ -6,8 +6,8 @@ import pathlib
|
||||
import subprocess
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Any, Awaitable, Callable
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from typing import Any, Awaitable, Callable, TextIO
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel
|
||||
@ -234,18 +234,66 @@ def prepare_dataset(
|
||||
return pd.DataFrame(new_dataset)
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance, metadata, use_multiprocessing, process_instance_func
|
||||
) -> EvalOutput | EvalError:
|
||||
def update_progress(
|
||||
result_or_future: Future | EvalOutput | EvalError,
|
||||
instance: pd.Series,
|
||||
pbar: tqdm,
|
||||
output_fp: TextIO,
|
||||
instance_queue: mp.Queue,
|
||||
):
|
||||
"""Update the progress bar and write the result to the output file."""
|
||||
try:
|
||||
return process_instance_func(instance, metadata, use_multiprocessing)
|
||||
if isinstance(result_or_future, Future):
|
||||
result = result_or_future.result()
|
||||
else:
|
||||
result = result_or_future
|
||||
except Exception as e:
|
||||
logger.error(f'Error processing instance [{instance.instance_id}]: {e}')
|
||||
return EvalError(
|
||||
instance_id=instance.instance_id,
|
||||
error=str(e),
|
||||
stacktrace=traceback.format_exc(),
|
||||
# Handle the error
|
||||
# Exception may be raised in the process_instance_func and will
|
||||
# be raised here when we try to access the .result() of the future
|
||||
handle_error(
|
||||
EvalError(
|
||||
instance_id=instance.instance_id,
|
||||
error=str(e),
|
||||
stacktrace=traceback.format_exc(),
|
||||
),
|
||||
instance,
|
||||
pbar,
|
||||
instance_queue,
|
||||
)
|
||||
return
|
||||
|
||||
# Update the progress bar and write the result to the output file
|
||||
if isinstance(result, EvalOutput):
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Instance {result.instance_id}')
|
||||
pbar.set_postfix_str(f'Test Result: {result.test_result}')
|
||||
logger.info(
|
||||
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
|
||||
)
|
||||
output_fp.write(json.dumps(result.model_dump()) + '\n')
|
||||
output_fp.flush()
|
||||
elif isinstance(result, EvalError):
|
||||
handle_error(result, instance, pbar, instance_queue)
|
||||
else:
|
||||
raise ValueError(f'Unexpected result type: {type(result)}')
|
||||
|
||||
|
||||
def handle_error(
|
||||
error: EvalError, instance: pd.Series, pbar: tqdm, instance_queue: mp.Queue
|
||||
):
|
||||
"""Handle an error that occurred during evaluation."""
|
||||
logger.error(
|
||||
f'Retrying instance [{instance.instance_id}] due to error: {error.error}. Stacktrace:\n{error.stacktrace}'
|
||||
+ '\n'
|
||||
+ '-' * 10
|
||||
+ '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
|
||||
+ '-' * 10
|
||||
+ '\n'
|
||||
)
|
||||
instance_queue.put(instance)
|
||||
pbar.total += 1
|
||||
pbar.refresh()
|
||||
|
||||
|
||||
def run_evaluation(
|
||||
@ -271,29 +319,6 @@ def run_evaluation(
|
||||
pbar = tqdm(total=total_instances, desc='Instances processed')
|
||||
output_fp = open(output_file, 'a')
|
||||
|
||||
def update_progress(result: EvalOutput | EvalError, instance: pd.Series):
|
||||
if isinstance(result, EvalOutput):
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Instance {result.instance_id}')
|
||||
pbar.set_postfix_str(f'Test Result: {result.test_result}')
|
||||
logger.info(
|
||||
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
|
||||
)
|
||||
output_fp.write(json.dumps(result.model_dump()) + '\n')
|
||||
output_fp.flush()
|
||||
else:
|
||||
logger.error(
|
||||
f'Retrying instance [{instance.instance_id}] due to error: {result.error}. Stacktrace:\n{result.stacktrace}'
|
||||
+ '\n'
|
||||
+ '-' * 10
|
||||
+ '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
|
||||
+ '-' * 10
|
||||
+ '\n'
|
||||
)
|
||||
instance_queue.put(instance)
|
||||
pbar.total += 1
|
||||
pbar.refresh()
|
||||
|
||||
try:
|
||||
if use_multiprocessing:
|
||||
with ProcessPoolExecutor(num_workers) as executor:
|
||||
@ -302,23 +327,17 @@ def run_evaluation(
|
||||
# Loop until there are *no more instances to be processed* and *all (in-progress) futures are done*
|
||||
# since a running future may add new instances to the queue when error occurs
|
||||
while not instance_queue.empty() or batch_futures:
|
||||
# Submit new tasks if **there are instances to be processed** and **available workers**
|
||||
# Submit new tasks if there are instances to be processed and available workers
|
||||
while (
|
||||
not instance_queue.empty() and len(batch_futures) < num_workers
|
||||
):
|
||||
try:
|
||||
instance = instance_queue.get(block=False)
|
||||
future = executor.submit(
|
||||
process_instance,
|
||||
instance,
|
||||
metadata,
|
||||
True,
|
||||
process_instance_func,
|
||||
process_instance_func, instance, metadata, True
|
||||
)
|
||||
future.add_done_callback(
|
||||
lambda f, inst=instance: update_progress(
|
||||
f.result(), inst
|
||||
)
|
||||
future.instance = (
|
||||
instance # Attach the instance to the future
|
||||
)
|
||||
batch_futures.append(future)
|
||||
except mp.queues.Empty:
|
||||
@ -328,12 +347,19 @@ def run_evaluation(
|
||||
break # Queue is empty, stop submitting new tasks
|
||||
|
||||
# Continue to wait for the futures to be done & remove completed futures
|
||||
batch_futures = [f for f in batch_futures if not f.done()]
|
||||
new_batch_futures = []
|
||||
for future in batch_futures:
|
||||
if future.done():
|
||||
update_progress(
|
||||
future, future.instance, pbar, output_fp, instance_queue
|
||||
)
|
||||
else:
|
||||
new_batch_futures.append(future)
|
||||
batch_futures = new_batch_futures
|
||||
|
||||
# Short sleep to prevent busy-waiting
|
||||
time.sleep(1)
|
||||
|
||||
# Ensure all futures are done
|
||||
assert instance_queue.empty(), 'instance_queue should be empty after all futures are done. This is a bug.'
|
||||
assert (
|
||||
len(batch_futures) == 0
|
||||
@ -341,10 +367,8 @@ def run_evaluation(
|
||||
else:
|
||||
while not instance_queue.empty():
|
||||
instance = instance_queue.get()
|
||||
result = process_instance(
|
||||
instance, metadata, False, process_instance_func
|
||||
)
|
||||
update_progress(result, instance)
|
||||
result = process_instance_func(instance, metadata, False)
|
||||
update_progress(result, instance, pbar, output_fp, instance_queue)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print('\nKeyboardInterrupt received. Cleaning up...\n')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user