mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Refactor I/O utils; allow 'task' command line parameter in cli.py (#6187)
Co-authored-by: OpenHands Bot <openhands@all-hands.dev>
This commit is contained in:
parent
663e36109c
commit
eed7e2dd6e
@ -6,11 +6,11 @@ from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.core.utils import json
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization.action import action_from_dict
|
||||
from openhands.events.serialization.event import event_to_memory
|
||||
from openhands.io import json
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
)
|
||||
from openhands.io import read_input, read_task
|
||||
|
||||
|
||||
def display_message(message: str):
|
||||
@ -82,21 +83,6 @@ def display_event(event: Event, config: AppConfig):
|
||||
display_confirmation(event.confirmation_state)
|
||||
|
||||
|
||||
def read_input(config: AppConfig) -> str:
|
||||
"""Read input from user based on config settings."""
|
||||
if config.cli_multiline_input:
|
||||
print('Enter your message (enter "/exit" on a new line to finish):')
|
||||
lines = []
|
||||
while True:
|
||||
line = input('>> ').rstrip()
|
||||
if line == '/exit': # finish input
|
||||
break
|
||||
lines.append(line)
|
||||
return '\n'.join(lines)
|
||||
else:
|
||||
return input('>> ').rstrip()
|
||||
|
||||
|
||||
async def main(loop: asyncio.AbstractEventLoop):
|
||||
"""Runs the agent in CLI mode."""
|
||||
|
||||
@ -104,7 +90,14 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
config = setup_config_from_args(args)
|
||||
# Load config from toml and override with command line arguments
|
||||
config: AppConfig = setup_config_from_args(args)
|
||||
|
||||
# Read task from file, CLI args, or stdin
|
||||
task_str = read_task(args, config.cli_multiline_input)
|
||||
|
||||
# If we have a task, create initial user action
|
||||
initial_user_action = MessageAction(content=task_str) if task_str else None
|
||||
|
||||
sid = str(uuid4())
|
||||
|
||||
@ -117,7 +110,9 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
|
||||
async def prompt_for_next_task():
|
||||
# Run input() in a thread pool to avoid blocking the event loop
|
||||
next_message = await loop.run_in_executor(None, read_input, config)
|
||||
next_message = await loop.run_in_executor(
|
||||
None, read_input, config.cli_multiline_input
|
||||
)
|
||||
if not next_message.strip():
|
||||
await prompt_for_next_task()
|
||||
if next_message == 'exit':
|
||||
@ -162,7 +157,12 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
|
||||
await runtime.connect()
|
||||
|
||||
asyncio.create_task(prompt_for_next_task())
|
||||
if initial_user_action:
|
||||
# If there's an initial user action, enqueue it and do not prompt again
|
||||
event_stream.add_event(initial_user_action, EventSource.USER)
|
||||
else:
|
||||
# Otherwise prompt for the user's first message right away
|
||||
asyncio.create_task(prompt_for_next_task())
|
||||
|
||||
await run_agent_until_done(
|
||||
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Callable, Protocol
|
||||
|
||||
@ -29,6 +28,7 @@ from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.events.serialization import event_from_dict
|
||||
from openhands.events.serialization.event import event_to_trajectory
|
||||
from openhands.io import read_input, read_task
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
|
||||
@ -41,32 +41,6 @@ class FakeUserResponseFunc(Protocol):
|
||||
) -> str: ...
|
||||
|
||||
|
||||
def read_task_from_file(file_path: str) -> str:
|
||||
"""Read task from the specified file."""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
def read_task_from_stdin() -> str:
|
||||
"""Read task from stdin."""
|
||||
return sys.stdin.read()
|
||||
|
||||
|
||||
def read_input(config: AppConfig) -> str:
|
||||
"""Read input from user based on config settings."""
|
||||
if config.cli_multiline_input:
|
||||
print('Enter your message (enter "/exit" on a new line to finish):')
|
||||
lines = []
|
||||
while True:
|
||||
line = input('>> ').rstrip()
|
||||
if line == '/exit': # finish input
|
||||
break
|
||||
lines.append(line)
|
||||
return '\n'.join(lines)
|
||||
else:
|
||||
return input('>> ').rstrip()
|
||||
|
||||
|
||||
async def run_controller(
|
||||
config: AppConfig,
|
||||
initial_user_action: Action,
|
||||
@ -139,7 +113,6 @@ async def run_controller(
|
||||
assert isinstance(
|
||||
initial_user_action, Action
|
||||
), f'initial user actions must be an Action, got {type(initial_user_action)}'
|
||||
# Logging
|
||||
logger.debug(
|
||||
f'Agent Controller Initialized: Running agent {agent.name}, model '
|
||||
f'{agent.llm.config.model}, with actions: {initial_user_action}'
|
||||
@ -167,7 +140,7 @@ async def run_controller(
|
||||
if exit_on_message:
|
||||
message = '/exit'
|
||||
elif fake_user_response_fn is None:
|
||||
message = read_input(config)
|
||||
message = read_input(config.cli_multiline_input)
|
||||
else:
|
||||
message = fake_user_response_fn(controller.get_state())
|
||||
action = MessageAction(content=message)
|
||||
@ -268,28 +241,23 @@ def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
config = setup_config_from_args(args)
|
||||
config: AppConfig = setup_config_from_args(args)
|
||||
|
||||
# Determine the task
|
||||
task_str = ''
|
||||
if args.file:
|
||||
task_str = read_task_from_file(args.file)
|
||||
elif args.task:
|
||||
task_str = args.task
|
||||
elif not sys.stdin.isatty():
|
||||
task_str = read_task_from_stdin()
|
||||
# Read task from file, CLI args, or stdin
|
||||
task_str = read_task(args, config.cli_multiline_input)
|
||||
|
||||
initial_user_action: Action = NullAction()
|
||||
if config.replay_trajectory_path:
|
||||
if task_str:
|
||||
raise ValueError(
|
||||
'User-specified task is not supported under trajectory replay mode'
|
||||
)
|
||||
elif task_str:
|
||||
initial_user_action = MessageAction(content=task_str)
|
||||
else:
|
||||
|
||||
if not task_str:
|
||||
raise ValueError('No task provided. Please specify a task through -t, -f.')
|
||||
|
||||
# Create initial user action
|
||||
initial_user_action: MessageAction = MessageAction(content=task_str)
|
||||
|
||||
# Set session name
|
||||
session_name = args.name
|
||||
sid = generate_sid(config, session_name)
|
||||
|
||||
@ -8,9 +8,9 @@ from functools import partial
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.utils import json
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.io import json
|
||||
from openhands.storage import FileStore
|
||||
from openhands.storage.locations import (
|
||||
get_conversation_dir,
|
||||
|
||||
10
openhands/io/__init__.py
Normal file
10
openhands/io/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
from openhands.io.io import read_input, read_task, read_task_from_file
|
||||
from openhands.io.json import dumps, loads
|
||||
|
||||
__all__ = [
|
||||
'read_input',
|
||||
'read_task_from_file',
|
||||
'read_task',
|
||||
'dumps',
|
||||
'loads',
|
||||
]
|
||||
40
openhands/io/io.py
Normal file
40
openhands/io/io.py
Normal file
@ -0,0 +1,40 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
def read_input(cli_multiline_input: bool = False) -> str:
|
||||
"""Read input from user based on config settings."""
|
||||
if cli_multiline_input:
|
||||
print('Enter your message (enter "/exit" on a new line to finish):')
|
||||
lines = []
|
||||
while True:
|
||||
line = input('>> ').rstrip()
|
||||
if line == '/exit': # finish input
|
||||
break
|
||||
lines.append(line)
|
||||
return '\n'.join(lines)
|
||||
else:
|
||||
return input('>> ').rstrip()
|
||||
|
||||
|
||||
def read_task_from_file(file_path: str) -> str:
|
||||
"""Read task from the specified file."""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
def read_task(args: argparse.Namespace, cli_multiline_input: bool) -> str:
|
||||
"""
|
||||
Read the task from the CLI args, file, or stdin.
|
||||
"""
|
||||
|
||||
# Determine the task
|
||||
task_str = ''
|
||||
if args.file:
|
||||
task_str = read_task_from_file(args.file)
|
||||
elif args.task:
|
||||
task_str = args.task
|
||||
elif not sys.stdin.isatty():
|
||||
task_str = read_input(cli_multiline_input)
|
||||
|
||||
return task_str
|
||||
@ -172,7 +172,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
from openhands.core.utils import json
|
||||
from openhands.io import json
|
||||
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
mock_function_calling = not self.is_function_calling_active()
|
||||
@ -369,7 +369,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
from openhands.core.utils import json
|
||||
from openhands.io import json
|
||||
|
||||
logger.debug(f'Model info: {json.dumps(self.model_info, indent=2)}')
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from openhands.core.cli import read_input
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.io import read_input
|
||||
|
||||
|
||||
def test_single_line_input():
|
||||
@ -10,7 +10,7 @@ def test_single_line_input():
|
||||
config.cli_multiline_input = False
|
||||
|
||||
with patch('builtins.input', return_value='hello world'):
|
||||
result = read_input(config)
|
||||
result = read_input(config.cli_multiline_input)
|
||||
assert result == 'hello world'
|
||||
|
||||
|
||||
@ -23,5 +23,5 @@ def test_multiline_input():
|
||||
mock_inputs = ['line 1', 'line 2', 'line 3', '/exit']
|
||||
|
||||
with patch('builtins.input', side_effect=mock_inputs):
|
||||
result = read_input(config)
|
||||
result = read_input(config.cli_multiline_input)
|
||||
assert result == 'line 1\nline 2\nline 3'
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
from openhands.core.utils import json
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.io import json
|
||||
|
||||
|
||||
def test_event_serialization_deserialization():
|
||||
|
||||
@ -3,7 +3,7 @@ from datetime import datetime
|
||||
|
||||
import psutil
|
||||
|
||||
from openhands.core.utils.json import dumps
|
||||
from openhands.io.json import dumps
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
|
||||
@ -2,11 +2,11 @@ import pytest
|
||||
|
||||
from openhands.agenthub.micro.agent import parse_response as parse_response_micro
|
||||
from openhands.core.exceptions import LLMResponseError
|
||||
from openhands.core.utils.json import loads as custom_loads
|
||||
from openhands.events.action import (
|
||||
FileWriteAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.io import loads as custom_loads
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user