Evaluation harness: Add agent config option (#6662)

This commit is contained in:
Boxuan Li
2025-02-13 12:05:03 -08:00
committed by GitHub
parent b197e0af47
commit ef12bc5381
9 changed files with 149 additions and 12 deletions

View File

@@ -25,6 +25,7 @@ from openhands.core.config import (
get_llm_config_arg,
get_parser,
)
from openhands.core.config.utils import get_agent_config_arg
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction
@@ -63,8 +64,12 @@ def get_config(
workspace_mount_path=None,
)
config.set_llm_config(metadata.llm_config)
agent_config = config.get_agent_config(metadata.agent_class)
agent_config.enable_prompt_extensions = False
if metadata.agent_config:
config.set_agent_config(metadata.agent_config, metadata.agent_class)
else:
logger.info('Agent config not provided, using default settings')
agent_config = config.get_agent_config(metadata.agent_class)
agent_config.enable_prompt_extensions = False
return config
@@ -238,6 +243,10 @@ if __name__ == '__main__':
)
args, _ = parser.parse_known_args()
agent_config = None
if args.agent_config:
agent_config = get_agent_config_arg(args.agent_config)
llm_config = None
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
@@ -256,6 +265,7 @@ if __name__ == '__main__':
eval_output_dir=args.eval_output_dir,
data_split=args.data_split,
details={'gaia-level': args.level},
agent_config=agent_config,
)
dataset = load_dataset('gaia-benchmark/GAIA', args.level)

View File

@@ -9,6 +9,7 @@ AGENT=$3
EVAL_LIMIT=$4
LEVELS=$5
NUM_WORKERS=$6
AGENT_CONFIG=$7
if [ -z "$NUM_WORKERS" ]; then
NUM_WORKERS=1
@@ -49,5 +50,9 @@ if [ -n "$EVAL_LIMIT" ]; then
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
fi
if [ -n "$AGENT_CONFIG" ]; then
echo "AGENT_CONFIG: $AGENT_CONFIG"
COMMAND="$COMMAND --agent-config $AGENT_CONFIG"
# Run the command
eval $COMMAND

View File

@@ -18,9 +18,11 @@ from openhands.core.config import (
AppConfig,
LLMConfig,
SandboxConfig,
get_agent_config_arg,
get_llm_config_arg,
get_parser,
)
from openhands.core.config.agent_config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction
@@ -34,6 +36,7 @@ def get_config(
task_short_name: str,
mount_path_on_host: str,
llm_config: LLMConfig,
agent_config: AgentConfig,
) -> AppConfig:
config = AppConfig(
run_as_openhands=False,
@@ -58,6 +61,14 @@ def get_config(
workspace_mount_path_in_sandbox='/outputs',
)
config.set_llm_config(llm_config)
if agent_config:
config.set_agent_config(agent_config)
else:
logger.info('Agent config not provided, using default settings')
agent_config = AgentConfig(
enable_prompt_extensions=False,
)
config.set_agent_config(agent_config)
return config
@@ -215,6 +226,10 @@ if __name__ == '__main__':
)
args, _ = parser.parse_known_args()
agent_config: AgentConfig | None = None
if args.agent_config:
agent_config = get_agent_config_arg(args.agent_config)
agent_llm_config: LLMConfig | None = None
if args.agent_llm_config:
agent_llm_config = get_llm_config_arg(args.agent_llm_config)
@@ -255,7 +270,7 @@ if __name__ == '__main__':
else:
temp_dir = tempfile.mkdtemp()
config: AppConfig = get_config(
args.task_image_name, task_short_name, temp_dir, agent_llm_config
args.task_image_name, task_short_name, temp_dir, agent_llm_config, agent_config
)
runtime: Runtime = create_runtime(config)
call_async_from_sync(runtime.connect)

View File

@@ -44,6 +44,10 @@ while [[ $# -gt 0 ]]; do
ENV_LLM_CONFIG="$2"
shift 2
;;
--agent-config)
AGENT_CONFIG="$2"
shift 2
;;
--outputs-path)
OUTPUTS_PATH="$2"
shift 2
@@ -140,13 +144,21 @@ while IFS= read -r task_image; do
continue
fi
export PYTHONPATH=evaluation/benchmarks/the_agent_company:\$PYTHONPATH && \
poetry run python run_infer.py \
--agent-llm-config "$AGENT_LLM_CONFIG" \
--env-llm-config "$ENV_LLM_CONFIG" \
--outputs-path "$OUTPUTS_PATH" \
--server-hostname "$SERVER_HOSTNAME" \
--task-image-name "$task_image"
# Build the Python command
COMMAND="poetry run python run_infer.py \
--agent-llm-config \"$AGENT_LLM_CONFIG\" \
--env-llm-config \"$ENV_LLM_CONFIG\" \
--outputs-path \"$OUTPUTS_PATH\" \
--server-hostname \"$SERVER_HOSTNAME\" \
--task-image-name \"$task_image\""
# Add agent-config if it's defined
if [ -n "$AGENT_CONFIG" ]; then
COMMAND="$COMMAND --agent-config $AGENT_CONFIG"
fi
export PYTHONPATH=evaluation/benchmarks/the_agent_company:$PYTHONPATH && \
eval "$COMMAND"
# Prune unused images and volumes
docker image rm "$task_image"

View File

