Refactor I/O utils; allow 'task' command line parameter in cli.py (#6187)

Co-authored-by: OpenHands Bot <openhands@all-hands.dev>
This commit is contained in:
Engel Nyst 2025-02-19 22:10:14 +01:00 committed by GitHub
parent 663e36109c
commit eed7e2dd6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 88 additions and 70 deletions

View File

@ -6,11 +6,11 @@ from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.core.message import ImageContent, Message, TextContent
from openhands.core.utils import json
from openhands.events.action import Action
from openhands.events.event import Event
from openhands.events.serialization.action import action_from_dict
from openhands.events.serialization.event import event_to_memory
from openhands.io import json
from openhands.llm.llm import LLM

View File

@ -30,6 +30,7 @@ from openhands.events.observation import (
CmdOutputObservation,
FileEditObservation,
)
from openhands.io import read_input, read_task
def display_message(message: str):
@ -82,21 +83,6 @@ def display_event(event: Event, config: AppConfig):
display_confirmation(event.confirmation_state)
def read_input(config: AppConfig) -> str:
"""Read input from user based on config settings."""
if config.cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()
async def main(loop: asyncio.AbstractEventLoop):
"""Runs the agent in CLI mode."""
@ -104,7 +90,14 @@ async def main(loop: asyncio.AbstractEventLoop):
logger.setLevel(logging.WARNING)
config = setup_config_from_args(args)
# Load config from toml and override with command line arguments
config: AppConfig = setup_config_from_args(args)
# Read task from file, CLI args, or stdin
task_str = read_task(args, config.cli_multiline_input)
# If we have a task, create initial user action
initial_user_action = MessageAction(content=task_str) if task_str else None
sid = str(uuid4())
@ -117,7 +110,9 @@ async def main(loop: asyncio.AbstractEventLoop):
async def prompt_for_next_task():
# Run input() in a thread pool to avoid blocking the event loop
next_message = await loop.run_in_executor(None, read_input, config)
next_message = await loop.run_in_executor(
None, read_input, config.cli_multiline_input
)
if not next_message.strip():
await prompt_for_next_task()
if next_message == 'exit':
@ -162,7 +157,12 @@ async def main(loop: asyncio.AbstractEventLoop):
await runtime.connect()
asyncio.create_task(prompt_for_next_task())
if initial_user_action:
# If there's an initial user action, enqueue it and do not prompt again
event_stream.add_event(initial_user_action, EventSource.USER)
else:
# Otherwise prompt for the user's first message right away
asyncio.create_task(prompt_for_next_task())
await run_agent_until_done(
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]

View File

@ -1,7 +1,6 @@
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Callable, Protocol
@ -29,6 +28,7 @@ from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.events.serialization.event import event_to_trajectory
from openhands.io import read_input, read_task
from openhands.runtime.base import Runtime
@ -41,32 +41,6 @@ class FakeUserResponseFunc(Protocol):
) -> str: ...
def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
def read_task_from_stdin() -> str:
"""Read task from stdin."""
return sys.stdin.read()
def read_input(config: AppConfig) -> str:
"""Read input from user based on config settings."""
if config.cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()
async def run_controller(
config: AppConfig,
initial_user_action: Action,
@ -139,7 +113,6 @@ async def run_controller(
assert isinstance(
initial_user_action, Action
), f'initial user actions must be an Action, got {type(initial_user_action)}'
# Logging
logger.debug(
f'Agent Controller Initialized: Running agent {agent.name}, model '
f'{agent.llm.config.model}, with actions: {initial_user_action}'
@ -167,7 +140,7 @@ async def run_controller(
if exit_on_message:
message = '/exit'
elif fake_user_response_fn is None:
message = read_input(config)
message = read_input(config.cli_multiline_input)
else:
message = fake_user_response_fn(controller.get_state())
action = MessageAction(content=message)
@ -268,28 +241,23 @@ def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
if __name__ == '__main__':
args = parse_arguments()
config = setup_config_from_args(args)
config: AppConfig = setup_config_from_args(args)
# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_task_from_stdin()
# Read task from file, CLI args, or stdin
task_str = read_task(args, config.cli_multiline_input)
initial_user_action: Action = NullAction()
if config.replay_trajectory_path:
if task_str:
raise ValueError(
'User-specified task is not supported under trajectory replay mode'
)
elif task_str:
initial_user_action = MessageAction(content=task_str)
else:
if not task_str:
raise ValueError('No task provided. Please specify a task through -t, -f.')
# Create initial user action
initial_user_action: MessageAction = MessageAction(content=task_str)
# Set session name
session_name = args.name
sid = generate_sid(config, session_name)

View File

@ -8,9 +8,9 @@ from functools import partial
from typing import Callable, Iterable
from openhands.core.logger import openhands_logger as logger
from openhands.core.utils import json
from openhands.events.event import Event, EventSource
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.io import json
from openhands.storage import FileStore
from openhands.storage.locations import (
get_conversation_dir,

10
openhands/io/__init__.py Normal file
View File

@ -0,0 +1,10 @@
from openhands.io.io import read_input, read_task, read_task_from_file
from openhands.io.json import dumps, loads
__all__ = [
'read_input',
'read_task_from_file',
'read_task',
'dumps',
'loads',
]

40
openhands/io/io.py Normal file
View File

@ -0,0 +1,40 @@
import argparse
import sys
def read_input(cli_multiline_input: bool = False) -> str:
"""Read input from user based on config settings."""
if cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()
def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
def read_task(args: argparse.Namespace, cli_multiline_input: bool) -> str:
"""
Read the task from the CLI args, file, or stdin.
"""
# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_input(cli_multiline_input)
return task_str

View File

@ -172,7 +172,7 @@ class LLM(RetryMixin, DebugMixin):
)
def wrapper(*args, **kwargs):
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.core.utils import json
from openhands.io import json
messages: list[dict[str, Any]] | dict[str, Any] = []
mock_function_calling = not self.is_function_calling_active()
@ -369,7 +369,7 @@ class LLM(RetryMixin, DebugMixin):
# noinspection PyBroadException
except Exception:
pass
from openhands.core.utils import json
from openhands.io import json
logger.debug(f'Model info: {json.dumps(self.model_info, indent=2)}')

View File

@ -1,7 +1,7 @@
from unittest.mock import patch
from openhands.core.cli import read_input
from openhands.core.config import AppConfig
from openhands.io import read_input
def test_single_line_input():
@ -10,7 +10,7 @@ def test_single_line_input():
config.cli_multiline_input = False
with patch('builtins.input', return_value='hello world'):
result = read_input(config)
result = read_input(config.cli_multiline_input)
assert result == 'hello world'
@ -23,5 +23,5 @@ def test_multiline_input():
mock_inputs = ['line 1', 'line 2', 'line 3', '/exit']
with patch('builtins.input', side_effect=mock_inputs):
result = read_input(config)
result = read_input(config.cli_multiline_input)
assert result == 'line 1\nline 2\nline 3'

View File

@ -1,7 +1,7 @@
from datetime import datetime
from openhands.core.utils import json
from openhands.events.action import MessageAction
from openhands.io import json
def test_event_serialization_deserialization():

View File

@ -3,7 +3,7 @@ from datetime import datetime
import psutil
from openhands.core.utils.json import dumps
from openhands.io.json import dumps
def get_memory_usage():

View File

@ -2,11 +2,11 @@ import pytest
from openhands.agenthub.micro.agent import parse_response as parse_response_micro
from openhands.core.exceptions import LLMResponseError
from openhands.core.utils.json import loads as custom_loads
from openhands.events.action import (
FileWriteAction,
MessageAction,
)
from openhands.io import loads as custom_loads
@pytest.mark.parametrize(