mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Evaluation harness: Add agent config option (#6662)
This commit is contained in:
@@ -25,6 +25,7 @@ from openhands.core.config import (
|
||||
get_llm_config_arg,
|
||||
get_parser,
|
||||
)
|
||||
from openhands.core.config.utils import get_agent_config_arg
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction
|
||||
@@ -63,8 +64,12 @@ def get_config(
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
config.set_llm_config(metadata.llm_config)
|
||||
agent_config = config.get_agent_config(metadata.agent_class)
|
||||
agent_config.enable_prompt_extensions = False
|
||||
if metadata.agent_config:
|
||||
config.set_agent_config(metadata.agent_config, metadata.agent_class)
|
||||
else:
|
||||
logger.info('Agent config not provided, using default settings')
|
||||
agent_config = config.get_agent_config(metadata.agent_class)
|
||||
agent_config.enable_prompt_extensions = False
|
||||
return config
|
||||
|
||||
|
||||
@@ -238,6 +243,10 @@ if __name__ == '__main__':
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
agent_config = None
|
||||
if args.agent_config:
|
||||
agent_config = get_agent_config_arg(args.agent_config)
|
||||
|
||||
llm_config = None
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
@@ -256,6 +265,7 @@ if __name__ == '__main__':
|
||||
eval_output_dir=args.eval_output_dir,
|
||||
data_split=args.data_split,
|
||||
details={'gaia-level': args.level},
|
||||
agent_config=agent_config,
|
||||
)
|
||||
|
||||
dataset = load_dataset('gaia-benchmark/GAIA', args.level)
|
||||
|
||||
@@ -9,6 +9,7 @@ AGENT=$3
|
||||
EVAL_LIMIT=$4
|
||||
LEVELS=$5
|
||||
NUM_WORKERS=$6
|
||||
AGENT_CONFIG=$7
|
||||
|
||||
if [ -z "$NUM_WORKERS" ]; then
|
||||
NUM_WORKERS=1
|
||||
@@ -49,5 +50,9 @@ if [ -n "$EVAL_LIMIT" ]; then
|
||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||
fi
|
||||
|
||||
if [ -n "$AGENT_CONFIG" ]; then
|
||||
echo "AGENT_CONFIG: $AGENT_CONFIG"
|
||||
COMMAND="$COMMAND --agent-config $AGENT_CONFIG"
|
||||
|
||||
# Run the command
|
||||
eval $COMMAND
|
||||
|
||||
@@ -18,9 +18,11 @@ from openhands.core.config import (
|
||||
AppConfig,
|
||||
LLMConfig,
|
||||
SandboxConfig,
|
||||
get_agent_config_arg,
|
||||
get_llm_config_arg,
|
||||
get_parser,
|
||||
)
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
@@ -34,6 +36,7 @@ def get_config(
|
||||
task_short_name: str,
|
||||
mount_path_on_host: str,
|
||||
llm_config: LLMConfig,
|
||||
agent_config: AgentConfig,
|
||||
) -> AppConfig:
|
||||
config = AppConfig(
|
||||
run_as_openhands=False,
|
||||
@@ -58,6 +61,14 @@ def get_config(
|
||||
workspace_mount_path_in_sandbox='/outputs',
|
||||
)
|
||||
config.set_llm_config(llm_config)
|
||||
if agent_config:
|
||||
config.set_agent_config(agent_config)
|
||||
else:
|
||||
logger.info('Agent config not provided, using default settings')
|
||||
agent_config = AgentConfig(
|
||||
enable_prompt_extensions=False,
|
||||
)
|
||||
config.set_agent_config(agent_config)
|
||||
return config
|
||||
|
||||
|
||||
@@ -215,6 +226,10 @@ if __name__ == '__main__':
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
agent_config: AgentConfig | None = None
|
||||
if args.agent_config:
|
||||
agent_config = get_agent_config_arg(args.agent_config)
|
||||
|
||||
agent_llm_config: LLMConfig | None = None
|
||||
if args.agent_llm_config:
|
||||
agent_llm_config = get_llm_config_arg(args.agent_llm_config)
|
||||
@@ -255,7 +270,7 @@ if __name__ == '__main__':
|
||||
else:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
config: AppConfig = get_config(
|
||||
args.task_image_name, task_short_name, temp_dir, agent_llm_config
|
||||
args.task_image_name, task_short_name, temp_dir, agent_llm_config, agent_config
|
||||
)
|
||||
runtime: Runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
|
||||
@@ -44,6 +44,10 @@ while [[ $# -gt 0 ]]; do
|
||||
ENV_LLM_CONFIG="$2"
|
||||
shift 2
|
||||
;;
|
||||
--agent-config)
|
||||
AGENT_CONFIG="$2"
|
||||
shift 2
|
||||
;;
|
||||
--outputs-path)
|
||||
OUTPUTS_PATH="$2"
|
||||
shift 2
|
||||
@@ -140,13 +144,21 @@ while IFS= read -r task_image; do
|
||||
continue
|
||||
fi
|
||||
|
||||
export PYTHONPATH=evaluation/benchmarks/the_agent_company:\$PYTHONPATH && \
|
||||
poetry run python run_infer.py \
|
||||
--agent-llm-config "$AGENT_LLM_CONFIG" \
|
||||
--env-llm-config "$ENV_LLM_CONFIG" \
|
||||
--outputs-path "$OUTPUTS_PATH" \
|
||||
--server-hostname "$SERVER_HOSTNAME" \
|
||||
--task-image-name "$task_image"
|
||||
# Build the Python command
|
||||
COMMAND="poetry run python run_infer.py \
|
||||
--agent-llm-config \"$AGENT_LLM_CONFIG\" \
|
||||
--env-llm-config \"$ENV_LLM_CONFIG\" \
|
||||
--outputs-path \"$OUTPUTS_PATH\" \
|
||||
--server-hostname \"$SERVER_HOSTNAME\" \
|
||||
--task-image-name \"$task_image\""
|
||||
|
||||
# Add agent-config if it's defined
|
||||
if [ -n "$AGENT_CONFIG" ]; then
|
||||
COMMAND="$COMMAND --agent-config $AGENT_CONFIG"
|
||||
fi
|
||||
|
||||
export PYTHONPATH=evaluation/benchmarks/the_agent_company:$PYTHONPATH && \
|
||||
eval "$COMMAND"
|
||||
|
||||
# Prune unused images and volumes
|
||||
docker image rm "$task_image"
|
||||
|
||||
@@ -17,6 +17,7 @@ from tqdm import tqdm
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.config.condenser_config import (
|
||||
CondenserConfig,
|
||||
NoOpCondenserConfig,
|
||||
@@ -43,6 +44,7 @@ from openhands.memory.condenser import get_condensation_metadata
|
||||
class EvalMetadata(BaseModel):
|
||||
agent_class: str
|
||||
llm_config: LLMConfig
|
||||
agent_config: AgentConfig | None = None
|
||||
max_iterations: int
|
||||
eval_output_dir: str
|
||||
start_time: str
|
||||
@@ -167,6 +169,7 @@ def make_metadata(
|
||||
eval_output_dir: str,
|
||||
data_split: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
agent_config: AgentConfig | None = None,
|
||||
condenser_config: CondenserConfig | None = None,
|
||||
) -> EvalMetadata:
|
||||
model_name = llm_config.model.split('/')[-1]
|
||||
@@ -189,6 +192,7 @@ def make_metadata(
|
||||
metadata = EvalMetadata(
|
||||
agent_class=agent_class,
|
||||
llm_config=llm_config,
|
||||
agent_config=agent_config,
|
||||
max_iterations=max_iterations,
|
||||
eval_output_dir=eval_output_path,
|
||||
start_time=time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
|
||||
@@ -10,6 +10,7 @@ from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
from openhands.core.config.utils import (
|
||||
finalize_config,
|
||||
get_agent_config_arg,
|
||||
get_llm_config_arg,
|
||||
get_parser,
|
||||
load_app_config,
|
||||
@@ -31,6 +32,7 @@ __all__ = [
|
||||
'load_from_env',
|
||||
'load_from_toml',
|
||||
'finalize_config',
|
||||
'get_agent_config_arg',
|
||||
'get_llm_config_arg',
|
||||
'get_field_info',
|
||||
'get_parser',
|
||||
|
||||
@@ -298,7 +298,59 @@ def finalize_config(cfg: AppConfig):
|
||||
)
|
||||
|
||||
|
||||
# Utility function for command line --group argument
|
||||
def get_agent_config_arg(
|
||||
agent_config_arg: str, toml_file: str = 'config.toml'
|
||||
) -> AgentConfig | None:
|
||||
"""Get a group of agent settings from the config file.
|
||||
|
||||
A group in config.toml can look like this:
|
||||
|
||||
```
|
||||
[agent.default]
|
||||
enable_prompt_extensions = false
|
||||
```
|
||||
|
||||
The user-defined group name, like "default", is the argument to this function. The function will load the AgentConfig object
|
||||
with the settings of this group, from the config file, and set it as the AgentConfig object for the app.
|
||||
|
||||
Note that the group must be under "agent" group, or in other words, the group name must start with "agent.".
|
||||
|
||||
Args:
|
||||
agent_config_arg: The group of agent settings to get from the config.toml file.
|
||||
toml_file: Path to the configuration file to read from. Defaults to 'config.toml'.
|
||||
|
||||
Returns:
|
||||
AgentConfig: The AgentConfig object with the settings from the config file.
|
||||
"""
|
||||
# keep only the name, just in case
|
||||
agent_config_arg = agent_config_arg.strip('[]')
|
||||
|
||||
# truncate the prefix, just in case
|
||||
if agent_config_arg.startswith('agent.'):
|
||||
agent_config_arg = agent_config_arg[6:]
|
||||
|
||||
logger.openhands_logger.debug(f'Loading agent config from {agent_config_arg}')
|
||||
|
||||
# load the toml file
|
||||
try:
|
||||
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
|
||||
toml_config = toml.load(toml_contents)
|
||||
except FileNotFoundError as e:
|
||||
logger.openhands_logger.error(f'Config file not found: {e}')
|
||||
return None
|
||||
except toml.TomlDecodeError as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Cannot parse agent group from {agent_config_arg}. Exception: {e}'
|
||||
)
|
||||
return None
|
||||
|
||||
# update the agent config with the specified section
|
||||
if 'agent' in toml_config and agent_config_arg in toml_config['agent']:
|
||||
return AgentConfig(**toml_config['agent'][agent_config_arg])
|
||||
logger.openhands_logger.debug(f'Loading from toml failed for {agent_config_arg}')
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_config_arg(
|
||||
llm_config_arg: str, toml_file: str = 'config.toml'
|
||||
) -> LLMConfig | None:
|
||||
@@ -443,6 +495,12 @@ def get_parser() -> argparse.ArgumentParser:
|
||||
type=str,
|
||||
help='Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--agent-config',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Replace default Agent ([agent] section in config.toml) config with the specified Agent config, e.g. "CodeAct" for [agent.CodeAct] section in config.toml',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-n',
|
||||
'--name',
|
||||
|
||||
@@ -128,6 +128,7 @@ def test_help_message(capsys):
|
||||
'--eval-note EVAL_NOTE',
|
||||
'--eval-ids EVAL_IDS',
|
||||
'-l LLM_CONFIG, --llm-config LLM_CONFIG',
|
||||
'--agent-config AGENT_CONFIG',
|
||||
'-n NAME, --name NAME',
|
||||
'--config-file CONFIG_FILE',
|
||||
'--no-auto-continue',
|
||||
@@ -137,4 +138,4 @@ def test_help_message(capsys):
|
||||
assert element in help_output, f"Expected '{element}' to be in the help message"
|
||||
|
||||
option_count = help_output.count(' -')
|
||||
assert option_count == 17, f'Expected 17 options, found {option_count}'
|
||||
assert option_count == 18, f'Expected 18 options, found {option_count}'
|
||||
|
||||
@@ -9,6 +9,7 @@ from openhands.core.config import (
|
||||
AppConfig,
|
||||
LLMConfig,
|
||||
finalize_config,
|
||||
get_agent_config_arg,
|
||||
get_llm_config_arg,
|
||||
load_from_env,
|
||||
load_from_toml,
|
||||
@@ -781,3 +782,32 @@ memory_max_threads = 10
|
||||
assert codeact_config.memory_enabled is True
|
||||
browsing_config = default_config.get_agent_configs().get('BrowsingAgent')
|
||||
assert browsing_config.memory_max_threads == 10
|
||||
|
||||
|
||||
def test_get_agent_config_arg(temp_toml_file):
|
||||
temp_toml = """
|
||||
[core]
|
||||
max_iterations = 100
|
||||
max_budget_per_task = 4.0
|
||||
|
||||
[agent.CodeActAgent]
|
||||
memory_enabled = true
|
||||
enable_prompt_extensions = false
|
||||
|
||||
[agent.BrowsingAgent]
|
||||
memory_enabled = false
|
||||
enable_prompt_extensions = true
|
||||
memory_max_threads = 10
|
||||
"""
|
||||
|
||||
with open(temp_toml_file, 'w') as f:
|
||||
f.write(temp_toml)
|
||||
|
||||
agent_config = get_agent_config_arg('CodeActAgent', temp_toml_file)
|
||||
assert agent_config.memory_enabled
|
||||
assert not agent_config.enable_prompt_extensions
|
||||
|
||||
agent_config2 = get_agent_config_arg('BrowsingAgent', temp_toml_file)
|
||||
assert not agent_config2.memory_enabled
|
||||
assert agent_config2.enable_prompt_extensions
|
||||
assert agent_config2.memory_max_threads == 10
|
||||
|
||||
Reference in New Issue
Block a user