@@ -17,6 +17,7 @@ from tqdm import tqdm
from openhands.controller.state.state import State
from openhands.core.config import LLMConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.condenser_config import (
CondenserConfig,
NoOpCondenserConfig,
@@ -43,6 +44,7 @@ from openhands.memory.condenser import get_condensation_metadata
class EvalMetadata(BaseModel):
agent_class: str
llm_config: LLMConfig
agent_config: AgentConfig | None = None
max_iterations: int
eval_output_dir: str
start_time: str
@@ -167,6 +169,7 @@ def make_metadata(
eval_output_dir: str,
data_split: str | None = None,
details: dict[str, Any] | None = None,
agent_config: AgentConfig | None = None,
condenser_config: CondenserConfig | None = None,
) -> EvalMetadata:
model_name = llm_config.model.split('/')[-1]
@@ -189,6 +192,7 @@ def make_metadata(
metadata = EvalMetadata(
agent_class=agent_class,
llm_config=llm_config,
agent_config=agent_config,
max_iterations=max_iterations,
eval_output_dir=eval_output_path,
start_time=time.strftime('%Y-%m-%d %H:%M:%S'),

View File

@@ -10,6 +10,7 @@ from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
from openhands.core.config.utils import (
finalize_config,
get_agent_config_arg,
get_llm_config_arg,
get_parser,
load_app_config,
@@ -31,6 +32,7 @@ __all__ = [
'load_from_env',
'load_from_toml',
'finalize_config',
'get_agent_config_arg',
'get_llm_config_arg',
'get_field_info',
'get_parser',

View File

@@ -298,7 +298,59 @@ def finalize_config(cfg: AppConfig):
)
# Utility function for command line --group argument
def get_agent_config_arg(
agent_config_arg: str, toml_file: str = 'config.toml'
) -> AgentConfig | None:
"""Get a group of agent settings from the config file.
A group in config.toml can look like this:
```
[agent.default]
enable_prompt_extensions = false
```
The user-defined group name, like "default", is the argument to this function. The function will load the AgentConfig object
with the settings of this group, from the config file, and set it as the AgentConfig object for the app.
Note that the group must be under "agent" group, or in other words, the group name must start with "agent.".
Args:
agent_config_arg: The group of agent settings to get from the config.toml file.
toml_file: Path to the configuration file to read from. Defaults to 'config.toml'.
Returns:
AgentConfig: The AgentConfig object with the settings from the config file.
"""
# keep only the name, just in case
agent_config_arg = agent_config_arg.strip('[]')
# truncate the prefix, just in case
if agent_config_arg.startswith('agent.'):
agent_config_arg = agent_config_arg[6:]
logger.openhands_logger.debug(f'Loading agent config from {agent_config_arg}')
# load the toml file
try:
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
toml_config = toml.load(toml_contents)
except FileNotFoundError as e:
logger.openhands_logger.error(f'Config file not found: {e}')
return None
except toml.TomlDecodeError as e:
logger.openhands_logger.error(
f'Cannot parse agent group from {agent_config_arg}. Exception: {e}'
)
return None
# update the agent config with the specified section
if 'agent' in toml_config and agent_config_arg in toml_config['agent']:
return AgentConfig(**toml_config['agent'][agent_config_arg])
logger.openhands_logger.debug(f'Loading from toml failed for {agent_config_arg}')
return None
def get_llm_config_arg(
llm_config_arg: str, toml_file: str = 'config.toml'
) -> LLMConfig | None:
@@ -443,6 +495,12 @@ def get_parser() -> argparse.ArgumentParser:
type=str,
help='Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml',
)
parser.add_argument(
'--agent-config',
default=None,
type=str,
help='Replace default Agent ([agent] section in config.toml) config with the specified Agent config, e.g. "CodeAct" for [agent.CodeAct] section in config.toml',
)
parser.add_argument(
'-n',
'--name',

View File

@@ -128,6 +128,7 @@ def test_help_message(capsys):
'--eval-note EVAL_NOTE',
'--eval-ids EVAL_IDS',
'-l LLM_CONFIG, --llm-config LLM_CONFIG',
'--agent-config AGENT_CONFIG',
'-n NAME, --name NAME',
'--config-file CONFIG_FILE',
'--no-auto-continue',
@@ -137,4 +138,4 @@ def test_help_message(capsys):
assert element in help_output, f"Expected '{element}' to be in the help message"
option_count = help_output.count(' -')
assert option_count == 17, f'Expected 17 options, found {option_count}'
assert option_count == 18, f'Expected 18 options, found {option_count}'

View File

@@ -9,6 +9,7 @@ from openhands.core.config import (
AppConfig,
LLMConfig,
finalize_config,
get_agent_config_arg,
get_llm_config_arg,
load_from_env,
load_from_toml,
@@ -781,3 +782,32 @@ memory_max_threads = 10
assert codeact_config.memory_enabled is True
browsing_config = default_config.get_agent_configs().get('BrowsingAgent')
assert browsing_config.memory_max_threads == 10
def test_get_agent_config_arg(temp_toml_file):
temp_toml = """
[core]
max_iterations = 100
max_budget_per_task = 4.0
[agent.CodeActAgent]
memory_enabled = true
enable_prompt_extensions = false
[agent.BrowsingAgent]
memory_enabled = false
enable_prompt_extensions = true
memory_max_threads = 10
"""
with open(temp_toml_file, 'w') as f:
f.write(temp_toml)
agent_config = get_agent_config_arg('CodeActAgent', temp_toml_file)
assert agent_config.memory_enabled
assert not agent_config.enable_prompt_extensions
agent_config2 = get_agent_config_arg('BrowsingAgent', temp_toml_file)
assert not agent_config2.memory_enabled
assert agent_config2.enable_prompt_extensions
assert agent_config2.memory_max_threads == 10