[eval] refactor process instance logic into update_progress (#3875)

This commit is contained in:
Xingyao Wang 2024-09-15 17:47:15 -05:00 committed by GitHub
parent ecf4aed28b
commit 2b3925278d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,8 +6,8 @@ import pathlib
import subprocess
import time
import traceback
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Awaitable, Callable
from concurrent.futures import Future, ProcessPoolExecutor
from typing import Any, Awaitable, Callable, TextIO
import pandas as pd
from pydantic import BaseModel
@ -234,18 +234,66 @@ def prepare_dataset(
return pd.DataFrame(new_dataset)
def process_instance(
instance, metadata, use_multiprocessing, process_instance_func
) -> EvalOutput | EvalError:
def update_progress(
result_or_future: Future | EvalOutput | EvalError,
instance: pd.Series,
pbar: tqdm,
output_fp: TextIO,
instance_queue: mp.Queue,
):
"""Update the progress bar and write the result to the output file."""
try:
return process_instance_func(instance, metadata, use_multiprocessing)
if isinstance(result_or_future, Future):
result = result_or_future.result()
else:
result = result_or_future
except Exception as e:
logger.error(f'Error processing instance [{instance.instance_id}]: {e}')
return EvalError(
instance_id=instance.instance_id,
error=str(e),
stacktrace=traceback.format_exc(),
# Handle the error
# Exception may be raised in the process_instance_func and will
# be raised here when we try to access the .result() of the future
handle_error(
EvalError(
instance_id=instance.instance_id,
error=str(e),
stacktrace=traceback.format_exc(),
),
instance,
pbar,
instance_queue,
)
return
# Update the progress bar and write the result to the output file
if isinstance(result, EvalOutput):
pbar.update(1)
pbar.set_description(f'Instance {result.instance_id}')
pbar.set_postfix_str(f'Test Result: {result.test_result}')
logger.info(
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
)
output_fp.write(json.dumps(result.model_dump()) + '\n')
output_fp.flush()
elif isinstance(result, EvalError):
handle_error(result, instance, pbar, instance_queue)
else:
raise ValueError(f'Unexpected result type: {type(result)}')
def handle_error(
error: EvalError, instance: pd.Series, pbar: tqdm, instance_queue: mp.Queue
):
"""Handle an error that occurred during evaluation."""
logger.error(
f'Retrying instance [{instance.instance_id}] due to error: {error.error}. Stacktrace:\n{error.stacktrace}'
+ '\n'
+ '-' * 10
+ '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
+ '-' * 10
+ '\n'
)
instance_queue.put(instance)
pbar.total += 1
pbar.refresh()
def run_evaluation(
@ -271,29 +319,6 @@ def run_evaluation(
pbar = tqdm(total=total_instances, desc='Instances processed')
output_fp = open(output_file, 'a')
def update_progress(result: EvalOutput | EvalError, instance: pd.Series):
if isinstance(result, EvalOutput):
pbar.update(1)
pbar.set_description(f'Instance {result.instance_id}')
pbar.set_postfix_str(f'Test Result: {result.test_result}')
logger.info(
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
)
output_fp.write(json.dumps(result.model_dump()) + '\n')
output_fp.flush()
else:
logger.error(
f'Retrying instance [{instance.instance_id}] due to error: {result.error}. Stacktrace:\n{result.stacktrace}'
+ '\n'
+ '-' * 10
+ '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
+ '-' * 10
+ '\n'
)
instance_queue.put(instance)
pbar.total += 1
pbar.refresh()
try:
if use_multiprocessing:
with ProcessPoolExecutor(num_workers) as executor:
@ -302,23 +327,17 @@ def run_evaluation(
# Loop until there are *no more instances to be processed* and *all (in-progress) futures are done*
# since a running future may add new instances to the queue when error occurs
while not instance_queue.empty() or batch_futures:
# Submit new tasks if **there are instances to be processed** and **available workers**
# Submit new tasks if there are instances to be processed and available workers
while (
not instance_queue.empty() and len(batch_futures) < num_workers
):
try:
instance = instance_queue.get(block=False)
future = executor.submit(
process_instance,
instance,
metadata,
True,
process_instance_func,
process_instance_func, instance, metadata, True
)
future.add_done_callback(
lambda f, inst=instance: update_progress(
f.result(), inst
)
future.instance = (
instance # Attach the instance to the future
)
batch_futures.append(future)
except mp.queues.Empty:
@ -328,12 +347,19 @@ def run_evaluation(
break # Queue is empty, stop submitting new tasks
# Continue to wait for the futures to be done & remove completed futures
batch_futures = [f for f in batch_futures if not f.done()]
new_batch_futures = []
for future in batch_futures:
if future.done():
update_progress(
future, future.instance, pbar, output_fp, instance_queue
)
else:
new_batch_futures.append(future)
batch_futures = new_batch_futures
# Short sleep to prevent busy-waiting
time.sleep(1)
# Ensure all futures are done
assert instance_queue.empty(), 'instance_queue should be empty after all futures are done. This is a bug.'
assert (
len(batch_futures) == 0
@ -341,10 +367,8 @@ def run_evaluation(
else:
while not instance_queue.empty():
instance = instance_queue.get()
result = process_instance(
instance, metadata, False, process_instance_func
)
update_progress(result, instance)
result = process_instance_func(instance, metadata, False)
update_progress(result, instance, pbar, output_fp, instance_queue)
except KeyboardInterrupt:
print('\nKeyboardInterrupt received. Cleaning up...\n')