mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Iterative evaluation with rule-based critic (#7293)
This commit is contained in:
@@ -18,6 +18,20 @@ Please follow instruction [here](../../README.md#setup) to setup your local deve
|
||||
|
||||
## Run Inference (Rollout) on SWE-Bench Instances: Generate Patch from Problem Statement
|
||||
|
||||
> [!NOTE]
|
||||
> **Iterative Evaluation Protocol**
|
||||
>
|
||||
> We have an iterative approach for more stable and reproducible results:
|
||||
> - For each instance, we attempt to generate a solution up to 3 times
|
||||
> - Each attempt continues until either:
|
||||
> 1. The agent successfully produces a patch with `AgentFinishAction`, or
|
||||
> 2. The attempt reaches the maximum iteration limit
|
||||
> - If an attempt fails, we retry with a fresh attempt (up to the 3-attempt maximum)
|
||||
> - If your LLM config has temperature=0, we will automatically use temperature=0.1 for the 2nd and 3rd attempts
|
||||
>
|
||||
> To enable this iterative protocol, set `export ITERATIVE_EVAL_MODE=true`
|
||||
|
||||
|
||||
### Running Locally with Docker
|
||||
|
||||
Make sure your Docker daemon is running, and you have ample disk space (at least 200-500GB, depends on the SWE-Bench set you are running on) for the instance-level docker image.
|
||||
@@ -45,7 +59,7 @@ to `CodeActAgent`.
|
||||
default, the script evaluates the entire SWE-bench_Lite test set (300 issues). Note:
|
||||
in order to use `eval_limit`, you must also set `agent`.
|
||||
- `max_iter`, e.g. `20`, is the maximum number of iterations for the agent to run. By
|
||||
default, it is set to 30.
|
||||
default, it is set to 60.
|
||||
- `num_workers`, e.g. `3`, is the number of parallel workers to run the evaluation. By
|
||||
default, it is set to 1.
|
||||
- `dataset`, a huggingface dataset name. e.g. `princeton-nlp/SWE-bench`, `princeton-nlp/SWE-bench_Lite`, or `princeton-nlp/SWE-bench_Verified`, specifies which dataset to evaluate on.
|
||||
|
||||
@@ -37,9 +37,10 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.critic import AgentFinishedCritic
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation, ErrorObservation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
from openhands.utils.shutdown_listener import sleep_if_should_continue
|
||||
@@ -122,7 +123,9 @@ You SHOULD NEVER attempt to browse the web.
|
||||
|
||||
|
||||
# TODO: migrate all swe-bench docker to ghcr.io/openhands
|
||||
DEFAULT_DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xingyaoww/')
|
||||
DEFAULT_DOCKER_IMAGE_PREFIX = os.environ.get(
|
||||
'EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xingyaoww/'
|
||||
)
|
||||
logger.info(f'Default docker image prefix: {DEFAULT_DOCKER_IMAGE_PREFIX}')
|
||||
|
||||
|
||||
@@ -637,20 +640,132 @@ if __name__ == '__main__':
|
||||
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
print(f'### OUTPUT FILE: {output_file} ###')
|
||||
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
|
||||
|
||||
if len(instances) > 0 and not isinstance(
|
||||
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
|
||||
):
|
||||
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
|
||||
instances[col] = instances[col].apply(lambda x: str(x))
|
||||
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
timeout_seconds=8 * 60 * 60, # 8 hour PER instance should be more than enough
|
||||
max_retries=5,
|
||||
# Run evaluation in iterative mode:
|
||||
# If a rollout fails to output AgentFinishAction, we will try again until it succeeds OR total 3 attempts have been made.
|
||||
ITERATIVE_EVAL_MODE = (
|
||||
os.environ.get('ITERATIVE_EVAL_MODE', 'false').lower() == 'true'
|
||||
)
|
||||
ITERATIVE_EVAL_MODE_MAX_ATTEMPTS = int(
|
||||
os.environ.get('ITERATIVE_EVAL_MODE_MAX_ATTEMPTS', '3')
|
||||
)
|
||||
|
||||
if not ITERATIVE_EVAL_MODE:
|
||||
# load the dataset
|
||||
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
|
||||
if len(instances) > 0 and not isinstance(
|
||||
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
|
||||
):
|
||||
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
|
||||
instances[col] = instances[col].apply(lambda x: str(x))
|
||||
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
timeout_seconds=8
|
||||
* 60
|
||||
* 60, # 8 hour PER instance should be more than enough
|
||||
max_retries=5,
|
||||
)
|
||||
else:
|
||||
critic = AgentFinishedCritic()
|
||||
|
||||
def get_cur_output_file_path(attempt: int) -> str:
|
||||
return (
|
||||
f'{output_file.removesuffix(".jsonl")}.critic_attempt_{attempt}.jsonl'
|
||||
)
|
||||
|
||||
eval_ids = None
|
||||
for attempt in range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1):
|
||||
cur_output_file = get_cur_output_file_path(attempt)
|
||||
logger.info(
|
||||
f'Running evaluation with critic {critic.__class__.__name__} for attempt {attempt} of {ITERATIVE_EVAL_MODE_MAX_ATTEMPTS}.'
|
||||
)
|
||||
|
||||
# For deterministic eval, we set temperature to 0.1 for (>1) attempt
|
||||
# so hopefully we get slightly different results
|
||||
if attempt > 1 and metadata.llm_config.temperature == 0:
|
||||
logger.info(
|
||||
f'Detected temperature is 0 for (>1) attempt {attempt}. Setting temperature to 0.1...'
|
||||
)
|
||||
metadata.llm_config.temperature = 0.1
|
||||
|
||||
# Load instances - at first attempt, we evaluate all instances
|
||||
# On subsequent attempts, we only evaluate the instances that failed the previous attempt determined by critic
|
||||
instances = prepare_dataset(
|
||||
swe_bench_tests, cur_output_file, args.eval_n_limit, eval_ids=eval_ids
|
||||
)
|
||||
if len(instances) > 0 and not isinstance(
|
||||
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
|
||||
):
|
||||
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
|
||||
instances[col] = instances[col].apply(lambda x: str(x))
|
||||
|
||||
# Run evaluation - but save them to cur_output_file
|
||||
logger.info(
|
||||
f'Evaluating {len(instances)} instances for attempt {attempt}...'
|
||||
)
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
cur_output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
timeout_seconds=8
|
||||
* 60
|
||||
* 60, # 8 hour PER instance should be more than enough
|
||||
max_retries=5,
|
||||
)
|
||||
|
||||
# When eval is done, we update eval_ids to the instances that failed the current attempt
|
||||
instances_failed = []
|
||||
logger.info(
|
||||
f'Use critic {critic.__class__.__name__} to check {len(instances)} instances for attempt {attempt}...'
|
||||
)
|
||||
with open(cur_output_file, 'r') as f:
|
||||
for line in f:
|
||||
instance = json.loads(line)
|
||||
history = [event_from_dict(event) for event in instance['history']]
|
||||
critic_result = critic.evaluate(history)
|
||||
if not critic_result.success:
|
||||
instances_failed.append(instance['instance_id'])
|
||||
logger.info(
|
||||
f'{len(instances_failed)} instances failed the current attempt {attempt}: {instances_failed}'
|
||||
)
|
||||
eval_ids = instances_failed
|
||||
|
||||
# If no instances failed, we break
|
||||
if len(instances_failed) == 0:
|
||||
break
|
||||
|
||||
# Then we should aggregate the results from all attempts into the original output file
|
||||
# and remove the intermediate files
|
||||
logger.info(
|
||||
'Aggregating results from all attempts into the original output file...'
|
||||
)
|
||||
fout = open(output_file, 'w')
|
||||
added_instance_ids = set()
|
||||
for attempt in reversed(range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1)):
|
||||
cur_output_file = get_cur_output_file_path(attempt)
|
||||
if not os.path.exists(cur_output_file):
|
||||
logger.warning(
|
||||
f'Intermediate output file {cur_output_file} does not exist. Skipping...'
|
||||
)
|
||||
continue
|
||||
|
||||
with open(cur_output_file, 'r') as f:
|
||||
for line in f:
|
||||
instance = json.loads(line)
|
||||
if instance['instance_id'] not in added_instance_ids:
|
||||
fout.write(line)
|
||||
added_instance_ids.add(instance['instance_id'])
|
||||
logger.info(
|
||||
f'Aggregated instances from {cur_output_file}. Total instances added so far: {len(added_instance_ids)}'
|
||||
)
|
||||
fout.close()
|
||||
logger.info(
|
||||
f'Done! Total {len(added_instance_ids)} instances added to {output_file}'
|
||||
)
|
||||
|
||||
@@ -25,8 +25,8 @@ if [ -z "$AGENT" ]; then
|
||||
fi
|
||||
|
||||
if [ -z "$MAX_ITER" ]; then
|
||||
echo "MAX_ITER not specified, use default 100"
|
||||
MAX_ITER=100
|
||||
echo "MAX_ITER not specified, use default 60"
|
||||
MAX_ITER=60
|
||||
fi
|
||||
|
||||
if [ -z "$RUN_WITH_BROWSING" ]; then
|
||||
|
||||
4
openhands/critic/__init__.py
Normal file
4
openhands/critic/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import BaseCritic, CriticResult
|
||||
from .finish_critic import AgentFinishedCritic
|
||||
|
||||
__all__ = ['CriticResult', 'BaseCritic', 'AgentFinishedCritic']
|
||||
31
openhands/critic/base.py
Normal file
31
openhands/critic/base.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import abc
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.events import Event
|
||||
|
||||
|
||||
class CriticResult(BaseModel):
|
||||
"""
|
||||
A critic result is a score and a message.
|
||||
"""
|
||||
|
||||
score: float
|
||||
message: str
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
"""
|
||||
Whether the agent is successful.
|
||||
"""
|
||||
return self.score >= 0.5
|
||||
|
||||
|
||||
class BaseCritic(abc.ABC):
|
||||
"""
|
||||
A critic is a function that takes in a list of events and returns a score about the quality of those events.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def evaluate(self, events: list[Event]) -> CriticResult:
|
||||
pass
|
||||
21
openhands/critic/finish_critic.py
Normal file
21
openhands/critic/finish_critic.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from openhands.critic.base import BaseCritic, CriticResult
|
||||
from openhands.events import Event
|
||||
from openhands.events.action import Action, AgentFinishAction
|
||||
|
||||
|
||||
class AgentFinishedCritic(BaseCritic):
|
||||
"""This is a simple rule-based critic that checks if the last event is an AgentFinishAction.
|
||||
|
||||
If not, it will return a score of 0 and a message indicating that the agent did not finish.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def evaluate(self, events: list[Event]) -> CriticResult:
|
||||
last_action = next((h for h in reversed(events) if isinstance(h, Action)), None)
|
||||
|
||||
if isinstance(last_action, AgentFinishAction):
|
||||
return CriticResult(score=1, message='Agent finished.')
|
||||
else:
|
||||
return CriticResult(score=0, message='Agent did not finish.')
|
||||
Reference in New Issue
Block a user