[eval] fix evaluation git patch post-processing (#3979)

This commit is contained in:
Xingyao Wang
2024-09-20 09:55:43 -05:00
committed by GitHub
parent caa0f03c7b
commit b24a7821ec
2 changed files with 38 additions and 39 deletions

View File

@@ -3,7 +3,6 @@ import tempfile
import time
import pandas as pd
from pydantic import BaseModel
from swebench.harness.grading import get_eval_report
from swebench.harness.run_evaluation import (
APPLY_PATCH_FAIL,
@@ -35,6 +34,36 @@ DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xing
logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
def process_git_patch(patch):
if not isinstance(patch, str):
return ''
if not patch.strip():
# skip empty patches
return ''
patch = patch.replace('\r\n', '\n')
# There might be some weird characters at the beginning of the patch
# due to some OpenHands inference command outputs
# FOR EXAMPLE:
# git diff --no-color --cached 895f28f9cbed817c00ab68770433170d83132d90
# 0
# diff --git a/django/db/models/sql/.backup.query.py b/django/db/models/sql/.backup.query.py
# new file mode 100644
# index 0000000000..fc13db5948
# We "find" the first line that starts with "diff" and then we remove lines before it
lines = patch.split('\n')
for i, line in enumerate(lines):
if line.startswith('diff --git'):
patch = '\n'.join(lines[i:])
break
patch = patch.rstrip() + '\n' # Make sure the last line ends with a newline
return patch
def get_config(instance: pd.Series) -> AppConfig:
# We use a different instance image for the each instance of swe-bench eval
base_container_image = get_instance_docker_image(instance['instance_id'])
@@ -60,13 +89,6 @@ def get_config(instance: pd.Series) -> AppConfig:
return config
class SWEBenchEvalResult(BaseModel):
instance_id: str
apply_patch_output: str
test_output: str
resolved: bool
def process_instance(
instance: pd.Series,
metadata: EvalMetadata | None = None,
@@ -94,6 +116,7 @@ def process_instance(
'resolved': False,
'failed_apply_patch': False,
'error_eval': False,
'test_timeout': False,
}
if model_patch == '':
@@ -170,13 +193,14 @@ def process_instance(
# Poll for completion
start_time = time.time()
timeout = 900 # 15 minutes
timeout = 1800 # 30 minutes
while True:
seconds_elapsed = time.time() - start_time
if seconds_elapsed > timeout:
logger.info(
f'[{instance_id}] Evaluation timed out after {timeout} seconds'
)
instance['test_result']['report']['test_timeout'] = True
break
check_action = CmdRunAction(
command=f'ps -p {pid} > /dev/null; echo $?', keep_prompt=False
@@ -315,6 +339,9 @@ if __name__ == '__main__':
set(predictions.columns)
), 'Input file must contain instance_id and model_patch columns.'
# Process model_patch
predictions['model_patch'] = predictions['model_patch'].apply(process_git_patch)
# Merge predictions with dataset
predictions['instance'] = predictions['instance_id'].apply(
lambda x: instance_id_to_instance[x]

View File

@@ -3,6 +3,8 @@ import os
import pandas as pd
from evaluation.swe_bench.eval_infer import process_git_patch
parser = argparse.ArgumentParser()
parser.add_argument('oh_output_file', type=str)
args = parser.parse_args()
@@ -14,36 +16,6 @@ oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True)
model_name = os.path.basename(os.path.dirname(args.oh_output_file))
def process_git_patch(patch):
if not isinstance(patch, str):
return ''
if not patch.strip():
# skip empty patches
return ''
patch = patch.replace('\r\n', '\n')
# There might be some weird characters at the beginning of the patch
# due to some OpenHands inference command outputs
# FOR EXAMPLE:
# git diff --no-color --cached 895f28f9cbed817c00ab68770433170d83132d90
# 0
# diff --git a/django/db/models/sql/.backup.query.py b/django/db/models/sql/.backup.query.py
# new file mode 100644
# index 0000000000..fc13db5948
# We "find" the first line that starts with "diff" and then we remove lines before it
lines = patch.split('\n')
for i, line in enumerate(lines):
if line.startswith('diff --git'):
patch = '\n'.join(lines[i:])
break
patch = patch.rstrip() + '\n' # Make sure the last line ends with a newline
return patch
def convert_row_to_swebench_format(row):
if 'git_patch' in row:
model_patch = row['git_patch']