fix: Allow evaluation benchmarks to pass image urls in run_controller() instead of simply passing strings (#4100)

Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
This commit is contained in:
Aditya Bharat Soni
2024-10-07 15:37:08 -04:00
committed by GitHub
parent 9c07370559
commit 0809d26f4d
19 changed files with 47 additions and 42 deletions

View File

@@ -22,6 +22,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import MessageAction
game = None
@@ -122,7 +123,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class

View File

@@ -217,7 +217,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
)

View File

@@ -30,7 +30,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime
@@ -211,7 +211,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
)

View File

@@ -27,7 +27,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime
@@ -285,7 +285,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class

View File

@@ -409,7 +409,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
],

View File

@@ -23,6 +23,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import MessageAction
# Only CodeActAgent can delegate to BrowsingAgent
SUPPORTED_AGENT_CLS = {'CodeActAgent'}
@@ -76,7 +77,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
)
)

View File

@@ -148,7 +148,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class

View File

@@ -24,6 +24,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import MessageAction
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
@@ -83,7 +84,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class

View File

@@ -219,7 +219,7 @@ Ok now its time to start solving the question. Good luck!
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class

View File

@@ -35,7 +35,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime
@@ -237,7 +237,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class

View File

@@ -211,7 +211,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class

View File

@@ -128,11 +128,12 @@ def process_instance(
runtime = create_runtime(config, sid=env_id)
task_str = initialize_runtime(runtime)
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=task_str, # take output from initialize_runtime
initial_user_action=MessageAction(
content=task_str
), # take output from initialize_runtime
runtime=runtime,
)
)

View File

@@ -29,6 +29,7 @@ from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import (
CmdRunAction,
MessageAction,
)
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime
@@ -180,7 +181,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=fake_user_response_fn,
)

View File

@@ -39,7 +39,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime
@@ -242,7 +242,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class

View File

@@ -29,7 +29,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.events.serialization.event import event_to_dict
from openhands.runtime.runtime import Runtime
@@ -365,6 +365,7 @@ def process_instance(
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
runtime = create_runtime(config, sid=instance.instance_id)
try:
initialize_runtime(runtime, instance)
@@ -374,7 +375,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class

View File

@@ -23,7 +23,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime
@@ -109,7 +109,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class

View File

@@ -148,7 +148,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=task_str,
initial_user_action=MessageAction(content=task_str),
runtime=runtime,
)
)

View File

@@ -83,7 +83,7 @@ def create_runtime(
async def run_controller(
config: AppConfig,
task_str: str,
initial_user_action: Action,
sid: str | None = None,
runtime: Runtime | None = None,
agent: Agent | None = None,
@@ -96,7 +96,7 @@ async def run_controller(
Args:
config: The app config.
task_str: The task to run. It can be a string.
initial_user_action: An Action object containing initial user input
runtime: (optional) A runtime for the agent to run on.
agent: (optional) A agent to run.
exit_on_message: quit if agent asks for a message from user (optional)
@@ -146,11 +146,13 @@ async def run_controller(
if controller is not None:
controller.agent_task = asyncio.create_task(controller.start_step_loop())
assert isinstance(task_str, str), f'task_str must be a string, got {type(task_str)}'
assert isinstance(
initial_user_action, Action
), f'initial user actions must be an Action, got {type(initial_user_action)}'
# Logging
logger.info(
f'Agent Controller Initialized: Running agent {agent.name}, model '
f'{agent.llm.config.model}, with task: "{task_str}"'
f'{agent.llm.config.model}, with actions: {initial_user_action}'
)
# start event is a MessageAction with the task, either resumed or new
@@ -166,8 +168,8 @@ async def run_controller(
EventSource.USER,
)
elif initial_state is None:
# init with the provided task
event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
# init with the provided actions
event_stream.add_event(initial_user_action, EventSource.USER)
async def on_event(event: Event):
if isinstance(event, AgentStateChangedObservation):
@@ -224,7 +226,7 @@ if __name__ == '__main__':
task_str = read_task_from_stdin()
else:
raise ValueError('No task provided. Please specify a task through -t, -f.')
initial_user_action: MessageAction = MessageAction(content=task_str)
# Load the app config
# this will load config from config.toml in the current directory
# as well as from the environment variables
@@ -253,7 +255,7 @@ if __name__ == '__main__':
asyncio.run(
run_controller(
config=config,
task_str=task_str,
initial_user_action=initial_user_action,
sid=sid,
)
)

View File

@@ -9,10 +9,7 @@ from openhands.controller.state.state import State
from openhands.core.config import load_app_config
from openhands.core.main import run_controller
from openhands.core.schema import AgentState
from openhands.events.action import (
AgentFinishAction,
AgentRejectAction,
)
from openhands.events.action import AgentFinishAction, AgentRejectAction, MessageAction
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.runtime import get_runtime_cls
@@ -90,7 +87,7 @@ def test_write_simple_script(current_test_name: str) -> None:
task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)
@@ -136,7 +133,7 @@ def test_edits(current_test_name: str):
# Execute the task
task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)
@@ -160,7 +157,7 @@ def test_ipython(current_test_name: str):
# Execute the task
task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)
@@ -185,7 +182,7 @@ def test_simple_task_rejection(current_test_name: str):
# the workspace is not a git repo
task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)
assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
@@ -200,7 +197,7 @@ def test_ipython_module(current_test_name: str):
# Execute the task
task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)
@@ -226,7 +223,7 @@ def test_browse_internet(current_test_name: str):
# Execute the task
task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)