mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
support evaluating task success right after inference.
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
import whatthepatch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -37,6 +38,118 @@ AGENT_CLS_TO_INST_SUFFIX = {
|
||||
'CodeActAgent': 'When you think you have fixed the issue through code changes, please run the following command: <execute_bash> exit </execute_bash>.\n'
|
||||
}
|
||||
|
||||
|
||||
def get_test_result(instance, sandbox, workspace_dir_name):
|
||||
test_result = {'result': {}, 'metadata': {}}
|
||||
try:
|
||||
test_patch_parsed = whatthepatch.parse_patch(instance.test_patch)
|
||||
# get a list of filepaths that are involved in the patch
|
||||
involved_filepaths = set()
|
||||
for patch in test_patch_parsed:
|
||||
involved_filepaths.add(patch.header.old_path.removeprefix('a/'))
|
||||
involved_filepaths.add(patch.header.new_path.removeprefix('b/'))
|
||||
involved_filepaths = list(involved_filepaths)
|
||||
test_result['metadata']['1_test_patch_parse_success'] = True
|
||||
test_result['metadata']['1_test_involved_filepaths'] = involved_filepaths
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error parsing test patch for instance {instance.instance_id}: {e}'
|
||||
)
|
||||
test_result['metadata']['1_test_patch_parse_success'] = False
|
||||
test_result['metadata']['1_test_patch_parse_error'] = str(e)
|
||||
test_result['metadata']['1_test_involved_filepaths'] = None
|
||||
involved_filepaths = []
|
||||
|
||||
# Try to revert the changes for involved filepaths
|
||||
err_code, output = sandbox.execute(f'cd /workspace/{workspace_dir_name}')
|
||||
test_result['metadata']['2_revert_test_involved_filepaths_success'] = []
|
||||
for filepath in involved_filepaths:
|
||||
err_code, output = sandbox.execute(
|
||||
f'git checkout {instance["base_commit"]} -- {filepath}'
|
||||
)
|
||||
if err_code != 0:
|
||||
logger.error(f'Error reverting changes for {filepath}: {output}')
|
||||
test_result['metadata']['2_revert_test_involved_filepaths_success'].append(
|
||||
False
|
||||
)
|
||||
else:
|
||||
test_result['metadata']['2_revert_test_involved_filepaths_success'].append(
|
||||
True
|
||||
)
|
||||
|
||||
# Apply the testcase
|
||||
err_code, output = sandbox.execute('git apply $SWE_TASK_DIR/test.patch')
|
||||
if err_code != 0:
|
||||
logger.error(f'Error applying test patch: {output}')
|
||||
test_result['metadata']['3_apply_test_patch_success'] = False
|
||||
test_result['metadata']['3_apply_test_patch_error'] = output
|
||||
else:
|
||||
test_result['metadata']['3_apply_test_patch_success'] = True
|
||||
|
||||
# Run the test command
|
||||
err_code, output = sandbox.execute(
|
||||
'$TEST_CMD > /workspace/$SWE_INSTANCE_ID.log 2>&1'
|
||||
)
|
||||
if err_code != 0:
|
||||
logger.error(f'Error running test command: {output}')
|
||||
test_result['metadata']['4_run_test_command_success'] = False
|
||||
test_result['metadata']['4_run_test_command_error'] = output
|
||||
else:
|
||||
test_result['metadata']['4_run_test_command_success'] = True
|
||||
|
||||
# Get the test output
|
||||
err_code, output = sandbox.execute('cat /workspace/$SWE_INSTANCE_ID.log')
|
||||
if err_code != 0:
|
||||
logger.error(f'Error getting test output: {output}')
|
||||
test_result['metadata']['4_get_test_output_success'] = False
|
||||
test_result['metadata']['4_get_test_output_error'] = output
|
||||
else:
|
||||
test_result['metadata']['4_get_test_output_success'] = True
|
||||
test_result['test_output'] = output
|
||||
|
||||
# Reformat instance.json
|
||||
# $SWE_TASK_DIR/instance.json is a dict {"XXX": "YYY"}, add a [ before and a ] after
|
||||
err_code, output = sandbox.execute(
|
||||
(
|
||||
'cat $SWE_TASK_DIR/instance.json | sed "s/^{/[{/" | sed "s/}$/}]/" > /workspace/instance.json'
|
||||
)
|
||||
)
|
||||
if err_code != 0:
|
||||
logger.error(f'Error creating instance.json: {output}')
|
||||
test_result['metadata']['5_reformat_instance_json_success'] = False
|
||||
test_result['metadata']['5_reformat_instance_json_error'] = output
|
||||
else:
|
||||
test_result['metadata']['5_reformat_instance_json_success'] = True
|
||||
|
||||
# Get the instance report
|
||||
err_code, output = sandbox.execute(
|
||||
(
|
||||
'cd /swe_util/OD-SWE-bench '
|
||||
'&& export PYTHONPATH=$(pwd):$PYTHONPATH '
|
||||
'&& conda run -n swe-bench-eval python swebench/metrics/get_instance_report.py --swe_bench_task /workspace/instance.json --log_path /workspace/$SWE_INSTANCE_ID.log'
|
||||
)
|
||||
)
|
||||
if err_code != 0:
|
||||
logger.error(f'Error getting instance report: {output}')
|
||||
test_result['metadata']['6_get_instance_report_success'] = False
|
||||
test_result['metadata']['6_get_instance_report_error'] = output
|
||||
else:
|
||||
test_result['metadata']['6_get_instance_report_success'] = True
|
||||
test_result['result_raw'] = output
|
||||
|
||||
# try to parse output
|
||||
for line in output.strip().split('\n'):
|
||||
line = line.strip('-')
|
||||
key, value = line.split(':')
|
||||
value = value.strip()
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
test_result['result'][key.strip()] = value
|
||||
return test_result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load the dataset
|
||||
dataset = load_dataset('princeton-nlp/SWE-bench_Lite')
|
||||
@@ -127,8 +240,10 @@ if __name__ == '__main__':
|
||||
if instance.hints_text:
|
||||
instruction += f'# Hints\n{instance.hints_text}\n\n'
|
||||
instruction += (
|
||||
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK. \n'
|
||||
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK \n'
|
||||
'You should ONLY interact with the environment provided to you.\n'
|
||||
'You should NOT modify any existing test case files. '
|
||||
'If needed, you can add new test cases in a NEW file to reproduce the issue.\n'
|
||||
)
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
|
||||
|
||||
@@ -141,6 +256,12 @@ if __name__ == '__main__':
|
||||
git_patch = sandbox.get_diff_patch()
|
||||
logger.info(f'Got git diff for instance {instance.instance_id}')
|
||||
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
# Attempt to analyze the test patch to get involved filepaths
|
||||
test_result = get_test_result(instance, sandbox, workspace_dir_name)
|
||||
pbar.update()
|
||||
pbar.set_postfix_str(f'Test Result: {test_result["result"]}')
|
||||
|
||||
# Save the output
|
||||
output = {
|
||||
'instance_id': instance.instance_id,
|
||||
@@ -151,6 +272,7 @@ if __name__ == '__main__':
|
||||
'history': [
|
||||
(action.to_dict(), obs.to_dict()) for action, obs in state.history
|
||||
],
|
||||
'test_result': test_result,
|
||||
}
|
||||
output_fp.write(json.dumps(output) + '\n')
|
||||
output_fp.flush()
|
||||
|
||||
13
poetry.lock
generated
13
poetry.lock
generated
@@ -6696,6 +6696,17 @@ MarkupSafe = ">=2.1.1"
|
||||
[package.extras]
|
||||
watchdog = ["watchdog (>=2.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "whatthepatch"
|
||||
version = "1.0.5"
|
||||
description = "A patch parsing and application library."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "whatthepatch-1.0.5-py3-none-any.whl", hash = "sha256:6bc41f9f48a63384be4478d8b2d5b22185aac75be853cdcb150a2dc174ede7e1"},
|
||||
{file = "whatthepatch-1.0.5.tar.gz", hash = "sha256:7f374c172812581bc3763587525d14a143aac7fe4220bc4676ecce0d86cb8f08"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.16.0"
|
||||
@@ -7013,4 +7024,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "134dad10d40d1357a1f022bef59404bbbde6fb292da95ea1809d9f36d60dba3f"
|
||||
content-hash = "e2616a3d4000c73443b222cd09424f5d3fc067ab1ab59fb6a1f44325f7c05d99"
|
||||
|
||||
@@ -49,6 +49,7 @@ pytest-asyncio = "*"
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
[build-system]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
Reference in New Issue
Block a user