feat: Add basic support for prompt-toolkit in the CLI (#7709)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Bashwara Undupitiya <bashwarau@verdentra.com>
This commit is contained in:
Panduka Muditha
2025-04-08 18:17:11 +05:30
committed by GitHub
parent dd03d9adce
commit 60e8b5841c
5 changed files with 304 additions and 46 deletions

View File

@@ -3,7 +3,9 @@ import logging
import sys
from uuid import uuid4
from termcolor import colored
from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.key_binding import KeyBindings
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
from openhands.core.config import (
@@ -36,24 +38,66 @@ from openhands.events.observation import (
CmdOutputObservation,
FileEditObservation,
)
from openhands.io import read_input, read_task
from openhands.io import read_task
prompt_session = PromptSession()
def display_message(message: str):
print(colored('🤖 ' + message + '\n', 'yellow'))
print_formatted_text(
FormattedText(
[
('ansiyellow', '🤖 '),
('ansiyellow', message),
('', '\n'),
]
)
)
def display_command(command: str):
print(' ' + colored(command + '\n', 'green'))
print_formatted_text(
FormattedText(
[
('', ' '),
('ansigreen', command),
('', '\n'),
]
)
)
def display_confirmation(confirmation_state: ActionConfirmationStatus):
if confirmation_state == ActionConfirmationStatus.CONFIRMED:
print(colored('' + confirmation_state + '\n', 'green'))
print_formatted_text(
FormattedText(
[
('ansigreen', ''),
('ansigreen', str(confirmation_state)),
('', '\n'),
]
)
)
elif confirmation_state == ActionConfirmationStatus.REJECTED:
print(colored('' + confirmation_state + '\n', 'red'))
print_formatted_text(
FormattedText(
[
('ansired', ''),
('ansired', str(confirmation_state)),
('', '\n'),
]
)
)
else:
print(colored('' + confirmation_state + '\n', 'yellow'))
print_formatted_text(
FormattedText(
[
('ansiyellow', ''),
('ansiyellow', str(confirmation_state)),
('', '\n'),
]
)
)
def display_command_output(output: str):
@@ -62,12 +106,19 @@ def display_command_output(output: str):
if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
# TODO: clean this up once we clean up terminal output
continue
print(colored(line, 'blue'))
print('\n')
print_formatted_text(FormattedText([('ansiblue', line)]))
print_formatted_text('')
def display_file_edit(event: FileEditAction | FileEditObservation):
print(colored(str(event), 'green'))
print_formatted_text(
FormattedText(
[
('ansigreen', str(event)),
('', '\n'),
]
)
)
def display_event(event: Event, config: AppConfig):
@@ -89,6 +140,41 @@ def display_event(event: Event, config: AppConfig):
display_confirmation(event.confirmation_state)
async def read_prompt_input(multiline=False):
try:
if multiline:
kb = KeyBindings()
@kb.add('c-d')
def _(event):
event.current_buffer.validate_and_handle()
message = await prompt_session.prompt_async(
'Enter your message and press Ctrl+D to finish:\n',
multiline=True,
key_bindings=kb,
)
else:
message = await prompt_session.prompt_async(
'>> ',
)
return message
except KeyboardInterrupt:
return 'exit'
except EOFError:
return 'exit'
async def read_confirmation_input():
try:
confirmation = await prompt_session.prompt_async(
'Confirm action (possible security risk)? (y/n) >> ',
)
return confirmation.lower() == 'y'
except (KeyboardInterrupt, EOFError):
return False
async def main(loop: asyncio.AbstractEventLoop):
"""Runs the agent in CLI mode."""
@@ -122,10 +208,7 @@ async def main(loop: asyncio.AbstractEventLoop):
event_stream = runtime.event_stream
async def prompt_for_next_task():
# Run input() in a thread pool to avoid blocking the event loop
next_message = await loop.run_in_executor(
None, read_input, config.cli_multiline_input
)
next_message = await read_prompt_input(config.cli_multiline_input)
if not next_message.strip():
await prompt_for_next_task()
if next_message == 'exit':
@@ -136,12 +219,6 @@ async def main(loop: asyncio.AbstractEventLoop):
action = MessageAction(content=next_message)
event_stream.add_event(action, EventSource.USER)
async def prompt_for_user_confirmation():
user_confirmation = await loop.run_in_executor(
None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
)
return user_confirmation.lower() == 'y'
async def on_event_async(event: Event):
display_event(event, config)
if isinstance(event, AgentStateChangedObservation):
@@ -151,7 +228,7 @@ async def main(loop: asyncio.AbstractEventLoop):
]:
await prompt_for_next_task()
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
user_confirmed = await prompt_for_user_confirmation()
user_confirmed = await read_confirmation_input()
if user_confirmed:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),

