mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
fix: Allow evaluation benchmarks to pass image urls in run_controller() instead of simply passing strings (#4100)
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
This commit is contained in:
committed by
GitHub
parent
9c07370559
commit
0809d26f4d
@@ -22,6 +22,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import MessageAction
|
||||
|
||||
game = None
|
||||
|
||||
@@ -122,7 +123,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
|
||||
@@ -217,7 +217,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.runtime.runtime import Runtime
|
||||
|
||||
@@ -211,7 +211,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.runtime.runtime import Runtime
|
||||
|
||||
@@ -285,7 +285,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
|
||||
@@ -409,7 +409,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
|
||||
@@ -23,6 +23,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import MessageAction
|
||||
|
||||
# Only CodeActAgent can delegate to BrowsingAgent
|
||||
SUPPORTED_AGENT_CLS = {'CodeActAgent'}
|
||||
@@ -76,7 +77,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -148,7 +148,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
|
||||
@@ -24,6 +24,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import MessageAction
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
@@ -83,7 +84,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
|
||||
@@ -219,7 +219,7 @@ Ok now its time to start solving the question. Good luck!
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
|
||||
@@ -35,7 +35,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.runtime.runtime import Runtime
|
||||
|
||||
@@ -237,7 +237,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
|
||||
@@ -211,7 +211,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
|
||||
@@ -128,11 +128,12 @@ def process_instance(
|
||||
|
||||
runtime = create_runtime(config, sid=env_id)
|
||||
task_str = initialize_runtime(runtime)
|
||||
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=task_str, # take output from initialize_runtime
|
||||
initial_user_action=MessageAction(
|
||||
content=task_str
|
||||
), # take output from initialize_runtime
|
||||
runtime=runtime,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import (
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.runtime.runtime import Runtime
|
||||
@@ -180,7 +181,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=fake_user_response_fn,
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.runtime.runtime import Runtime
|
||||
|
||||
@@ -242,7 +242,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
|
||||
metadata.agent_class
|
||||
|
||||
@@ -29,7 +29,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation, ErrorObservation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.runtime.runtime import Runtime
|
||||
@@ -365,6 +365,7 @@ def process_instance(
|
||||
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||
|
||||
runtime = create_runtime(config, sid=instance.instance_id)
|
||||
|
||||
try:
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
@@ -374,7 +375,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
|
||||
@@ -23,7 +23,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.runtime.runtime import Runtime
|
||||
|
||||
@@ -109,7 +109,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
|
||||
@@ -148,7 +148,7 @@ def process_instance(
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=task_str,
|
||||
initial_user_action=MessageAction(content=task_str),
|
||||
runtime=runtime,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -83,7 +83,7 @@ def create_runtime(
|
||||
|
||||
async def run_controller(
|
||||
config: AppConfig,
|
||||
task_str: str,
|
||||
initial_user_action: Action,
|
||||
sid: str | None = None,
|
||||
runtime: Runtime | None = None,
|
||||
agent: Agent | None = None,
|
||||
@@ -96,7 +96,7 @@ async def run_controller(
|
||||
|
||||
Args:
|
||||
config: The app config.
|
||||
task_str: The task to run. It can be a string.
|
||||
initial_user_action: An Action object containing initial user input
|
||||
runtime: (optional) A runtime for the agent to run on.
|
||||
agent: (optional) A agent to run.
|
||||
exit_on_message: quit if agent asks for a message from user (optional)
|
||||
@@ -146,11 +146,13 @@ async def run_controller(
|
||||
if controller is not None:
|
||||
controller.agent_task = asyncio.create_task(controller.start_step_loop())
|
||||
|
||||
assert isinstance(task_str, str), f'task_str must be a string, got {type(task_str)}'
|
||||
assert isinstance(
|
||||
initial_user_action, Action
|
||||
), f'initial user actions must be an Action, got {type(initial_user_action)}'
|
||||
# Logging
|
||||
logger.info(
|
||||
f'Agent Controller Initialized: Running agent {agent.name}, model '
|
||||
f'{agent.llm.config.model}, with task: "{task_str}"'
|
||||
f'{agent.llm.config.model}, with actions: {initial_user_action}'
|
||||
)
|
||||
|
||||
# start event is a MessageAction with the task, either resumed or new
|
||||
@@ -166,8 +168,8 @@ async def run_controller(
|
||||
EventSource.USER,
|
||||
)
|
||||
elif initial_state is None:
|
||||
# init with the provided task
|
||||
event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
|
||||
# init with the provided actions
|
||||
event_stream.add_event(initial_user_action, EventSource.USER)
|
||||
|
||||
async def on_event(event: Event):
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
@@ -224,7 +226,7 @@ if __name__ == '__main__':
|
||||
task_str = read_task_from_stdin()
|
||||
else:
|
||||
raise ValueError('No task provided. Please specify a task through -t, -f.')
|
||||
|
||||
initial_user_action: MessageAction = MessageAction(content=task_str)
|
||||
# Load the app config
|
||||
# this will load config from config.toml in the current directory
|
||||
# as well as from the environment variables
|
||||
@@ -253,7 +255,7 @@ if __name__ == '__main__':
|
||||
asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=task_str,
|
||||
initial_user_action=initial_user_action,
|
||||
sid=sid,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -9,10 +9,7 @@ from openhands.controller.state.state import State
|
||||
from openhands.core.config import load_app_config
|
||||
from openhands.core.main import run_controller
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
)
|
||||
from openhands.events.action import AgentFinishAction, AgentRejectAction, MessageAction
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.runtime import get_runtime_cls
|
||||
@@ -90,7 +87,7 @@ def test_write_simple_script(current_test_name: str) -> None:
|
||||
task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
|
||||
|
||||
final_state: State | None = asyncio.run(
|
||||
run_controller(CONFIG, task, exit_on_message=True)
|
||||
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
@@ -136,7 +133,7 @@ def test_edits(current_test_name: str):
|
||||
# Execute the task
|
||||
task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
|
||||
final_state: State | None = asyncio.run(
|
||||
run_controller(CONFIG, task, exit_on_message=True)
|
||||
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
@@ -160,7 +157,7 @@ def test_ipython(current_test_name: str):
|
||||
# Execute the task
|
||||
task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
|
||||
final_state: State | None = asyncio.run(
|
||||
run_controller(CONFIG, task, exit_on_message=True)
|
||||
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
@@ -185,7 +182,7 @@ def test_simple_task_rejection(current_test_name: str):
|
||||
# the workspace is not a git repo
|
||||
task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
|
||||
final_state: State | None = asyncio.run(
|
||||
run_controller(CONFIG, task, exit_on_message=True)
|
||||
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
|
||||
@@ -200,7 +197,7 @@ def test_ipython_module(current_test_name: str):
|
||||
# Execute the task
|
||||
task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
|
||||
final_state: State | None = asyncio.run(
|
||||
run_controller(CONFIG, task, exit_on_message=True)
|
||||
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
@@ -226,7 +223,7 @@ def test_browse_internet(current_test_name: str):
|
||||
# Execute the task
|
||||
task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
|
||||
final_state: State | None = asyncio.run(
|
||||
run_controller(CONFIG, task, exit_on_message=True)
|
||||
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user