(fix): Conditional imports resolved in SWE-bench eval script while multiprocessing enabled (#7244)

Co-authored-by: Calvin Smith <calvin@all-hands.dev>
This commit is contained in:
Calvin Smith 2025-03-13 13:29:11 -06:00 committed by GitHub
parent 78d185b102
commit 303b7ab180
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,9 @@ import os
import subprocess
import tempfile
import time
from dataclasses import dataclass
from functools import partial
from typing import Callable
import pandas as pd
from tqdm import tqdm
@ -91,12 +93,22 @@ def get_config(metadata: EvalMetadata, instance: pd.Series) -> AppConfig:
return config
@dataclass
class ConditionalImports:
"""We instantiate the values in this dataclass differently if we're evaluating SWE-bench or SWE-Gym."""
get_eval_report: Callable
APPLY_PATCH_FAIL: str
APPLY_PATCH_PASS: str
def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
log_dir: str | None = None,
runtime_failure_count: int = 0,
conditional_imports: ConditionalImports | None = None,
) -> EvalOutput:
"""
Evaluate agent performance on a SWE-bench problem instance.
@ -108,9 +120,18 @@ def process_instance(
log_dir (str | None, default=None): Path to directory where log files will be written. Must
be provided if `reset_logger` is set.
conditional_imports: A dataclass containing values that are imported differently based on
whether we're evaluating SWE-bench or SWE-Gym.
Raises:
AssertionError: if the `reset_logger` flag is set without a provided log directory.
AssertionError: if `conditional_imports` is not provided.
"""
assert (
conditional_imports is not None
), 'conditional_imports must be provided to run process_instance using multiprocessing'
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
if reset_logger:
assert (
@ -124,7 +145,7 @@ def process_instance(
config = get_config(metadata, instance)
instance_id = instance.instance_id
model_patch = instance['model_patch']
test_spec: TestSpec = instance['test_spec']
test_spec = instance['test_spec']
logger.info(f'Starting evaluation for instance {instance_id}.')
if 'test_result' not in instance.keys():
@ -196,7 +217,9 @@ def process_instance(
instance['test_result']['apply_patch_output'] = apply_patch_output
if 'APPLY_PATCH_FAIL' in apply_patch_output:
logger.info(f'[{instance_id}] {APPLY_PATCH_FAIL}:\n{apply_patch_output}')
logger.info(
f'[{instance_id}] {conditional_imports.APPLY_PATCH_FAIL}:\n{apply_patch_output}'
)
instance['test_result']['report']['failed_apply_patch'] = True
return EvalOutput(
@ -205,7 +228,9 @@ def process_instance(
metadata=metadata,
)
elif 'APPLY_PATCH_PASS' in apply_patch_output:
logger.info(f'[{instance_id}] {APPLY_PATCH_PASS}:\n{apply_patch_output}')
logger.info(
f'[{instance_id}] {conditional_imports.APPLY_PATCH_PASS}:\n{apply_patch_output}'
)
# Run eval script in background and save output to log file
log_file = '/tmp/eval_output.log'
@ -271,7 +296,7 @@ def process_instance(
with open(test_output_path, 'w') as f:
f.write(test_output)
try:
_report = get_eval_report(
_report = conditional_imports.get_eval_report(
test_spec=test_spec,
prediction={
'model_patch': model_patch,
@ -345,7 +370,6 @@ if __name__ == '__main__':
)
from swegym.harness.test_spec import (
SWEbenchInstance,
TestSpec,
make_test_spec,
)
from swegym.harness.utils import load_swebench_dataset
@ -357,7 +381,6 @@ if __name__ == '__main__':
)
from swebench.harness.test_spec.test_spec import (
SWEbenchInstance,
TestSpec,
make_test_spec,
)
from swebench.harness.utils import load_swebench_dataset
@ -445,7 +468,15 @@ if __name__ == '__main__':
# The evaluation harness constrains the signature of `process_instance_func` but we need to
# pass extra information. Build a new function object to avoid issues with multiprocessing.
process_instance_func = partial(
process_instance, log_dir=output_file.replace('.jsonl', '.logs')
process_instance,
log_dir=output_file.replace('.jsonl', '.logs'),
# We have to explicitly pass these imports to the process_instance function, otherwise
# they won't be available in the multiprocessing context.
conditional_imports=ConditionalImports(
get_eval_report=get_eval_report,
APPLY_PATCH_FAIL=APPLY_PATCH_FAIL,
APPLY_PATCH_PASS=APPLY_PATCH_PASS,
),
)
run_evaluation(