8
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -2661,7 +2661,7 @@ grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_versi
grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
proto-plus = [
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
{version = ">=1.22.3,<2.0.0dev"},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
requests = ">=2.18.0,<3.0.0.dev0"
@@ -9212,7 +9212,7 @@ description = "A language and compiler for custom Deep Learning operations"
optional = false
python-versions = "*"
groups = ["evaluation"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version == \"3.12\""
files = [
{file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"},
{file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"},
@@ -10193,4 +10193,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.1"
python-versions = "^3.12"
content-hash = "2072ef82a79cc7be1652af1b33fb9814d85e1c6ad4ed410e12f2e57e0fadbd58"
content-hash = "bd1e164559f6395718bc76482493aa71327e7aee9ebe6131facba65661d9abec"

View File

@@ -79,6 +79,7 @@ memory-profiler = "^0.61.0"
daytona-sdk = "0.12.1"
python-json-logger = "^3.2.1"
playwright = "^1.51.0"
prompt-toolkit = "^3.0.50"
[tool.poetry.group.dev.dependencies]
ruff = "0.11.4"

View File

@@ -0,0 +1,169 @@
import asyncio
from datetime import datetime
from io import StringIO
from unittest.mock import AsyncMock, Mock, patch
import pytest
from prompt_toolkit.application import create_app_session
from prompt_toolkit.input import create_pipe_input
from prompt_toolkit.output import create_output
from openhands.core.cli import main
from openhands.core.config import AppConfig
from openhands.core.schema import AgentState
from openhands.events.action import MessageAction
from openhands.events.event import EventSource
from openhands.events.observation import AgentStateChangedObservation
class MockEventStream:
def __init__(self):
self._subscribers = {}
self.cur_id = 0
def subscribe(self, subscriber_id, callback, callback_id):
if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = {}
self._subscribers[subscriber_id][callback_id] = callback
def unsubscribe(self, subscriber_id, callback_id):
if (
subscriber_id in self._subscribers
and callback_id in self._subscribers[subscriber_id]
):
del self._subscribers[subscriber_id][callback_id]
def add_event(self, event, source):
event._id = self.cur_id
self.cur_id += 1
event._source = source
event._timestamp = datetime.now().isoformat()
for subscriber_id in self._subscribers:
for callback_id, callback in self._subscribers[subscriber_id].items():
callback(event)
@pytest.fixture
def mock_agent():
with patch('openhands.core.cli.create_agent') as mock_create_agent:
mock_agent_instance = AsyncMock()
mock_agent_instance.name = 'test-agent'
mock_agent_instance.llm = AsyncMock()
mock_agent_instance.llm.config = AsyncMock()
mock_agent_instance.llm.config.model = 'test-model'
mock_agent_instance.llm.config.base_url = 'http://test'
mock_agent_instance.llm.config.max_message_chars = 1000
mock_agent_instance.config = AsyncMock()
mock_agent_instance.config.disabled_microagents = []
mock_agent_instance.sandbox_plugins = []
mock_agent_instance.prompt_manager = AsyncMock()
mock_create_agent.return_value = mock_agent_instance
yield mock_agent_instance
@pytest.fixture
def mock_controller():
with patch('openhands.core.cli.create_controller') as mock_create_controller:
mock_controller_instance = AsyncMock()
mock_controller_instance.state.agent_state = None
# Mock run_until_done to finish immediately
mock_controller_instance.run_until_done = AsyncMock(return_value=None)
mock_create_controller.return_value = (mock_controller_instance, None)
yield mock_controller_instance
@pytest.fixture
def mock_config():
with patch('openhands.core.cli.parse_arguments') as mock_parse_args:
args = Mock()
args.file = None
args.task = None
args.directory = None
mock_parse_args.return_value = args
with patch('openhands.core.cli.setup_config_from_args') as mock_setup_config:
mock_config = AppConfig()
mock_config.cli_multiline_input = False
mock_config.security = Mock()
mock_config.security.confirmation_mode = False
mock_config.sandbox = Mock()
mock_config.sandbox.selected_repo = None
mock_setup_config.return_value = mock_config
yield mock_config
@pytest.fixture
def mock_memory():
with patch('openhands.core.cli.create_memory') as mock_create_memory:
mock_memory_instance = AsyncMock()
mock_create_memory.return_value = mock_memory_instance
yield mock_memory_instance
@pytest.fixture
def mock_read_task():
with patch('openhands.core.cli.read_task') as mock_read_task:
mock_read_task.return_value = None
yield mock_read_task
@pytest.fixture
def mock_runtime():
with patch('openhands.core.cli.create_runtime') as mock_create_runtime:
mock_runtime_instance = AsyncMock()
mock_event_stream = MockEventStream()
mock_runtime_instance.event_stream = mock_event_stream
mock_runtime_instance.connect = AsyncMock()
# Ensure status_callback is None
mock_runtime_instance.status_callback = None
# Mock get_microagents_from_selected_repo
mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[])
mock_create_runtime.return_value = mock_runtime_instance
yield mock_runtime_instance
@pytest.mark.asyncio
async def test_cli_greeting(
mock_runtime, mock_controller, mock_config, mock_agent, mock_memory, mock_read_task
):
buffer = StringIO()
with create_app_session(
input=create_pipe_input(), output=create_output(stdout=buffer)
):
mock_controller.status_callback = None
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
await asyncio.sleep(0.1)
hello_response = MessageAction(content='Ping')
hello_response._source = EventSource.AGENT
mock_runtime.event_stream.add_event(hello_response, EventSource.AGENT)
state_change = AgentStateChangedObservation(
content='Awaiting user input', agent_state=AgentState.AWAITING_USER_INPUT
)
state_change._source = EventSource.AGENT
mock_runtime.event_stream.add_event(state_change, EventSource.AGENT)
stop_event = AgentStateChangedObservation(
content='Stop', agent_state=AgentState.STOPPED
)
stop_event._source = EventSource.AGENT
mock_runtime.event_stream.add_event(stop_event, EventSource.AGENT)
mock_controller.state.agent_state = AgentState.STOPPED
try:
await asyncio.wait_for(main_task, timeout=1.0)
except asyncio.TimeoutError:
main_task.cancel()
buffer.seek(0)
output = buffer.read()
assert 'Ping' in output

View File

@@ -1,9 +1,13 @@
import asyncio
from argparse import Namespace
from io import StringIO
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch
import pytest
from prompt_toolkit.application import create_app_session
from prompt_toolkit.input import create_pipe_input
from prompt_toolkit.output import create_output
from openhands.core.cli import main
from openhands.core.config import AppConfig
@@ -83,34 +87,41 @@ def mock_config(task_file: Path):
@pytest.mark.asyncio
async def test_cli_session_id_output(
mock_runtime, mock_agent, mock_controller, mock_config, capsys
mock_runtime, mock_agent, mock_controller, mock_config
):
# status_callback is set when initializing the runtime
mock_controller.status_callback = None
buffer = StringIO()
# Use input patch just for the exit command
with patch('builtins.input', return_value='exit'):
# Create a task for main
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
with create_app_session(
input=create_pipe_input(), output=create_output(stdout=buffer)
):
# Create a task for main
main_task = asyncio.create_task(main(asyncio.get_event_loop()))
# Give it a moment to display the session ID
await asyncio.sleep(0.1)
# Give it a moment to display the session ID
await asyncio.sleep(0.1)
# Trigger agent state change to STOPPED to end the main loop
event = AgentStateChangedObservation(
content='Stop', agent_state=AgentState.STOPPED
)
event._source = EventSource.AGENT
await mock_runtime.event_stream.add_event(event)
# Trigger agent state change to STOPPED to end the main loop
event = AgentStateChangedObservation(
content='Stop', agent_state=AgentState.STOPPED
)
event._source = EventSource.AGENT
await mock_runtime.event_stream.add_event(event)
# Wait for main to finish with a timeout
try:
await asyncio.wait_for(main_task, timeout=1.0)
except asyncio.TimeoutError:
main_task.cancel()
# Wait for main to finish with a timeout
try:
await asyncio.wait_for(main_task, timeout=1.0)
except asyncio.TimeoutError:
main_task.cancel()
# Check the output
captured = capsys.readouterr()
assert 'Session ID:' in captured.out
# Also verify that our task message was processed
assert 'Ask me what your task is' in str(mock_runtime.mock_calls)
buffer.seek(0)
output = buffer.read()
# Check the output
assert 'Session ID:' in output
# Also verify that our task message was processed
assert 'Ask me what your task is' in str(mock_runtime.mock_calls)