diff --git a/openhands/core/cli.py b/openhands/core/cli.py index aed45ec287..4d381984b3 100644 --- a/openhands/core/cli.py +++ b/openhands/core/cli.py @@ -3,7 +3,9 @@ import logging import sys from uuid import uuid4 -from termcolor import colored +from prompt_toolkit import PromptSession, print_formatted_text +from prompt_toolkit.formatted_text import FormattedText +from prompt_toolkit.key_binding import KeyBindings import openhands.agenthub # noqa F401 (we import this to get the agents registered) from openhands.core.config import ( @@ -36,24 +38,66 @@ from openhands.events.observation import ( CmdOutputObservation, FileEditObservation, ) -from openhands.io import read_input, read_task +from openhands.io import read_task + +prompt_session = PromptSession() def display_message(message: str): - print(colored('🤖 ' + message + '\n', 'yellow')) + print_formatted_text( + FormattedText( + [ + ('ansiyellow', '🤖 '), + ('ansiyellow', message), + ('', '\n'), + ] + ) + ) def display_command(command: str): - print('❯ ' + colored(command + '\n', 'green')) + print_formatted_text( + FormattedText( + [ + ('', '❯ '), + ('ansigreen', command), + ('', '\n'), + ] + ) + ) def display_confirmation(confirmation_state: ActionConfirmationStatus): if confirmation_state == ActionConfirmationStatus.CONFIRMED: - print(colored('✅ ' + confirmation_state + '\n', 'green')) + print_formatted_text( + FormattedText( + [ + ('ansigreen', '✅ '), + ('ansigreen', str(confirmation_state)), + ('', '\n'), + ] + ) + ) elif confirmation_state == ActionConfirmationStatus.REJECTED: - print(colored('❌ ' + confirmation_state + '\n', 'red')) + print_formatted_text( + FormattedText( + [ + ('ansired', '❌ '), + ('ansired', str(confirmation_state)), + ('', '\n'), + ] + ) + ) else: - print(colored('⏳ ' + confirmation_state + '\n', 'yellow')) + print_formatted_text( + FormattedText( + [ + ('ansiyellow', '⏳ '), + ('ansiyellow', str(confirmation_state)), + ('', '\n'), + ] + ) + ) def display_command_output(output: str): @@ -62,12 +106,19 @@ def display_command_output(output: str): if line.startswith('[Python Interpreter') or line.startswith('openhands@'): # TODO: clean this up once we clean up terminal output continue - print(colored(line, 'blue')) - print('\n') + print_formatted_text(FormattedText([('ansiblue', line)])) + print_formatted_text('') def display_file_edit(event: FileEditAction | FileEditObservation): - print(colored(str(event), 'green')) + print_formatted_text( + FormattedText( + [ + ('ansigreen', str(event)), + ('', '\n'), + ] + ) + ) def display_event(event: Event, config: AppConfig): @@ -89,6 +140,41 @@ def display_event(event: Event, config: AppConfig): display_confirmation(event.confirmation_state) +async def read_prompt_input(multiline=False): + try: + if multiline: + kb = KeyBindings() + + @kb.add('c-d') + def _(event): + event.current_buffer.validate_and_handle() + + message = await prompt_session.prompt_async( + 'Enter your message and press Ctrl+D to finish:\n', + multiline=True, + key_bindings=kb, + ) + else: + message = await prompt_session.prompt_async( + '>> ', + ) + return message + except KeyboardInterrupt: + return 'exit' + except EOFError: + return 'exit' + + +async def read_confirmation_input(): + try: + confirmation = await prompt_session.prompt_async( + 'Confirm action (possible security risk)? (y/n) >> ', + ) + return confirmation.lower() == 'y' + except (KeyboardInterrupt, EOFError): + return False + + async def main(loop: asyncio.AbstractEventLoop): """Runs the agent in CLI mode.""" @@ -122,10 +208,7 @@ async def main(loop: asyncio.AbstractEventLoop): event_stream = runtime.event_stream async def prompt_for_next_task(): - # Run input() in a thread pool to avoid blocking the event loop - next_message = await loop.run_in_executor( - None, read_input, config.cli_multiline_input - ) + next_message = await read_prompt_input(config.cli_multiline_input) if not next_message.strip(): await prompt_for_next_task() if next_message == 'exit': @@ -136,12 +219,6 @@ async def main(loop: asyncio.AbstractEventLoop): action = MessageAction(content=next_message) event_stream.add_event(action, EventSource.USER) - async def prompt_for_user_confirmation(): - user_confirmation = await loop.run_in_executor( - None, lambda: input('Confirm action (possible security risk)? (y/n) >> ') - ) - return user_confirmation.lower() == 'y' - async def on_event_async(event: Event): display_event(event, config) if isinstance(event, AgentStateChangedObservation): @@ -151,7 +228,7 @@ async def main(loop: asyncio.AbstractEventLoop): ]: await prompt_for_next_task() if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION: - user_confirmed = await prompt_for_user_confirmation() + user_confirmed = await read_confirmation_input() if user_confirmed: event_stream.add_event( ChangeAgentStateAction(AgentState.USER_CONFIRMED), diff --git a/poetry.lock b/poetry.lock index fdc3c6f54d..0cac82f1ff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -2661,7 +2661,7 @@ grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_versi grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} proto-plus = [ {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, - {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.22.3,<2.0.0dev"}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -9212,7 +9212,7 @@ description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" groups = ["evaluation"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version == \"3.12\"" files = [ {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, @@ -10193,4 +10193,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.1" python-versions = "^3.12" -content-hash = "2072ef82a79cc7be1652af1b33fb9814d85e1c6ad4ed410e12f2e57e0fadbd58" +content-hash = "bd1e164559f6395718bc76482493aa71327e7aee9ebe6131facba65661d9abec" diff --git a/pyproject.toml b/pyproject.toml index 0338bedb73..e7427c8ec9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ memory-profiler = "^0.61.0" daytona-sdk = "0.12.1" python-json-logger = "^3.2.1" playwright = "^1.51.0" +prompt-toolkit = "^3.0.50" [tool.poetry.group.dev.dependencies] ruff = "0.11.4" diff --git a/tests/unit/test_cli_basic.py b/tests/unit/test_cli_basic.py new file mode 100644 index 0000000000..0251bbf647 --- /dev/null +++ b/tests/unit/test_cli_basic.py @@ -0,0 +1,169 @@ +import asyncio +from datetime import datetime +from io import StringIO +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from prompt_toolkit.application import create_app_session +from prompt_toolkit.input import create_pipe_input +from prompt_toolkit.output import create_output + +from openhands.core.cli import main +from openhands.core.config import AppConfig +from openhands.core.schema import AgentState +from openhands.events.action import MessageAction +from openhands.events.event import EventSource +from openhands.events.observation import AgentStateChangedObservation + + +class MockEventStream: + def __init__(self): + self._subscribers = {} + self.cur_id = 0 + + def subscribe(self, subscriber_id, callback, callback_id): + if subscriber_id not in self._subscribers: + self._subscribers[subscriber_id] = {} + self._subscribers[subscriber_id][callback_id] = callback + + def unsubscribe(self, subscriber_id, callback_id): + if ( + subscriber_id in self._subscribers + and callback_id in self._subscribers[subscriber_id] + ): + del self._subscribers[subscriber_id][callback_id] + + def add_event(self, event, source): + event._id = self.cur_id + self.cur_id += 1 + event._source = source + event._timestamp = datetime.now().isoformat() + + for subscriber_id in self._subscribers: + for callback_id, callback in self._subscribers[subscriber_id].items(): + callback(event) + + +@pytest.fixture +def mock_agent(): + with patch('openhands.core.cli.create_agent') as mock_create_agent: + mock_agent_instance = AsyncMock() + mock_agent_instance.name = 'test-agent' + mock_agent_instance.llm = AsyncMock() + mock_agent_instance.llm.config = AsyncMock() + mock_agent_instance.llm.config.model = 'test-model' + mock_agent_instance.llm.config.base_url = 'http://test' + mock_agent_instance.llm.config.max_message_chars = 1000 + mock_agent_instance.config = AsyncMock() + mock_agent_instance.config.disabled_microagents = [] + mock_agent_instance.sandbox_plugins = [] + mock_agent_instance.prompt_manager = AsyncMock() + mock_create_agent.return_value = mock_agent_instance + yield mock_agent_instance + + +@pytest.fixture +def mock_controller(): + with patch('openhands.core.cli.create_controller') as mock_create_controller: + mock_controller_instance = AsyncMock() + mock_controller_instance.state.agent_state = None + # Mock run_until_done to finish immediately + mock_controller_instance.run_until_done = AsyncMock(return_value=None) + mock_create_controller.return_value = (mock_controller_instance, None) + yield mock_controller_instance + + +@pytest.fixture +def mock_config(): + with patch('openhands.core.cli.parse_arguments') as mock_parse_args: + args = Mock() + args.file = None + args.task = None + args.directory = None + mock_parse_args.return_value = args + with patch('openhands.core.cli.setup_config_from_args') as mock_setup_config: + mock_config = AppConfig() + mock_config.cli_multiline_input = False + mock_config.security = Mock() + mock_config.security.confirmation_mode = False + mock_config.sandbox = Mock() + mock_config.sandbox.selected_repo = None + mock_setup_config.return_value = mock_config + yield mock_config + + +@pytest.fixture +def mock_memory(): + with patch('openhands.core.cli.create_memory') as mock_create_memory: + mock_memory_instance = AsyncMock() + mock_create_memory.return_value = mock_memory_instance + yield mock_memory_instance + + +@pytest.fixture +def mock_read_task(): + with patch('openhands.core.cli.read_task') as mock_read_task: + mock_read_task.return_value = None + yield mock_read_task + + +@pytest.fixture +def mock_runtime(): + with patch('openhands.core.cli.create_runtime') as mock_create_runtime: + mock_runtime_instance = AsyncMock() + + mock_event_stream = MockEventStream() + mock_runtime_instance.event_stream = mock_event_stream + + mock_runtime_instance.connect = AsyncMock() + + # Ensure status_callback is None + mock_runtime_instance.status_callback = None + # Mock get_microagents_from_selected_repo + mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[]) + mock_create_runtime.return_value = mock_runtime_instance + yield mock_runtime_instance + + +@pytest.mark.asyncio +async def test_cli_greeting( + mock_runtime, mock_controller, mock_config, mock_agent, mock_memory, mock_read_task +): + buffer = StringIO() + + with create_app_session( + input=create_pipe_input(), output=create_output(stdout=buffer) + ): + mock_controller.status_callback = None + + main_task = asyncio.create_task(main(asyncio.get_event_loop())) + + await asyncio.sleep(0.1) + + hello_response = MessageAction(content='Ping') + hello_response._source = EventSource.AGENT + mock_runtime.event_stream.add_event(hello_response, EventSource.AGENT) + + state_change = AgentStateChangedObservation( + content='Awaiting user input', agent_state=AgentState.AWAITING_USER_INPUT + ) + state_change._source = EventSource.AGENT + mock_runtime.event_stream.add_event(state_change, EventSource.AGENT) + + stop_event = AgentStateChangedObservation( + content='Stop', agent_state=AgentState.STOPPED + ) + stop_event._source = EventSource.AGENT + mock_runtime.event_stream.add_event(stop_event, EventSource.AGENT) + + mock_controller.state.agent_state = AgentState.STOPPED + + try: + await asyncio.wait_for(main_task, timeout=1.0) + except asyncio.TimeoutError: + main_task.cancel() + + buffer.seek(0) + output = buffer.read() + + assert 'Ping' in output diff --git a/tests/unit/test_cli_sid.py b/tests/unit/test_cli_sid.py index 67db79b63e..c24f9e6233 100644 --- a/tests/unit/test_cli_sid.py +++ b/tests/unit/test_cli_sid.py @@ -1,9 +1,13 @@ import asyncio from argparse import Namespace +from io import StringIO from pathlib import Path from unittest.mock import AsyncMock, Mock, patch import pytest +from prompt_toolkit.application import create_app_session +from prompt_toolkit.input import create_pipe_input +from prompt_toolkit.output import create_output from openhands.core.cli import main from openhands.core.config import AppConfig @@ -83,34 +87,41 @@ def mock_config(task_file: Path): @pytest.mark.asyncio async def test_cli_session_id_output( - mock_runtime, mock_agent, mock_controller, mock_config, capsys + mock_runtime, mock_agent, mock_controller, mock_config ): # status_callback is set when initializing the runtime mock_controller.status_callback = None + buffer = StringIO() + # Use input patch just for the exit command with patch('builtins.input', return_value='exit'): - # Create a task for main - main_task = asyncio.create_task(main(asyncio.get_event_loop())) + with create_app_session( + input=create_pipe_input(), output=create_output(stdout=buffer) + ): + # Create a task for main + main_task = asyncio.create_task(main(asyncio.get_event_loop())) - # Give it a moment to display the session ID - await asyncio.sleep(0.1) + # Give it a moment to display the session ID + await asyncio.sleep(0.1) - # Trigger agent state change to STOPPED to end the main loop - event = AgentStateChangedObservation( - content='Stop', agent_state=AgentState.STOPPED - ) - event._source = EventSource.AGENT - await mock_runtime.event_stream.add_event(event) + # Trigger agent state change to STOPPED to end the main loop + event = AgentStateChangedObservation( + content='Stop', agent_state=AgentState.STOPPED + ) + event._source = EventSource.AGENT + await mock_runtime.event_stream.add_event(event) - # Wait for main to finish with a timeout - try: - await asyncio.wait_for(main_task, timeout=1.0) - except asyncio.TimeoutError: - main_task.cancel() + # Wait for main to finish with a timeout + try: + await asyncio.wait_for(main_task, timeout=1.0) + except asyncio.TimeoutError: + main_task.cancel() - # Check the output - captured = capsys.readouterr() - assert 'Session ID:' in captured.out - # Also verify that our task message was processed - assert 'Ask me what your task is' in str(mock_runtime.mock_calls) + buffer.seek(0) + output = buffer.read() + + # Check the output + assert 'Session ID:' in output + # Also verify that our task message was processed + assert 'Ask me what your task is' in str(mock_runtime.mock_calls)