support evaluating task success right after inference.

This commit is contained in:
Xingyao Wang
2024-05-04 03:36:35 +08:00
parent 7b6dfef3db
commit 149f73b15e
3 changed files with 136 additions and 2 deletions

View File

@@ -4,6 +4,7 @@ import os
import pathlib
import time
import whatthepatch
from datasets import load_dataset
from tqdm import tqdm
@@ -37,6 +38,118 @@ AGENT_CLS_TO_INST_SUFFIX = {
'CodeActAgent': 'When you think you have fixed the issue through code changes, please run the following command: <execute_bash> exit </execute_bash>.\n'
}
def get_test_result(instance, sandbox, workspace_dir_name):
test_result = {'result': {}, 'metadata': {}}
try:
test_patch_parsed = whatthepatch.parse_patch(instance.test_patch)
# get a list of filepaths that are involved in the patch
involved_filepaths = set()
for patch in test_patch_parsed:
involved_filepaths.add(patch.header.old_path.removeprefix('a/'))
involved_filepaths.add(patch.header.new_path.removeprefix('b/'))
involved_filepaths = list(involved_filepaths)
test_result['metadata']['1_test_patch_parse_success'] = True
test_result['metadata']['1_test_involved_filepaths'] = involved_filepaths
except Exception as e:
logger.error(
f'Error parsing test patch for instance {instance.instance_id}: {e}'
)
test_result['metadata']['1_test_patch_parse_success'] = False
test_result['metadata']['1_test_patch_parse_error'] = str(e)
test_result['metadata']['1_test_involved_filepaths'] = None
involved_filepaths = []
# Try to revert the changes for involved filepaths
err_code, output = sandbox.execute(f'cd /workspace/{workspace_dir_name}')
test_result['metadata']['2_revert_test_involved_filepaths_success'] = []
for filepath in involved_filepaths:
err_code, output = sandbox.execute(
f'git checkout {instance["base_commit"]} -- {filepath}'
)
if err_code != 0:
logger.error(f'Error reverting changes for {filepath}: {output}')
test_result['metadata']['2_revert_test_involved_filepaths_success'].append(
False
)
else:
test_result['metadata']['2_revert_test_involved_filepaths_success'].append(
True
)
# Apply the testcase
err_code, output = sandbox.execute('git apply $SWE_TASK_DIR/test.patch')
if err_code != 0:
logger.error(f'Error applying test patch: {output}')
test_result['metadata']['3_apply_test_patch_success'] = False
test_result['metadata']['3_apply_test_patch_error'] = output
else:
test_result['metadata']['3_apply_test_patch_success'] = True
# Run the test command
err_code, output = sandbox.execute(
'$TEST_CMD > /workspace/$SWE_INSTANCE_ID.log 2>&1'
)
if err_code != 0:
logger.error(f'Error running test command: {output}')
test_result['metadata']['4_run_test_command_success'] = False
test_result['metadata']['4_run_test_command_error'] = output
else:
test_result['metadata']['4_run_test_command_success'] = True
# Get the test output
err_code, output = sandbox.execute('cat /workspace/$SWE_INSTANCE_ID.log')
if err_code != 0:
logger.error(f'Error getting test output: {output}')
test_result['metadata']['4_get_test_output_success'] = False
test_result['metadata']['4_get_test_output_error'] = output
else:
test_result['metadata']['4_get_test_output_success'] = True
test_result['test_output'] = output
# Reformat instance.json
# $SWE_TASK_DIR/instance.json is a dict {"XXX": "YYY"}, add a [ before and a ] after
err_code, output = sandbox.execute(
(
'cat $SWE_TASK_DIR/instance.json | sed "s/^{/[{/" | sed "s/}$/}]/" > /workspace/instance.json'
)
)
if err_code != 0:
logger.error(f'Error creating instance.json: {output}')
test_result['metadata']['5_reformat_instance_json_success'] = False
test_result['metadata']['5_reformat_instance_json_error'] = output
else:
test_result['metadata']['5_reformat_instance_json_success'] = True
# Get the instance report
err_code, output = sandbox.execute(
(
'cd /swe_util/OD-SWE-bench '
'&& export PYTHONPATH=$(pwd):$PYTHONPATH '
'&& conda run -n swe-bench-eval python swebench/metrics/get_instance_report.py --swe_bench_task /workspace/instance.json --log_path /workspace/$SWE_INSTANCE_ID.log'
)
)
if err_code != 0:
logger.error(f'Error getting instance report: {output}')
test_result['metadata']['6_get_instance_report_success'] = False
test_result['metadata']['6_get_instance_report_error'] = output
else:
test_result['metadata']['6_get_instance_report_success'] = True
test_result['result_raw'] = output
# try to parse output
for line in output.strip().split('\n'):
line = line.strip('-')
key, value = line.split(':')
value = value.strip()
try:
value = int(value)
except ValueError:
pass
test_result['result'][key.strip()] = value
return test_result
if __name__ == '__main__':
# Load the dataset
dataset = load_dataset('princeton-nlp/SWE-bench_Lite')
@@ -127,8 +240,10 @@ if __name__ == '__main__':
if instance.hints_text:
instruction += f'# Hints\n{instance.hints_text}\n\n'
instruction += (
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK. \n'
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK \n'
'You should ONLY interact with the environment provided to you.\n'
'You should NOT modify any existing test case files. '
'If needed, you can add new test cases in a NEW file to reproduce the issue.\n'
)
instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
@@ -141,6 +256,12 @@ if __name__ == '__main__':
git_patch = sandbox.get_diff_patch()
logger.info(f'Got git diff for instance {instance.instance_id}')
# ======= Attempt to evaluate the agent's edits =======
# Attempt to analyze the test patch to get involved filepaths
test_result = get_test_result(instance, sandbox, workspace_dir_name)
pbar.update()
pbar.set_postfix_str(f'Test Result: {test_result["result"]}')
# Save the output
output = {
'instance_id': instance.instance_id,
@@ -151,6 +272,7 @@ if __name__ == '__main__':
'history': [
(action.to_dict(), obs.to_dict()) for action, obs in state.history
],
'test_result': test_result,
}
output_fp.write(json.dumps(output) + '\n')
output_fp.flush()

13
poetry.lock generated
View File

@@ -6696,6 +6696,17 @@ MarkupSafe = ">=2.1.1"
[package.extras]
watchdog = ["watchdog (>=2.3)"]
[[package]]
name = "whatthepatch"
version = "1.0.5"
description = "A patch parsing and application library."
optional = false
python-versions = ">=3.7"
files = [
{file = "whatthepatch-1.0.5-py3-none-any.whl", hash = "sha256:6bc41f9f48a63384be4478d8b2d5b22185aac75be853cdcb150a2dc174ede7e1"},
{file = "whatthepatch-1.0.5.tar.gz", hash = "sha256:7f374c172812581bc3763587525d14a143aac7fe4220bc4676ecce0d86cb8f08"},
]
[[package]]
name = "wrapt"
version = "1.16.0"
@@ -7013,4 +7024,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "134dad10d40d1357a1f022bef59404bbbde6fb292da95ea1809d9f36d60dba3f"
content-hash = "e2616a3d4000c73443b222cd09424f5d3fc067ab1ab59fb6a1f44325f7c05d99"

View File

@@ -49,6 +49,7 @@ pytest-asyncio = "*"
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"
[build-system]
build-backend = "poetry.core.masonry.api"