mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
feat: Add basic support for prompt-toolkit in the CLI (#7709)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Bashwara Undupitiya <bashwarau@verdentra.com>
This commit is contained in:
@@ -3,7 +3,9 @@ import logging
|
||||
import sys
|
||||
from uuid import uuid4
|
||||
|
||||
from termcolor import colored
|
||||
from prompt_toolkit import PromptSession, print_formatted_text
|
||||
from prompt_toolkit.formatted_text import FormattedText
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
|
||||
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from openhands.core.config import (
|
||||
@@ -36,24 +38,66 @@ from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
)
|
||||
from openhands.io import read_input, read_task
|
||||
from openhands.io import read_task
|
||||
|
||||
prompt_session = PromptSession()
|
||||
|
||||
|
||||
def display_message(message: str):
|
||||
print(colored('🤖 ' + message + '\n', 'yellow'))
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
('ansiyellow', '🤖 '),
|
||||
('ansiyellow', message),
|
||||
('', '\n'),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def display_command(command: str):
|
||||
print('❯ ' + colored(command + '\n', 'green'))
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
('', '❯ '),
|
||||
('ansigreen', command),
|
||||
('', '\n'),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def display_confirmation(confirmation_state: ActionConfirmationStatus):
|
||||
if confirmation_state == ActionConfirmationStatus.CONFIRMED:
|
||||
print(colored('✅ ' + confirmation_state + '\n', 'green'))
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
('ansigreen', '✅ '),
|
||||
('ansigreen', str(confirmation_state)),
|
||||
('', '\n'),
|
||||
]
|
||||
)
|
||||
)
|
||||
elif confirmation_state == ActionConfirmationStatus.REJECTED:
|
||||
print(colored('❌ ' + confirmation_state + '\n', 'red'))
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
('ansired', '❌ '),
|
||||
('ansired', str(confirmation_state)),
|
||||
('', '\n'),
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(colored('⏳ ' + confirmation_state + '\n', 'yellow'))
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
('ansiyellow', '⏳ '),
|
||||
('ansiyellow', str(confirmation_state)),
|
||||
('', '\n'),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def display_command_output(output: str):
|
||||
@@ -62,12 +106,19 @@ def display_command_output(output: str):
|
||||
if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
|
||||
# TODO: clean this up once we clean up terminal output
|
||||
continue
|
||||
print(colored(line, 'blue'))
|
||||
print('\n')
|
||||
print_formatted_text(FormattedText([('ansiblue', line)]))
|
||||
print_formatted_text('')
|
||||
|
||||
|
||||
def display_file_edit(event: FileEditAction | FileEditObservation):
|
||||
print(colored(str(event), 'green'))
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
('ansigreen', str(event)),
|
||||
('', '\n'),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def display_event(event: Event, config: AppConfig):
|
||||
@@ -89,6 +140,41 @@ def display_event(event: Event, config: AppConfig):
|
||||
display_confirmation(event.confirmation_state)
|
||||
|
||||
|
||||
async def read_prompt_input(multiline=False):
|
||||
try:
|
||||
if multiline:
|
||||
kb = KeyBindings()
|
||||
|
||||
@kb.add('c-d')
|
||||
def _(event):
|
||||
event.current_buffer.validate_and_handle()
|
||||
|
||||
message = await prompt_session.prompt_async(
|
||||
'Enter your message and press Ctrl+D to finish:\n',
|
||||
multiline=True,
|
||||
key_bindings=kb,
|
||||
)
|
||||
else:
|
||||
message = await prompt_session.prompt_async(
|
||||
'>> ',
|
||||
)
|
||||
return message
|
||||
except KeyboardInterrupt:
|
||||
return 'exit'
|
||||
except EOFError:
|
||||
return 'exit'
|
||||
|
||||
|
||||
async def read_confirmation_input():
|
||||
try:
|
||||
confirmation = await prompt_session.prompt_async(
|
||||
'Confirm action (possible security risk)? (y/n) >> ',
|
||||
)
|
||||
return confirmation.lower() == 'y'
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return False
|
||||
|
||||
|
||||
async def main(loop: asyncio.AbstractEventLoop):
|
||||
"""Runs the agent in CLI mode."""
|
||||
|
||||
@@ -122,10 +208,7 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
event_stream = runtime.event_stream
|
||||
|
||||
async def prompt_for_next_task():
|
||||
# Run input() in a thread pool to avoid blocking the event loop
|
||||
next_message = await loop.run_in_executor(
|
||||
None, read_input, config.cli_multiline_input
|
||||
)
|
||||
next_message = await read_prompt_input(config.cli_multiline_input)
|
||||
if not next_message.strip():
|
||||
await prompt_for_next_task()
|
||||
if next_message == 'exit':
|
||||
@@ -136,12 +219,6 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
action = MessageAction(content=next_message)
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
async def prompt_for_user_confirmation():
|
||||
user_confirmation = await loop.run_in_executor(
|
||||
None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
|
||||
)
|
||||
return user_confirmation.lower() == 'y'
|
||||
|
||||
async def on_event_async(event: Event):
|
||||
display_event(event, config)
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
@@ -151,7 +228,7 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
]:
|
||||
await prompt_for_next_task()
|
||||
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||
user_confirmed = await prompt_for_user_confirmation()
|
||||
user_confirmed = await read_confirmation_input()
|
||||
if user_confirmed:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
|
||||
|
||||
8
poetry.lock
generated
8
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -2661,7 +2661,7 @@ grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_versi
|
||||
grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
|
||||
proto-plus = [
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0dev"},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
|
||||
requests = ">=2.18.0,<3.0.0.dev0"
|
||||
@@ -9212,7 +9212,7 @@ description = "A language and compiler for custom Deep Learning operations"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["evaluation"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version == \"3.12\""
|
||||
files = [
|
||||
{file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"},
|
||||
{file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"},
|
||||
@@ -10193,4 +10193,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "2072ef82a79cc7be1652af1b33fb9814d85e1c6ad4ed410e12f2e57e0fadbd58"
|
||||
content-hash = "bd1e164559f6395718bc76482493aa71327e7aee9ebe6131facba65661d9abec"
|
||||
|
||||
@@ -79,6 +79,7 @@ memory-profiler = "^0.61.0"
|
||||
daytona-sdk = "0.12.1"
|
||||
python-json-logger = "^3.2.1"
|
||||
playwright = "^1.51.0"
|
||||
prompt-toolkit = "^3.0.50"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.11.4"
|
||||
|
||||
169
tests/unit/test_cli_basic.py
Normal file
169
tests/unit/test_cli_basic.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.application import create_app_session
|
||||
from prompt_toolkit.input import create_pipe_input
|
||||
from prompt_toolkit.output import create_output
|
||||
|
||||
from openhands.core.cli import main
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
|
||||
|
||||
class MockEventStream:
|
||||
def __init__(self):
|
||||
self._subscribers = {}
|
||||
self.cur_id = 0
|
||||
|
||||
def subscribe(self, subscriber_id, callback, callback_id):
|
||||
if subscriber_id not in self._subscribers:
|
||||
self._subscribers[subscriber_id] = {}
|
||||
self._subscribers[subscriber_id][callback_id] = callback
|
||||
|
||||
def unsubscribe(self, subscriber_id, callback_id):
|
||||
if (
|
||||
subscriber_id in self._subscribers
|
||||
and callback_id in self._subscribers[subscriber_id]
|
||||
):
|
||||
del self._subscribers[subscriber_id][callback_id]
|
||||
|
||||
def add_event(self, event, source):
|
||||
event._id = self.cur_id
|
||||
self.cur_id += 1
|
||||
event._source = source
|
||||
event._timestamp = datetime.now().isoformat()
|
||||
|
||||
for subscriber_id in self._subscribers:
|
||||
for callback_id, callback in self._subscribers[subscriber_id].items():
|
||||
callback(event)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
with patch('openhands.core.cli.create_agent') as mock_create_agent:
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_instance.name = 'test-agent'
|
||||
mock_agent_instance.llm = AsyncMock()
|
||||
mock_agent_instance.llm.config = AsyncMock()
|
||||
mock_agent_instance.llm.config.model = 'test-model'
|
||||
mock_agent_instance.llm.config.base_url = 'http://test'
|
||||
mock_agent_instance.llm.config.max_message_chars = 1000
|
||||
mock_agent_instance.config = AsyncMock()
|
||||
mock_agent_instance.config.disabled_microagents = []
|
||||
mock_agent_instance.sandbox_plugins = []
|
||||
mock_agent_instance.prompt_manager = AsyncMock()
|
||||
mock_create_agent.return_value = mock_agent_instance
|
||||
yield mock_agent_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_controller():
|
||||
with patch('openhands.core.cli.create_controller') as mock_create_controller:
|
||||
mock_controller_instance = AsyncMock()
|
||||
mock_controller_instance.state.agent_state = None
|
||||
# Mock run_until_done to finish immediately
|
||||
mock_controller_instance.run_until_done = AsyncMock(return_value=None)
|
||||
mock_create_controller.return_value = (mock_controller_instance, None)
|
||||
yield mock_controller_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
with patch('openhands.core.cli.parse_arguments') as mock_parse_args:
|
||||
args = Mock()
|
||||
args.file = None
|
||||
args.task = None
|
||||
args.directory = None
|
||||
mock_parse_args.return_value = args
|
||||
with patch('openhands.core.cli.setup_config_from_args') as mock_setup_config:
|
||||
mock_config = AppConfig()
|
||||
mock_config.cli_multiline_input = False
|
||||
mock_config.security = Mock()
|
||||
mock_config.security.confirmation_mode = False
|
||||
mock_config.sandbox = Mock()
|
||||
mock_config.sandbox.selected_repo = None
|
||||
mock_setup_config.return_value = mock_config
|
||||
yield mock_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory():
|
||||
with patch('openhands.core.cli.create_memory') as mock_create_memory:
|
||||
mock_memory_instance = AsyncMock()
|
||||
mock_create_memory.return_value = mock_memory_instance
|
||||
yield mock_memory_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_read_task():
|
||||
with patch('openhands.core.cli.read_task') as mock_read_task:
|
||||
mock_read_task.return_value = None
|
||||
yield mock_read_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime():
|
||||
with patch('openhands.core.cli.create_runtime') as mock_create_runtime:
|
||||
mock_runtime_instance = AsyncMock()
|
||||
|
||||
mock_event_stream = MockEventStream()
|
||||
mock_runtime_instance.event_stream = mock_event_stream
|
||||
|
||||
mock_runtime_instance.connect = AsyncMock()
|
||||
|
||||
# Ensure status_callback is None
|
||||
mock_runtime_instance.status_callback = None
|
||||
# Mock get_microagents_from_selected_repo
|
||||
mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[])
|
||||
mock_create_runtime.return_value = mock_runtime_instance
|
||||
yield mock_runtime_instance
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cli_greeting(
|
||||
mock_runtime, mock_controller, mock_config, mock_agent, mock_memory, mock_read_task
|
||||
):
|
||||
buffer = StringIO()
|
||||
|
||||
with create_app_session(
|
||||
input=create_pipe_input(), output=create_output(stdout=buffer)
|
||||
):
|
||||
mock_controller.status_callback = None
|
||||
|
||||
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
hello_response = MessageAction(content='Ping')
|
||||
hello_response._source = EventSource.AGENT
|
||||
mock_runtime.event_stream.add_event(hello_response, EventSource.AGENT)
|
||||
|
||||
state_change = AgentStateChangedObservation(
|
||||
content='Awaiting user input', agent_state=AgentState.AWAITING_USER_INPUT
|
||||
)
|
||||
state_change._source = EventSource.AGENT
|
||||
mock_runtime.event_stream.add_event(state_change, EventSource.AGENT)
|
||||
|
||||
stop_event = AgentStateChangedObservation(
|
||||
content='Stop', agent_state=AgentState.STOPPED
|
||||
)
|
||||
stop_event._source = EventSource.AGENT
|
||||
mock_runtime.event_stream.add_event(stop_event, EventSource.AGENT)
|
||||
|
||||
mock_controller.state.agent_state = AgentState.STOPPED
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(main_task, timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
main_task.cancel()
|
||||
|
||||
buffer.seek(0)
|
||||
output = buffer.read()
|
||||
|
||||
assert 'Ping' in output
|
||||
@@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
from argparse import Namespace
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.application import create_app_session
|
||||
from prompt_toolkit.input import create_pipe_input
|
||||
from prompt_toolkit.output import create_output
|
||||
|
||||
from openhands.core.cli import main
|
||||
from openhands.core.config import AppConfig
|
||||
@@ -83,34 +87,41 @@ def mock_config(task_file: Path):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cli_session_id_output(
|
||||
mock_runtime, mock_agent, mock_controller, mock_config, capsys
|
||||
mock_runtime, mock_agent, mock_controller, mock_config
|
||||
):
|
||||
# status_callback is set when initializing the runtime
|
||||
mock_controller.status_callback = None
|
||||
|
||||
buffer = StringIO()
|
||||
|
||||
# Use input patch just for the exit command
|
||||
with patch('builtins.input', return_value='exit'):
|
||||
# Create a task for main
|
||||
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
|
||||
with create_app_session(
|
||||
input=create_pipe_input(), output=create_output(stdout=buffer)
|
||||
):
|
||||
# Create a task for main
|
||||
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
|
||||
|
||||
# Give it a moment to display the session ID
|
||||
await asyncio.sleep(0.1)
|
||||
# Give it a moment to display the session ID
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Trigger agent state change to STOPPED to end the main loop
|
||||
event = AgentStateChangedObservation(
|
||||
content='Stop', agent_state=AgentState.STOPPED
|
||||
)
|
||||
event._source = EventSource.AGENT
|
||||
await mock_runtime.event_stream.add_event(event)
|
||||
# Trigger agent state change to STOPPED to end the main loop
|
||||
event = AgentStateChangedObservation(
|
||||
content='Stop', agent_state=AgentState.STOPPED
|
||||
)
|
||||
event._source = EventSource.AGENT
|
||||
await mock_runtime.event_stream.add_event(event)
|
||||
|
||||
# Wait for main to finish with a timeout
|
||||
try:
|
||||
await asyncio.wait_for(main_task, timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
main_task.cancel()
|
||||
# Wait for main to finish with a timeout
|
||||
try:
|
||||
await asyncio.wait_for(main_task, timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
main_task.cancel()
|
||||
|
||||
# Check the output
|
||||
captured = capsys.readouterr()
|
||||
assert 'Session ID:' in captured.out
|
||||
# Also verify that our task message was processed
|
||||
assert 'Ask me what your task is' in str(mock_runtime.mock_calls)
|
||||
buffer.seek(0)
|
||||
output = buffer.read()
|
||||
|
||||
# Check the output
|
||||
assert 'Session ID:' in output
|
||||
# Also verify that our task message was processed
|
||||
assert 'Ask me what your task is' in str(mock_runtime.mock_calls)
|
||||
|
||||
Reference in New Issue
Block a user