mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[GAIA] Add prompt improvement to alleviate solution parsing issue & support Tavily search tools (#9057)
This commit is contained in:
parent
e6e0f4673f
commit
ddaa186971
1
evaluation/benchmarks/gaia/.gitignore
vendored
Normal file
1
evaluation/benchmarks/gaia/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
data/
|
||||
@ -6,6 +6,13 @@ This folder contains evaluation harness for evaluating agents on the [GAIA bench
|
||||
|
||||
Please follow instruction [here](../../README.md#setup) to setup your local development environment and LLM.
|
||||
|
||||
To enable the Tavily MCP Server, you can add the Tavily API key under the `core` section of your `config.toml` file, like below:
|
||||
|
||||
```toml
|
||||
[core]
|
||||
search_api_key = "tvly-******"
|
||||
```
|
||||
|
||||
## Run the evaluation
|
||||
|
||||
We are using the GAIA dataset hosted on [Hugging Face](https://huggingface.co/datasets/gaia-benchmark/GAIA).
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
@ -6,6 +7,7 @@ import re
|
||||
import huggingface_hub
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
from pydantic import SecretStr
|
||||
|
||||
from evaluation.benchmarks.gaia.scorer import question_scorer
|
||||
from evaluation.utils.shared import (
|
||||
@ -24,6 +26,7 @@ from openhands.core.config import (
|
||||
OpenHandsConfig,
|
||||
get_llm_config_arg,
|
||||
get_parser,
|
||||
load_from_toml,
|
||||
)
|
||||
from openhands.core.config.utils import get_agent_config_arg
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -41,7 +44,7 @@ AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
}
|
||||
|
||||
AGENT_CLS_TO_INST_SUFFIX = {
|
||||
'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
|
||||
'CodeActAgent': 'When you think you have solved the question, please use the finish tool and include your final answer in the message parameter of the finish tool. Your final answer MUST be encapsulated within <solution> and </solution>.\n'
|
||||
}
|
||||
|
||||
|
||||
@ -49,7 +52,7 @@ def get_config(
|
||||
metadata: EvalMetadata,
|
||||
) -> OpenHandsConfig:
|
||||
sandbox_config = get_default_sandbox_config_for_eval()
|
||||
sandbox_config.base_container_image = 'python:3.12-bookworm'
|
||||
sandbox_config.base_container_image = 'nikolaik/python-nodejs:python3.12-nodejs22'
|
||||
config = OpenHandsConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
@ -67,6 +70,11 @@ def get_config(
|
||||
logger.info('Agent config not provided, using default settings')
|
||||
agent_config = config.get_agent_config(metadata.agent_class)
|
||||
agent_config.enable_prompt_extensions = False
|
||||
|
||||
config_copy = copy.deepcopy(config)
|
||||
load_from_toml(config_copy)
|
||||
if config_copy.search_api_key:
|
||||
config.search_api_key = SecretStr(config_copy.search_api_key)
|
||||
return config
|
||||
|
||||
|
||||
@ -134,16 +142,26 @@ def process_instance(
|
||||
dest_file = None
|
||||
|
||||
# Prepare instruction
|
||||
instruction = f'{instance["Question"]}\n'
|
||||
instruction = """You have one question to answer. It is paramount that you provide a correct answer.
|
||||
Give it all you can: I know for a fact that you have access to all the relevant tools to solve it and find the correct answer (the answer does exist). Failure or 'I cannot answer' or 'None found' will not be tolerated, success will be rewarded.
|
||||
You must make sure you find the correct answer! You MUST strictly follow the task-specific formatting instructions for your final answer.
|
||||
Here is the task:
|
||||
{task_question}
|
||||
""".format(
|
||||
task_question=instance['Question'],
|
||||
)
|
||||
logger.info(f'Instruction: {instruction}')
|
||||
if dest_file:
|
||||
instruction += f'\n\nThe mentioned file is provided in the workspace at: {dest_file.split("/")[-1]}'
|
||||
|
||||
instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
|
||||
instruction += 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
|
||||
instruction += """IMPORTANT: When seeking information from a website, REFRAIN from arbitrary URL navigation. You should utilize the designated search engine tool with precise keywords to obtain relevant URLs or use the specific website's search interface. DO NOT navigate directly to specific URLs as they may not exist.\n\nFor example: if you want to search for a research paper on Arxiv, either use the search engine tool with specific keywords or navigate to arxiv.org and then use its interface.\n"""
|
||||
instruction += 'IMPORTANT: You should NEVER ask for Human Help.\n'
|
||||
instruction += 'IMPORTANT: Please encapsulate your final answer (answer ONLY) within <solution> and </solution>. Your answer will be evaluated using string matching approaches so it important that you STRICTLY adhere to the output formatting instructions specified in the task (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)\n'
|
||||
instruction += (
|
||||
'For example: The answer to the question is <solution> 42 </solution>.\n'
|
||||
)
|
||||
instruction += "IMPORTANT: Your final answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, express it numerically (i.e., with digits rather than words), do not use commas, and do not include units such as $ or percent signs unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities). If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.\n"
|
||||
|
||||
# NOTE: You can actually set slightly different instruction for different agents
|
||||
instruction += AGENT_CLS_TO_INST_SUFFIX.get(metadata.agent_class, '')
|
||||
logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
|
||||
@ -175,7 +193,7 @@ def process_instance(
|
||||
for event in reversed(state.history):
|
||||
if event.source == 'agent':
|
||||
if isinstance(event, AgentFinishAction):
|
||||
model_answer_raw = event.thought
|
||||
model_answer_raw = event.final_thought
|
||||
break
|
||||
elif isinstance(event, CmdRunAction):
|
||||
model_answer_raw = event.thought
|
||||
@ -222,6 +240,7 @@ def process_instance(
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
test_result=test_result,
|
||||
)
|
||||
runtime.close()
|
||||
return output
|
||||
|
||||
|
||||
@ -253,6 +272,8 @@ if __name__ == '__main__':
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
|
||||
|
||||
toml_config = OpenHandsConfig()
|
||||
load_from_toml(toml_config)
|
||||
metadata = make_metadata(
|
||||
llm_config=llm_config,
|
||||
dataset_name='gaia',
|
||||
@ -261,7 +282,10 @@ if __name__ == '__main__':
|
||||
eval_note=args.eval_note,
|
||||
eval_output_dir=args.eval_output_dir,
|
||||
data_split=args.data_split,
|
||||
details={'gaia-level': args.level},
|
||||
details={
|
||||
'gaia-level': args.level,
|
||||
'mcp-servers': ['tavily'] if toml_config.search_api_key else [],
|
||||
},
|
||||
agent_config=agent_config,
|
||||
)
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ echo "LEVELS: $LEVELS"
|
||||
COMMAND="poetry run python ./evaluation/benchmarks/gaia/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--max-iterations 30 \
|
||||
--max-iterations 60 \
|
||||
--level $LEVELS \
|
||||
--data-split validation \
|
||||
--eval-num-workers $NUM_WORKERS \
|
||||
|
||||
@ -273,9 +273,9 @@ async def run_session(
|
||||
)
|
||||
)
|
||||
|
||||
config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
|
||||
runtime.config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
|
||||
|
||||
await add_mcp_tools_to_agent(agent, runtime, memory, config)
|
||||
await add_mcp_tools_to_agent(agent, runtime, memory)
|
||||
|
||||
# Clear loading animation
|
||||
is_loaded.set()
|
||||
|
||||
@ -139,9 +139,9 @@ async def run_controller(
|
||||
config.mcp_host, config, None
|
||||
)
|
||||
)
|
||||
config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
|
||||
runtime.config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
|
||||
|
||||
await add_mcp_tools_to_agent(agent, runtime, memory, config)
|
||||
await add_mcp_tools_to_agent(agent, runtime, memory)
|
||||
|
||||
replay_events: list[Event] | None = None
|
||||
if config.replay_trajectory_path:
|
||||
|
||||
@ -10,7 +10,6 @@ from openhands.core.config.mcp_config import (
|
||||
MCPSHTTPServerConfig,
|
||||
MCPSSEServerConfig,
|
||||
)
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
@ -187,9 +186,7 @@ async def call_tool_mcp(mcp_clients: list[MCPClient], action: MCPAction) -> Obse
|
||||
)
|
||||
|
||||
|
||||
async def add_mcp_tools_to_agent(
|
||||
agent: 'Agent', runtime: Runtime, memory: 'Memory', app_config: OpenHandsConfig
|
||||
):
|
||||
async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memory'):
|
||||
"""
|
||||
Add MCP tools to an agent.
|
||||
"""
|
||||
@ -208,7 +205,6 @@ async def add_mcp_tools_to_agent(
|
||||
extra_stdio_servers = []
|
||||
|
||||
# Add microagent MCP tools if available
|
||||
mcp_config: MCPConfig = app_config.mcp
|
||||
microagent_mcp_configs = memory.get_microagent_mcp_tools()
|
||||
for mcp_config in microagent_mcp_configs:
|
||||
if mcp_config.sse_servers:
|
||||
|
||||
@ -158,7 +158,7 @@ class AgentSession:
|
||||
# NOTE: this needs to happen before controller is created
|
||||
# so MCP tools can be included into the SystemMessageAction
|
||||
if self.runtime and runtime_connected and agent.config.enable_mcp:
|
||||
await add_mcp_tools_to_agent(agent, self.runtime, self.memory, config)
|
||||
await add_mcp_tools_to_agent(agent, self.runtime, self.memory)
|
||||
|
||||
if replay_json:
|
||||
initial_message = self._run_replay(
|
||||
|
||||
@ -385,7 +385,6 @@ async def test_add_mcp_tools_from_microagents():
|
||||
"""Test that add_mcp_tools_to_agent adds tools from microagents."""
|
||||
# Import ActionExecutionClient for mocking
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
@ -394,10 +393,6 @@ async def test_add_mcp_tools_from_microagents():
|
||||
mock_agent = MagicMock()
|
||||
mock_runtime = MagicMock(spec=ActionExecutionClient)
|
||||
mock_memory = MagicMock()
|
||||
mock_mcp_config = MCPConfig()
|
||||
|
||||
# Create a mock OpenHandsConfig with the MCP config
|
||||
mock_app_config = OpenHandsConfig(mcp=mock_mcp_config, search_api_key=None)
|
||||
|
||||
# Configure the mock memory to return a microagent MCP config
|
||||
mock_stdio_server = MCPStdioServerConfig(
|
||||
@ -425,9 +420,7 @@ async def test_add_mcp_tools_from_microagents():
|
||||
new=AsyncMock(return_value=[mock_tool]),
|
||||
):
|
||||
# Call the function with the OpenHandsConfig instead of MCPConfig
|
||||
await add_mcp_tools_to_agent(
|
||||
mock_agent, mock_runtime, mock_memory, mock_app_config
|
||||
)
|
||||
await add_mcp_tools_to_agent(mock_agent, mock_runtime, mock_memory)
|
||||
|
||||
# Verify that the memory's get_microagent_mcp_tools was called
|
||||
mock_memory.get_microagent_mcp_tools.assert_called_once()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
@ -259,6 +260,7 @@ async def test_run_controller_with_fatal_error(
|
||||
|
||||
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = test_event_stream
|
||||
runtime.config = copy.deepcopy(config)
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
@ -326,6 +328,7 @@ async def test_run_controller_stop_with_stuck(
|
||||
|
||||
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = test_event_stream
|
||||
runtime.config = copy.deepcopy(config)
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
@ -762,6 +765,7 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
runtime.config = copy.deepcopy(config)
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
@ -883,7 +887,9 @@ async def test_context_window_exceeded_error_handling(
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
config = OpenHandsConfig(max_iterations=max_iterations)
|
||||
mock_runtime.event_stream = test_event_stream
|
||||
mock_runtime.config = copy.deepcopy(config)
|
||||
|
||||
# Now we can run the controller for a fixed number of steps. Since the step
|
||||
# state is set to error out before then, if this terminates and we have a
|
||||
@ -891,7 +897,7 @@ async def test_context_window_exceeded_error_handling(
|
||||
# handles the truncation correctly.
|
||||
final_state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=OpenHandsConfig(max_iterations=max_iterations),
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
@ -1027,11 +1033,13 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
mock_runtime.event_stream = test_event_stream
|
||||
config = OpenHandsConfig(max_iterations=5)
|
||||
mock_runtime.config = copy.deepcopy(config)
|
||||
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=OpenHandsConfig(max_iterations=5),
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
@ -1104,10 +1112,12 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
mock_runtime.event_stream = test_event_stream
|
||||
config = OpenHandsConfig(max_iterations=3)
|
||||
mock_runtime.config = copy.deepcopy(config)
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=OpenHandsConfig(max_iterations=3),
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
@ -1167,6 +1177,7 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
runtime.event_stream = event_stream
|
||||
runtime.config = copy.deepcopy(config)
|
||||
|
||||
# Create a real Memory instance
|
||||
memory = Memory(event_stream=event_stream, sid='test-memory')
|
||||
|
||||
@ -208,9 +208,7 @@ async def test_run_session_without_initial_action(
|
||||
mock_display_runtime_init.assert_called_once_with('local')
|
||||
mock_display_animation.assert_called_once()
|
||||
mock_create_agent.assert_called_once_with(mock_config)
|
||||
mock_add_mcp_tools.assert_called_once_with(
|
||||
mock_agent, mock_runtime, mock_memory, mock_config
|
||||
)
|
||||
mock_add_mcp_tools.assert_called_once_with(mock_agent, mock_runtime, mock_memory)
|
||||
mock_create_runtime.assert_called_once()
|
||||
mock_create_controller.assert_called_once()
|
||||
mock_create_memory.assert_called_once()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user