mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
413 lines
16 KiB
Python
Executable File
413 lines
16 KiB
Python
Executable File
import asyncio
|
|
import json
|
|
import os
|
|
|
|
import pandas as pd
|
|
from datasets import load_dataset
|
|
from litellm import completion as litellm_completion
|
|
|
|
import openhands.agenthub
|
|
from evaluation.benchmarks.swe_bench.run_infer import (
|
|
AgentFinishedCritic,
|
|
complete_runtime,
|
|
filter_dataset,
|
|
get_config,
|
|
initialize_runtime,
|
|
)
|
|
from evaluation.benchmarks.swe_bench.run_infer import (
|
|
get_instruction as base_get_instruction,
|
|
)
|
|
from evaluation.utils.shared import (
|
|
EvalException,
|
|
EvalMetadata,
|
|
EvalOutput,
|
|
get_metrics,
|
|
make_metadata,
|
|
prepare_dataset,
|
|
reset_logger_for_multiprocessing,
|
|
run_evaluation,
|
|
)
|
|
from openhands.controller.state.state import State
|
|
from openhands.core.config import (
|
|
get_evaluation_parser,
|
|
get_llm_config_arg,
|
|
)
|
|
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
|
from openhands.core.config.utils import get_condenser_config_arg
|
|
from openhands.core.logger import openhands_logger as logger
|
|
from openhands.core.main import create_runtime, run_controller
|
|
from openhands.events.action import MessageAction
|
|
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
|
from openhands.utils.async_utils import call_async_from_sync
|
|
|
|
USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
|
|
USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'false').lower() == 'true'
|
|
RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'false'
|
|
|
|
|
|
class FakeUser:
|
|
def __init__(self, issue, hints, files):
|
|
self.system_message = f"""
|
|
You are a GitHub user reporting an issue. Here are the details of your issue and environment:
|
|
|
|
Issue: {issue}
|
|
|
|
Hints: {hints}
|
|
|
|
Files relative to your current directory: {files}
|
|
|
|
Your task is to respond to questions from a coder who is trying to solve your issue. The coder has a summarized version of the issue you have. Follow these rules:
|
|
1. If the coder asks a question that is directly related to the information in the issue you have, provide that information.
|
|
2. Always stay in character as a user reporting an issue, not as an AI assistant.
|
|
3. Keep your responses concise and to the point.
|
|
4. The coder has limited turns to solve the issue. Do not interact with the coder beyond 3 turns.
|
|
|
|
Respond with "I don't have that information" if the question is unrelated or you're unsure.
|
|
"""
|
|
self.chat_history = [{'role': 'system', 'content': self.system_message}]
|
|
self.turns = 0
|
|
# Get LLM config from config.toml
|
|
self.llm_config = get_llm_config_arg(
|
|
'llm.fake_user'
|
|
) # You can change 'fake_user' to any config name you want
|
|
|
|
def generate_reply(self, question):
|
|
if self.turns > 3:
|
|
return 'Please continue working on the task. Do NOT ask for more help.'
|
|
self.chat_history.append({'role': 'user', 'content': question.content})
|
|
|
|
response = litellm_completion(
|
|
model=self.llm_config.model,
|
|
messages=self.chat_history,
|
|
api_key=self.llm_config.api_key.get_secret_value(),
|
|
temperature=self.llm_config.temperature,
|
|
base_url=self.llm_config.base_url,
|
|
)
|
|
|
|
reply = response.choices[0].message.content
|
|
self.chat_history.append({'role': 'assistant', 'content': reply})
|
|
self.turns += 1
|
|
return reply
|
|
|
|
|
|
# Global variable for fake user
|
|
fake_user = None
|
|
|
|
|
|
def get_fake_user_response(state: State) -> str:
|
|
global fake_user
|
|
if not fake_user:
|
|
return 'Please continue working on the task.'
|
|
last_agent_message = state.get_last_agent_message()
|
|
if last_agent_message:
|
|
return fake_user.generate_reply(last_agent_message)
|
|
return 'Please continue working on the task.'
|
|
|
|
|
|
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
|
'CodeActAgent': get_fake_user_response,
|
|
}
|
|
|
|
|
|
def get_instruction(instance: pd.Series, metadata: EvalMetadata) -> MessageAction:
|
|
instance_copy = instance.copy()
|
|
instance_copy.problem_statement = f'{instance.problem_statement}\n\nHints:\nThe user has not provided all the necessary details about the issue, and there are some hidden details that are helpful. Please ask the user specific questions using non-code commands to gather the relevant information that the user has to help you solve the issue. Ensure you have all the details you require to solve the issue.'
|
|
return base_get_instruction(instance_copy, metadata)
|
|
|
|
|
|
def process_instance(
|
|
instance: pd.Series,
|
|
metadata: EvalMetadata,
|
|
reset_logger: bool = True,
|
|
) -> EvalOutput:
|
|
config = get_config(instance, metadata)
|
|
global fake_user
|
|
original_issue = instance.original_issue
|
|
issue = str(original_issue)
|
|
fake_user = FakeUser(issue=issue, hints=instance.hints_text, files=instance.files)
|
|
|
|
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
|
if reset_logger:
|
|
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
|
|
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
|
|
else:
|
|
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
|
|
|
runtime = create_runtime(config)
|
|
call_async_from_sync(runtime.connect)
|
|
|
|
try:
|
|
initialize_runtime(runtime, instance, metadata)
|
|
|
|
message_action = get_instruction(instance, metadata)
|
|
|
|
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
|
state: State | None = asyncio.run(
|
|
run_controller(
|
|
config=config,
|
|
initial_user_action=message_action,
|
|
runtime=runtime,
|
|
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
|
metadata.agent_class
|
|
],
|
|
)
|
|
)
|
|
|
|
# if fatal error, throw EvalError to trigger re-run
|
|
if (
|
|
state
|
|
and state.last_error
|
|
and 'fatal error during agent execution' in state.last_error
|
|
and 'stuck in a loop' not in state.last_error
|
|
):
|
|
raise EvalException('Fatal error detected: ' + state.last_error)
|
|
|
|
# Get git patch
|
|
return_val = complete_runtime(runtime, instance)
|
|
git_patch = return_val['git_patch']
|
|
logger.info(
|
|
f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
|
|
)
|
|
finally:
|
|
runtime.close()
|
|
|
|
# Prepare test result
|
|
test_result = {
|
|
'git_patch': git_patch,
|
|
}
|
|
|
|
if state is None:
|
|
raise ValueError('State should not be None.')
|
|
|
|
histories = [event_to_dict(event) for event in state.history]
|
|
metrics = get_metrics(state)
|
|
|
|
# Save the output
|
|
instruction = message_action.content
|
|
if message_action.image_urls:
|
|
instruction += (
|
|
'\n\n<image_urls>' + '\n'.join(message_action.image_urls) + '</image_urls>'
|
|
)
|
|
output = EvalOutput(
|
|
instance_id=instance.instance_id,
|
|
instruction=instruction,
|
|
instance=instance.to_dict(),
|
|
test_result=test_result,
|
|
metadata=metadata,
|
|
history=histories,
|
|
metrics=metrics,
|
|
error=state.last_error if state and state.last_error else None,
|
|
)
|
|
return output
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = get_evaluation_parser()
|
|
parser.add_argument(
|
|
'--dataset',
|
|
type=str,
|
|
default='cmu-lti/interactive-swe',
|
|
help='dataset to evaluate on',
|
|
)
|
|
parser.add_argument(
|
|
'--split',
|
|
type=str,
|
|
default='test',
|
|
help='split to evaluate on',
|
|
)
|
|
|
|
args, _ = parser.parse_known_args()
|
|
|
|
# Load dataset from huggingface datasets
|
|
dataset = load_dataset(args.dataset, split=args.split)
|
|
swe_bench_tests = filter_dataset(dataset.to_pandas(), 'instance_id')
|
|
logger.info(
|
|
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
|
|
)
|
|
llm_config = None
|
|
if args.llm_config:
|
|
llm_config = get_llm_config_arg(args.llm_config)
|
|
llm_config.log_completions = True
|
|
# modify_params must be False for evaluation purpose, for reproducibility and accurancy of results
|
|
llm_config.modify_params = False
|
|
|
|
if llm_config is None:
|
|
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
|
|
|
|
# Get condenser config from environment variable
|
|
condenser_name = os.environ.get('EVAL_CONDENSER')
|
|
if condenser_name:
|
|
condenser_config = get_condenser_config_arg(condenser_name)
|
|
if condenser_config is None:
|
|
raise ValueError(
|
|
f'Could not find Condenser config: EVAL_CONDENSER={condenser_name}'
|
|
)
|
|
else:
|
|
# If no specific condenser config is provided via env var, default to NoOpCondenser
|
|
condenser_config = NoOpCondenserConfig()
|
|
logger.debug(
|
|
'No Condenser config provided via EVAL_CONDENSER, using NoOpCondenser.'
|
|
)
|
|
|
|
details = {'mode': 'interact'}
|
|
_agent_cls = openhands.agenthub.Agent.get_cls(args.agent_cls)
|
|
|
|
dataset_descrption = (
|
|
args.dataset.replace('/', '__') + '-' + args.split.replace('/', '__')
|
|
)
|
|
metadata = make_metadata(
|
|
llm_config,
|
|
dataset_descrption,
|
|
args.agent_cls,
|
|
args.max_iterations,
|
|
args.eval_note,
|
|
args.eval_output_dir,
|
|
details=details,
|
|
condenser_config=condenser_config,
|
|
)
|
|
|
|
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
|
print(f'### OUTPUT FILE: {output_file} ###')
|
|
|
|
# Run evaluation in iterative mode:
|
|
# If a rollout fails to output AgentFinishAction, we will try again until it succeeds OR total 3 attempts have been made.
|
|
ITERATIVE_EVAL_MODE = (
|
|
os.environ.get('ITERATIVE_EVAL_MODE', 'false').lower() == 'true'
|
|
)
|
|
ITERATIVE_EVAL_MODE_MAX_ATTEMPTS = int(
|
|
os.environ.get('ITERATIVE_EVAL_MODE_MAX_ATTEMPTS', '3')
|
|
)
|
|
|
|
if not ITERATIVE_EVAL_MODE:
|
|
# load the dataset
|
|
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
|
|
if len(instances) > 0 and not isinstance(
|
|
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
|
|
):
|
|
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
|
|
instances[col] = instances[col].apply(lambda x: str(x))
|
|
run_evaluation(
|
|
instances,
|
|
metadata,
|
|
output_file,
|
|
args.eval_num_workers,
|
|
process_instance,
|
|
timeout_seconds=8
|
|
* 60
|
|
* 60, # 8 hour PER instance should be more than enough
|
|
max_retries=5,
|
|
)
|
|
else:
|
|
critic = AgentFinishedCritic()
|
|
|
|
def get_cur_output_file_path(attempt: int) -> str:
|
|
return (
|
|
f'{output_file.removesuffix(".jsonl")}.critic_attempt_{attempt}.jsonl'
|
|
)
|
|
|
|
eval_ids = None
|
|
for attempt in range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1):
|
|
cur_output_file = get_cur_output_file_path(attempt)
|
|
logger.info(
|
|
f'Running evaluation with critic {critic.__class__.__name__} for attempt {attempt} of {ITERATIVE_EVAL_MODE_MAX_ATTEMPTS}.'
|
|
)
|
|
|
|
# For deterministic eval, we set temperature to 0.1 for (>1) attempt
|
|
# so hopefully we get slightly different results
|
|
if attempt > 1 and metadata.llm_config.temperature == 0:
|
|
logger.info(
|
|
f'Detected temperature is 0 for (>1) attempt {attempt}. Setting temperature to 0.1...'
|
|
)
|
|
metadata.llm_config.temperature = 0.1
|
|
|
|
# Load instances - at first attempt, we evaluate all instances
|
|
# On subsequent attempts, we only evaluate the instances that failed the previous attempt determined by critic
|
|
instances = prepare_dataset(
|
|
swe_bench_tests, cur_output_file, args.eval_n_limit, eval_ids=eval_ids
|
|
)
|
|
if len(instances) > 0 and not isinstance(
|
|
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
|
|
):
|
|
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
|
|
instances[col] = instances[col].apply(lambda x: str(x))
|
|
|
|
# Run evaluation - but save them to cur_output_file
|
|
logger.info(
|
|
f'Evaluating {len(instances)} instances for attempt {attempt}...'
|
|
)
|
|
run_evaluation(
|
|
instances,
|
|
metadata,
|
|
cur_output_file,
|
|
args.eval_num_workers,
|
|
process_instance,
|
|
timeout_seconds=8
|
|
* 60
|
|
* 60, # 8 hour PER instance should be more than enough
|
|
max_retries=5,
|
|
)
|
|
|
|
# When eval is done, we update eval_ids to the instances that failed the current attempt
|
|
instances_failed = []
|
|
logger.info(
|
|
f'Use critic {critic.__class__.__name__} to check {len(instances)} instances for attempt {attempt}...'
|
|
)
|
|
with open(cur_output_file, 'r') as f:
|
|
for line in f:
|
|
instance = json.loads(line)
|
|
try:
|
|
history = [
|
|
event_from_dict(event) for event in instance['history']
|
|
]
|
|
critic_result = critic.evaluate(
|
|
history, instance['test_result'].get('git_patch', '')
|
|
)
|
|
if not critic_result.success:
|
|
instances_failed.append(instance['instance_id'])
|
|
except Exception as e:
|
|
logger.error(
|
|
f'Error loading history for instance {instance["instance_id"]}: {e}'
|
|
)
|
|
instances_failed.append(instance['instance_id'])
|
|
logger.info(
|
|
f'{len(instances_failed)} instances failed the current attempt {attempt}: {instances_failed}'
|
|
)
|
|
eval_ids = instances_failed
|
|
|
|
# If no instances failed, we break
|
|
if len(instances_failed) == 0:
|
|
break
|
|
|
|
# Then we should aggregate the results from all attempts into the original output file
|
|
# and remove the intermediate files
|
|
logger.info(
|
|
'Aggregating results from all attempts into the original output file...'
|
|
)
|
|
fout = open(output_file, 'w')
|
|
added_instance_ids = set()
|
|
for attempt in reversed(range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1)):
|
|
cur_output_file = get_cur_output_file_path(attempt)
|
|
if not os.path.exists(cur_output_file):
|
|
logger.warning(
|
|
f'Intermediate output file {cur_output_file} does not exist. Skipping...'
|
|
)
|
|
continue
|
|
|
|
with open(cur_output_file, 'r') as f:
|
|
for line in f:
|
|
instance = json.loads(line)
|
|
# Also make sure git_patch is not empty - otherwise we fall back to previous attempt (empty patch is worse than anything else)
|
|
if (
|
|
instance['instance_id'] not in added_instance_ids
|
|
and instance['test_result'].get('git_patch', '').strip()
|
|
):
|
|
fout.write(line)
|
|
added_instance_ids.add(instance['instance_id'])
|
|
logger.info(
|
|
f'Aggregated instances from {cur_output_file}. Total instances added so far: {len(added_instance_ids)}'
|
|
)
|
|
fout.close()
|
|
logger.info(
|
|
f'Done! Total {len(added_instance_ids)} instances added to {output_file}'
|
|
)
|