mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
feat(evaluation): Filter task ids by difficulty for SWE Gym rollouts (#11490)
Co-authored-by: Graham Neubig <neubig@gmail.com> Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
38f2728cfa
commit
12d6da8130
79
evaluation/benchmarks/multi_swe_bench/compute_skip_ids.py
Normal file
79
evaluation/benchmarks/multi_swe_bench/compute_skip_ids.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import argparse
|
||||||
|
import fnmatch
|
||||||
|
import json
|
||||||
|
from collections import Counter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def find_final_reports(base_dir, pattern=None):
|
||||||
|
base_path = Path(base_dir)
|
||||||
|
if not base_path.exists():
|
||||||
|
raise FileNotFoundError(f'Base directory does not exist: {base_dir}')
|
||||||
|
|
||||||
|
# Find all final_report.json files
|
||||||
|
all_reports = list(base_path.rglob('final_report.json'))
|
||||||
|
|
||||||
|
if pattern is None:
|
||||||
|
return all_reports
|
||||||
|
|
||||||
|
# Filter by pattern
|
||||||
|
filtered_reports = []
|
||||||
|
for report in all_reports:
|
||||||
|
# Get relative path from base_dir for matching
|
||||||
|
rel_path = report.relative_to(base_path)
|
||||||
|
if fnmatch.fnmatch(str(rel_path), pattern):
|
||||||
|
filtered_reports.append(report)
|
||||||
|
|
||||||
|
return filtered_reports
|
||||||
|
|
||||||
|
|
||||||
|
def collect_resolved_ids(report_files):
|
||||||
|
id_counter = Counter()
|
||||||
|
|
||||||
|
for report_file in report_files:
|
||||||
|
with open(report_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if 'resolved_ids' not in data:
|
||||||
|
raise KeyError(f"'resolved_ids' key not found in {report_file}")
|
||||||
|
resolved_ids = data['resolved_ids']
|
||||||
|
id_counter.update(resolved_ids)
|
||||||
|
|
||||||
|
return id_counter
|
||||||
|
|
||||||
|
|
||||||
|
def get_skip_ids(id_counter, threshold):
|
||||||
|
return [id_str for id_str, count in id_counter.items() if count >= threshold]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Compute SKIP_IDS from resolved IDs in final_report.json files'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'threshold',
|
||||||
|
type=int,
|
||||||
|
help='Minimum number of times an ID must be resolved to be skipped',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--base-dir',
|
||||||
|
default='evaluation/evaluation_outputs/outputs',
|
||||||
|
help='Base directory to search for final_report.json files (default: evaluation/evaluation_outputs/outputs)',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--pattern',
|
||||||
|
default=None,
|
||||||
|
help='Glob pattern to filter paths (e.g., "*Multi-SWE-RL*/**/*gpt*")',
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
report_files = find_final_reports(args.base_dir, args.pattern)
|
||||||
|
id_counter = collect_resolved_ids(report_files)
|
||||||
|
|
||||||
|
skip_ids = get_skip_ids(id_counter, args.threshold)
|
||||||
|
skip_ids = [s.replace('/', '__').replace(':pr-', '-') for s in skip_ids]
|
||||||
|
skip_ids = ','.join(sorted(skip_ids))
|
||||||
|
print(skip_ids)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -747,10 +747,14 @@ def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
|
|||||||
subset = dataset[dataset[filter_column].isin(selected_ids)]
|
subset = dataset[dataset[filter_column].isin(selected_ids)]
|
||||||
logger.info(f'Retained {subset.shape[0]} tasks after filtering')
|
logger.info(f'Retained {subset.shape[0]} tasks after filtering')
|
||||||
return subset
|
return subset
|
||||||
skip_ids = os.environ.get('SKIP_IDS', '').split(',')
|
skip_ids = [id for id in os.environ.get('SKIP_IDS', '').split(',') if id]
|
||||||
if len(skip_ids) > 0:
|
if len(skip_ids) > 0:
|
||||||
|
logger.info(f'Dataset size before filtering: {dataset.shape[0]} tasks')
|
||||||
logger.info(f'Filtering {len(skip_ids)} tasks from "SKIP_IDS"...')
|
logger.info(f'Filtering {len(skip_ids)} tasks from "SKIP_IDS"...')
|
||||||
return dataset[~dataset[filter_column].isin(skip_ids)]
|
logger.info(f'SKIP_IDS:\n{skip_ids}')
|
||||||
|
filtered_dataset = dataset[~dataset[filter_column].isin(skip_ids)]
|
||||||
|
logger.info(f'Dataset size after filtering: {filtered_dataset.shape[0]} tasks')
|
||||||
|
return filtered_dataset
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
@ -768,6 +772,11 @@ if __name__ == '__main__':
|
|||||||
default='test',
|
default='test',
|
||||||
help='split to evaluate on',
|
help='split to evaluate on',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--filter_dataset_after_sampling',
|
||||||
|
action='store_true',
|
||||||
|
help='if provided, filter dataset after sampling instead of before',
|
||||||
|
)
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
# NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
|
# NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
|
||||||
@ -777,10 +786,24 @@ if __name__ == '__main__':
|
|||||||
logger.info(f'Loading dataset {args.dataset} with split {args.split} ')
|
logger.info(f'Loading dataset {args.dataset} with split {args.split} ')
|
||||||
dataset = load_dataset('json', data_files=args.dataset)
|
dataset = load_dataset('json', data_files=args.dataset)
|
||||||
dataset = dataset[args.split]
|
dataset = dataset[args.split]
|
||||||
swe_bench_tests = filter_dataset(dataset.to_pandas(), 'instance_id')
|
swe_bench_tests = dataset.to_pandas()
|
||||||
logger.info(
|
|
||||||
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
|
# Determine filter strategy based on flag
|
||||||
)
|
filter_func = None
|
||||||
|
if args.filter_dataset_after_sampling:
|
||||||
|
# Pass filter as callback to apply after sampling
|
||||||
|
def filter_func(df):
|
||||||
|
return filter_dataset(df, 'instance_id')
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks (filtering will occur after sampling)'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Apply filter before sampling
|
||||||
|
swe_bench_tests = filter_dataset(swe_bench_tests, 'instance_id')
|
||||||
|
logger.info(
|
||||||
|
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
|
||||||
|
)
|
||||||
|
|
||||||
llm_config = None
|
llm_config = None
|
||||||
if args.llm_config:
|
if args.llm_config:
|
||||||
@ -810,7 +833,9 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||||
print(f'### OUTPUT FILE: {output_file} ###')
|
print(f'### OUTPUT FILE: {output_file} ###')
|
||||||
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
|
instances = prepare_dataset(
|
||||||
|
swe_bench_tests, output_file, args.eval_n_limit, filter_func=filter_func
|
||||||
|
)
|
||||||
|
|
||||||
if len(instances) > 0 and not isinstance(
|
if len(instances) > 0 and not isinstance(
|
||||||
instances['FAIL_TO_PASS'][instances['FAIL_TO_PASS'].index[0]], str
|
instances['FAIL_TO_PASS'][instances['FAIL_TO_PASS'].index[0]], str
|
||||||
|
|||||||
@ -8,8 +8,14 @@
|
|||||||
MODEL=$1 # eg your llm config name in config.toml (eg: "llm.claude-3-5-sonnet-20241022-t05")
|
MODEL=$1 # eg your llm config name in config.toml (eg: "llm.claude-3-5-sonnet-20241022-t05")
|
||||||
EXP_NAME=$2 # "train-t05"
|
EXP_NAME=$2 # "train-t05"
|
||||||
EVAL_DATASET=$3 # path to original dataset (jsonl file)
|
EVAL_DATASET=$3 # path to original dataset (jsonl file)
|
||||||
N_WORKERS=${4:-64}
|
MAX_ITER=$4
|
||||||
N_RUNS=${5:-1}
|
N_WORKERS=${5:-64}
|
||||||
|
N_RUNS=${6:-1}
|
||||||
|
EVAL_LIMIT=${7:-}
|
||||||
|
SKIP_IDS_THRESHOLD=$8
|
||||||
|
SKIP_IDS_PATTERN=$9
|
||||||
|
INPUT_SKIP_IDS=${10}
|
||||||
|
FILTER_DATASET_AFTER_SAMPLING=${11:-}
|
||||||
|
|
||||||
export EXP_NAME=$EXP_NAME
|
export EXP_NAME=$EXP_NAME
|
||||||
# use 2x resources for rollout since some codebases are pretty resource-intensive
|
# use 2x resources for rollout since some codebases are pretty resource-intensive
|
||||||
@ -17,6 +23,7 @@ export DEFAULT_RUNTIME_RESOURCE_FACTOR=2
|
|||||||
echo "MODEL: $MODEL"
|
echo "MODEL: $MODEL"
|
||||||
echo "EXP_NAME: $EXP_NAME"
|
echo "EXP_NAME: $EXP_NAME"
|
||||||
echo "EVAL_DATASET: $EVAL_DATASET"
|
echo "EVAL_DATASET: $EVAL_DATASET"
|
||||||
|
echo "INPUT_SKIP_IDS: $INPUT_SKIP_IDS"
|
||||||
# Generate DATASET path by adding _with_runtime_ before .jsonl extension
|
# Generate DATASET path by adding _with_runtime_ before .jsonl extension
|
||||||
DATASET="${EVAL_DATASET%.jsonl}_with_runtime_.jsonl" # path to converted dataset
|
DATASET="${EVAL_DATASET%.jsonl}_with_runtime_.jsonl" # path to converted dataset
|
||||||
|
|
||||||
@ -35,9 +42,6 @@ else
|
|||||||
export SANDBOX_REMOTE_RUNTIME_API_URL="https://runtime.eval.all-hands.dev"
|
export SANDBOX_REMOTE_RUNTIME_API_URL="https://runtime.eval.all-hands.dev"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
#EVAL_LIMIT=3000
|
|
||||||
MAX_ITER=100
|
|
||||||
|
|
||||||
|
|
||||||
# ===== Run inference =====
|
# ===== Run inference =====
|
||||||
source "evaluation/utils/version_control.sh"
|
source "evaluation/utils/version_control.sh"
|
||||||
@ -69,17 +73,52 @@ function run_eval() {
|
|||||||
--dataset $DATASET \
|
--dataset $DATASET \
|
||||||
--split $SPLIT"
|
--split $SPLIT"
|
||||||
|
|
||||||
|
# Conditionally add filter flag
|
||||||
|
if [ "$FILTER_DATASET_AFTER_SAMPLING" = "true" ]; then
|
||||||
|
COMMAND="$COMMAND --filter_dataset_after_sampling"
|
||||||
|
fi
|
||||||
|
|
||||||
echo "Running command: $COMMAND"
|
echo "Running command: $COMMAND"
|
||||||
if [ -n "$EVAL_LIMIT" ]; then
|
if [ -n "$EVAL_LIMIT" ]; then
|
||||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Run the command
|
|
||||||
eval $COMMAND
|
eval $COMMAND
|
||||||
}
|
}
|
||||||
|
|
||||||
for run_idx in $(seq 1 $N_RUNS); do
|
for run_idx in $(seq 1 $N_RUNS); do
|
||||||
|
if [ -n "$SKIP_IDS_THRESHOLD" ]; then
|
||||||
|
echo "Computing SKIP_IDS for run $run_idx..."
|
||||||
|
SKIP_CMD="poetry run python evaluation/benchmarks/multi_swe_bench/compute_skip_ids.py $SKIP_IDS_THRESHOLD"
|
||||||
|
if [ -n "$SKIP_IDS_PATTERN" ]; then
|
||||||
|
SKIP_CMD="$SKIP_CMD --pattern \"$SKIP_IDS_PATTERN\""
|
||||||
|
fi
|
||||||
|
COMPUTED_SKIP_IDS=$(eval $SKIP_CMD)
|
||||||
|
SKIP_STATUS=$?
|
||||||
|
if [ $SKIP_STATUS -ne 0 ]; then
|
||||||
|
echo "ERROR: Skip IDs computation failed with exit code $SKIP_STATUS"
|
||||||
|
exit $SKIP_STATUS
|
||||||
|
fi
|
||||||
|
echo "COMPUTED_SKIP_IDS: $COMPUTED_SKIP_IDS"
|
||||||
|
else
|
||||||
|
echo "SKIP_IDS_THRESHOLD not provided, skipping SKIP_IDS computation"
|
||||||
|
COMPUTED_SKIP_IDS=""
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Concatenate COMPUTED_SKIP_IDS and INPUT_SKIP_IDS
|
||||||
|
if [ -n "$COMPUTED_SKIP_IDS" ] && [ -n "$INPUT_SKIP_IDS" ]; then
|
||||||
|
export SKIP_IDS="${COMPUTED_SKIP_IDS},${INPUT_SKIP_IDS}"
|
||||||
|
elif [ -n "$COMPUTED_SKIP_IDS" ]; then
|
||||||
|
export SKIP_IDS="$COMPUTED_SKIP_IDS"
|
||||||
|
elif [ -n "$INPUT_SKIP_IDS" ]; then
|
||||||
|
export SKIP_IDS="$INPUT_SKIP_IDS"
|
||||||
|
else
|
||||||
|
unset SKIP_IDS
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "FINAL SKIP_IDS: $SKIP_IDS"
|
||||||
|
echo ""
|
||||||
|
|
||||||
while true; do
|
while true; do
|
||||||
echo "### Running inference... ###"
|
echo "### Running inference... ###"
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Awaitable, Callable, TextIO
|
from typing import Any, Awaitable, Callable, Optional, TextIO
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -222,6 +222,7 @@ def prepare_dataset(
|
|||||||
eval_n_limit: int,
|
eval_n_limit: int,
|
||||||
eval_ids: list[str] | None = None,
|
eval_ids: list[str] | None = None,
|
||||||
skip_num: int | None = None,
|
skip_num: int | None = None,
|
||||||
|
filter_func: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
|
||||||
):
|
):
|
||||||
assert 'instance_id' in dataset.columns, (
|
assert 'instance_id' in dataset.columns, (
|
||||||
"Expected 'instance_id' column in the dataset. You should define your own unique identifier for each instance and use it as the 'instance_id' column."
|
"Expected 'instance_id' column in the dataset. You should define your own unique identifier for each instance and use it as the 'instance_id' column."
|
||||||
@ -265,6 +266,12 @@ def prepare_dataset(
|
|||||||
f'Randomly sampling {eval_n_limit} unique instances with random seed 42.'
|
f'Randomly sampling {eval_n_limit} unique instances with random seed 42.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if filter_func is not None:
|
||||||
|
dataset = filter_func(dataset)
|
||||||
|
logger.info(
|
||||||
|
f'Applied filter after sampling: {len(dataset)} instances remaining'
|
||||||
|
)
|
||||||
|
|
||||||
def make_serializable(instance_dict: dict) -> dict:
|
def make_serializable(instance_dict: dict) -> dict:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user