feat: Add CLI support for agent pause and resume (#8129)

Co-authored-by: Bashwara Undupitiya <bashwarau@verdentra.com>
This commit is contained in:
Panduka Muditha 2025-04-29 01:56:18 +05:30 committed by GitHub
parent 06ce12eff4
commit 998de564cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 490 additions and 97 deletions

View File

@ -14,12 +14,14 @@ from openhands.core.cli_commands import (
)
from openhands.core.cli_tui import (
UsageMetrics,
display_agent_running_message,
display_banner,
display_event,
display_initial_user_prompt,
display_initialization_animation,
display_runtime_initialization_message,
display_welcome_message,
process_agent_pause,
read_confirmation_input,
read_prompt_input,
)
@ -99,6 +101,7 @@ async def run_session(
sid = str(uuid4())
is_loaded = asyncio.Event()
is_paused = asyncio.Event()
# Show runtime initialization message
display_runtime_initialization_message(config.runtime)
@ -124,10 +127,12 @@ async def run_session(
usage_metrics = UsageMetrics()
async def prompt_for_next_task():
async def prompt_for_next_task(agent_state: str):
nonlocal reload_microagents, new_session_requested
while True:
next_message = await read_prompt_input(config.cli_multiline_input)
next_message = await read_prompt_input(
agent_state, multiline=config.cli_multiline_input
)
if not next_message.strip():
continue
@ -150,14 +155,23 @@ async def run_session(
return
async def on_event_async(event: Event) -> None:
nonlocal reload_microagents
nonlocal reload_microagents, is_paused
display_event(event, config)
update_usage_metrics(event, usage_metrics)
# Pause the agent if the pause event is set (if Ctrl-P is pressed)
if is_paused.is_set():
event_stream.add_event(
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.USER,
)
is_paused.clear()
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
AgentState.PAUSED,
]:
# Reload microagents after initialization of repo.md
if reload_microagents:
@ -166,20 +180,28 @@ async def run_session(
)
memory.load_user_workspace_microagents(microagents)
reload_microagents = False
await prompt_for_next_task()
await prompt_for_next_task(event.agent_state)
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
user_confirmed = await read_confirmation_input()
if user_confirmed:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
EventSource.USER,
)
else:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_REJECTED),
EventSource.USER,
)
# Only display the confirmation prompt if the agent is not paused
if not is_paused.is_set():
user_confirmed = await read_confirmation_input()
if user_confirmed:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
EventSource.USER,
)
else:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_REJECTED),
EventSource.USER,
)
if event.agent_state == AgentState.RUNNING:
# Enable pause/resume functionality only if the confirmation mode is disabled
if not config.security.confirmation_mode:
display_agent_running_message()
loop.create_task(process_agent_pause(is_paused))
def on_event(event: Event) -> None:
loop.create_task(on_event_async(event))
@ -212,7 +234,7 @@ async def run_session(
clear()
# Show OpenHands banner and session ID
display_banner(session_id=sid, is_loaded=is_loaded)
display_banner(session_id=sid)
# Show OpenHands welcome
display_welcome_message()
@ -225,7 +247,7 @@ async def run_session(
)
else:
# Otherwise prompt for the user's first message right away
asyncio.create_task(prompt_for_next_task())
asyncio.create_task(prompt_for_next_task(''))
await run_agent_until_done(
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]

View File

@ -70,6 +70,8 @@ async def handle_commands(
)
elif command == '/settings':
await handle_settings_command(config, settings_store)
elif command == '/resume':
close_repl, new_session_requested = await handle_resume_command(event_stream)
else:
close_repl = True
action = MessageAction(content=command)
@ -183,6 +185,28 @@ async def handle_settings_command(
await modify_llm_settings_advanced(config, settings_store)
# FIXME: Currently there's an issue with the actual 'resume' behavior.
# Setting the agent state to RUNNING will currently freeze the agent without continuing with the rest of the task.
# This is a workaround to handle the resume command for the time being. Replace user message with the state change event once the issue is fixed.
async def handle_resume_command(
event_stream: EventStream,
) -> tuple[bool, bool]:
close_repl = True
new_session_requested = False
event_stream.add_event(
MessageAction(content='continue'),
EventSource.USER,
)
# event_stream.add_event(
# ChangeAgentStateAction(AgentState.RUNNING),
# EventSource.ENVIRONMENT,
# )
return close_repl, new_session_requested
async def init_repository(current_dir: str) -> bool:
repo_file_path = Path(current_dir) / '.openhands' / 'microagents' / 'repo.md'
init_repo = False

View File

@ -10,7 +10,9 @@ from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.application import Application
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.formatted_text import HTML, FormattedText, StyleAndTextTuples
from prompt_toolkit.input import create_input
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.keys import Keys
from prompt_toolkit.layout.containers import HSplit, Window
from prompt_toolkit.layout.controls import FormattedTextControl
from prompt_toolkit.layout.layout import Layout
@ -22,6 +24,7 @@ from prompt_toolkit.widgets import Frame, TextArea
from openhands import __version__
from openhands.core.config import AppConfig
from openhands.core.schema import AgentState
from openhands.events import EventSource
from openhands.events.action import (
Action,
@ -32,6 +35,7 @@ from openhands.events.action import (
)
from openhands.events.event import Event
from openhands.events.observation import (
AgentStateChangedObservation,
CmdOutputObservation,
FileEditObservation,
FileReadObservation,
@ -56,6 +60,7 @@ COMMANDS = {
'/status': 'Display session details and usage metrics',
'/new': 'Create a new session',
'/settings': 'Display and modify current settings',
'/resume': 'Resume the agent',
}
@ -114,7 +119,7 @@ def display_initialization_animation(text, is_loaded: asyncio.Event):
sys.stdout.flush()
def display_banner(session_id: str, is_loaded: asyncio.Event):
def display_banner(session_id: str):
print_formatted_text(
HTML(r"""<gold>
___ _ _ _
@ -129,11 +134,8 @@ def display_banner(session_id: str, is_loaded: asyncio.Event):
print_formatted_text(HTML(f'<grey>OpenHands CLI v{__version__}</grey>'))
banner_text = (
'Initialized session' if is_loaded.is_set() else 'Initializing session'
)
print_formatted_text('')
print_formatted_text(HTML(f'<grey>{banner_text} {session_id}</grey>'))
print_formatted_text(HTML(f'<grey>Initialized session {session_id}</grey>'))
print_formatted_text('')
@ -177,6 +179,8 @@ def display_event(event: Event, config: AppConfig) -> None:
display_file_edit(event)
if isinstance(event, FileReadObservation):
display_file_read(event)
if isinstance(event, AgentStateChangedObservation):
display_agent_paused_message(event.agent_state)
def display_message(message: str):
@ -389,77 +393,58 @@ def display_status(usage_metrics: UsageMetrics, session_id: str):
display_usage_metrics(usage_metrics)
def display_agent_running_message():
print_formatted_text('')
print_formatted_text(
HTML('<gold>Agent running...</gold> <grey>(Ctrl-P to pause)</grey>')
)
def display_agent_paused_message(agent_state: str):
if agent_state != AgentState.PAUSED:
return
print_formatted_text('')
print_formatted_text(
HTML('<gold>Agent paused</gold> <grey>(type /resume to resume)</grey>')
)
# Common input functions
class CommandCompleter(Completer):
"""Custom completer for commands."""
def __init__(self, agent_state: str):
super().__init__()
self.agent_state = agent_state
def get_completions(self, document, complete_event):
text = document.text
# Only show completions if the user has typed '/'
text = document.text_before_cursor.lstrip()
if text.startswith('/'):
# If just '/' is typed, show all commands
if text == '/':
for command, description in COMMANDS.items():
available_commands = dict(COMMANDS)
if self.agent_state != AgentState.PAUSED:
available_commands.pop('/resume', None)
for command, description in available_commands.items():
if command.startswith(text):
yield Completion(
command[1:], # Remove the leading '/' as it's already typed
start_position=0,
display=f'{command} - {description}',
command,
start_position=-len(text),
display_meta=description,
style='bg:ansidarkgray fg:ansiwhite',
)
# Otherwise show matching commands
else:
for command, description in COMMANDS.items():
if command.startswith(text):
yield Completion(
command[len(text) :], # Complete the remaining part
start_position=0,
display=f'{command} - {description}',
)
prompt_session = PromptSession(style=DEFAULT_STYLE)
# RPrompt animation related variables
SPINNER_FRAMES = [
'[ ■□□□ ]',
'[ □■□□ ]',
'[ □□■□ ]',
'[ □□□■ ]',
'[ □□■□ ]',
'[ □■□□ ]',
]
ANIMATION_INTERVAL = 0.2 # seconds
current_frame_index = 0
last_update_time = time.monotonic()
def create_prompt_session():
return PromptSession(style=DEFAULT_STYLE)
# RPrompt function for the user confirmation
def get_rprompt() -> FormattedText:
"""
Returns the current animation frame for the rprompt.
This function is called by prompt_toolkit during rendering.
"""
global current_frame_index, last_update_time
# Only update the frame if enough time has passed
# This prevents excessive recalculation during rendering
now = time.monotonic()
if now - last_update_time > ANIMATION_INTERVAL:
current_frame_index = (current_frame_index + 1) % len(SPINNER_FRAMES)
last_update_time = now
# Return the frame wrapped in FormattedText
return FormattedText(
[
('', ' '), # Add a space before the spinner
(COLOR_GOLD, SPINNER_FRAMES[current_frame_index]),
]
)
async def read_prompt_input(multiline=False):
async def read_prompt_input(agent_state: str, multiline=False):
try:
prompt_session = create_prompt_session()
prompt_session.completer = (
CommandCompleter(agent_state) if not multiline else None
)
if multiline:
kb = KeyBindings()
@ -470,38 +455,54 @@ async def read_prompt_input(multiline=False):
with patch_stdout():
print_formatted_text('')
message = await prompt_session.prompt_async(
'Enter your message and press Ctrl+D to finish:\n',
HTML(
'<gold>Enter your message and press Ctrl-D to finish:</gold>\n'
),
multiline=True,
key_bindings=kb,
)
else:
with patch_stdout():
print_formatted_text('')
prompt_session.completer = CommandCompleter()
message = await prompt_session.prompt_async(
'> ',
HTML('<gold>> </gold>'),
)
return message
return message if message is not None else ''
except (KeyboardInterrupt, EOFError):
return '/exit'
async def read_confirmation_input():
async def read_confirmation_input() -> bool:
try:
prompt_session = create_prompt_session()
with patch_stdout():
prompt_session.completer = None
confirmation = await prompt_session.prompt_async(
'Proceed with action? (y)es/(n)o > ',
rprompt=get_rprompt,
refresh_interval=ANIMATION_INTERVAL / 2,
print_formatted_text('')
confirmation: str = await prompt_session.prompt_async(
HTML('<gold>Proceed with action? (y)es/(n)o > </gold>'),
)
prompt_session.rprompt = None
confirmation = confirmation.strip().lower()
confirmation = '' if confirmation is None else confirmation.strip().lower()
return confirmation in ['y', 'yes']
except (KeyboardInterrupt, EOFError):
return False
async def process_agent_pause(done: asyncio.Event) -> None:
input = create_input()
def keys_ready():
for key_press in input.read_keys():
if key_press.key == Keys.ControlP:
print_formatted_text('')
print_formatted_text(HTML('<gold>Pausing the agent...</gold>'))
done.set()
with input.raw_mode():
with input.attach(keys_ready):
await done.wait()
def cli_confirm(
question: str = 'Are you sure?', choices: list[str] | None = None
) -> int:

View File

@ -8,6 +8,7 @@ from openhands.core.cli_commands import (
handle_help_command,
handle_init_command,
handle_new_command,
handle_resume_command,
handle_settings_command,
handle_status_command,
)
@ -461,3 +462,27 @@ class TestHandleSettingsCommand:
# Verify correct behavior
mock_display_settings.assert_called_once_with(config)
mock_cli_confirm.assert_called_once()
class TestHandleResumeCommand:
@pytest.mark.asyncio
async def test_handle_resume_command(self):
"""Test that handle_resume_command adds the 'continue' message to the event stream."""
# Create a mock event stream
event_stream = MagicMock(spec=EventStream)
# Call the function
close_repl, new_session_requested = await handle_resume_command(event_stream)
# Check that the event stream add_event was called with the correct message action
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
message_action, source = args
assert isinstance(message_action, MessageAction)
assert message_action.content == 'continue'
assert source == EventSource.USER
# Check the return values
assert close_repl is True
assert new_session_requested is False

View File

@ -0,0 +1,325 @@
import asyncio
from unittest.mock import MagicMock, call, patch
import pytest
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.keys import Keys
from openhands.core.cli_tui import process_agent_pause
from openhands.core.schema import AgentState
from openhands.events import EventSource
from openhands.events.action import ChangeAgentStateAction
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.stream import EventStream
class TestProcessAgentPause:
@pytest.mark.asyncio
@patch('openhands.core.cli_tui.create_input')
@patch('openhands.core.cli_tui.print_formatted_text')
async def test_process_agent_pause_ctrl_p(self, mock_print, mock_create_input):
"""Test that process_agent_pause sets the done event when Ctrl+P is pressed."""
# Create the done event
done = asyncio.Event()
# Set up the mock input
mock_input = MagicMock()
mock_create_input.return_value = mock_input
# Mock the context managers
mock_raw_mode = MagicMock()
mock_input.raw_mode.return_value = mock_raw_mode
mock_raw_mode.__enter__ = MagicMock()
mock_raw_mode.__exit__ = MagicMock()
mock_attach = MagicMock()
mock_input.attach.return_value = mock_attach
mock_attach.__enter__ = MagicMock()
mock_attach.__exit__ = MagicMock()
# Capture the keys_ready function
keys_ready_func = None
def fake_attach(callback):
nonlocal keys_ready_func
keys_ready_func = callback
return mock_attach
mock_input.attach.side_effect = fake_attach
# Create a task to run process_agent_pause
task = asyncio.create_task(process_agent_pause(done))
# Give it a moment to start and capture the callback
await asyncio.sleep(0.1)
# Make sure we captured the callback
assert keys_ready_func is not None
# Create a key press that simulates Ctrl+P
key_press = MagicMock()
key_press.key = Keys.ControlP
mock_input.read_keys.return_value = [key_press]
# Manually call the callback to simulate key press
keys_ready_func()
# Verify done was set
assert done.is_set()
# Verify print was called with the pause message
assert mock_print.call_count == 2
assert mock_print.call_args_list[0] == call('')
# Check that the second call contains the pause message HTML
second_call = mock_print.call_args_list[1][0][0]
assert isinstance(second_call, HTML)
assert 'Pausing the agent' in str(second_call)
# Cancel the task
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
class TestCliPauseResumeInRunSession:
@pytest.mark.asyncio
async def test_on_event_async_pause_processing(self):
"""Test that on_event_async processes the pause event when is_paused is set."""
# Create a mock event
event = MagicMock()
# Create mock dependencies
event_stream = MagicMock()
is_paused = asyncio.Event()
reload_microagents = False
config = MagicMock()
# Patch the display_event function
with patch('openhands.core.cli.display_event') as mock_display_event, patch(
'openhands.core.cli.update_usage_metrics'
) as mock_update_metrics:
# Create a closure to capture the current context
async def test_func():
# Set the pause event
is_paused.set()
# Create a context similar to run_session to call on_event_async
# We're creating a function that mimics the environment of on_event_async
async def on_event_async_test(event):
nonlocal reload_microagents, is_paused
mock_display_event(event, config)
mock_update_metrics(event, usage_metrics=MagicMock())
# Pause the agent if the pause event is set (through Ctrl-P)
if is_paused.is_set():
event_stream.add_event(
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.USER,
)
is_paused.clear()
# Call our test function
await on_event_async_test(event)
# Check that the event_stream.add_event was called with the correct action
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
action, source = args
assert isinstance(action, ChangeAgentStateAction)
assert action.agent_state == AgentState.PAUSED
assert source == EventSource.USER
# Check that is_paused was cleared
assert not is_paused.is_set()
# Run the test function
await test_func()
class TestCliCommandsPauseResume:
@pytest.mark.asyncio
@patch('openhands.core.cli_commands.handle_resume_command')
async def test_handle_commands_resume(self, mock_handle_resume):
"""Test that the handle_commands function properly calls handle_resume_command."""
# Import the handle_commands function
from openhands.core.cli_commands import handle_commands
# Set up mocks
event_stream = MagicMock(spec=EventStream)
usage_metrics = MagicMock()
sid = 'test-session-id'
config = MagicMock()
current_dir = '/test/dir'
settings_store = MagicMock()
# Set the return value for handle_resume_command
mock_handle_resume.return_value = (False, False)
# Call handle_commands with the resume command
close_repl, reload_microagents, new_session_requested = await handle_commands(
'/resume',
event_stream,
usage_metrics,
sid,
config,
current_dir,
settings_store,
)
# Check that handle_resume_command was called with the correct arguments
mock_handle_resume.assert_called_once_with(event_stream)
# Check the return values
assert close_repl is False
assert reload_microagents is False
assert new_session_requested is False
class TestAgentStatePauseResume:
@pytest.mark.asyncio
@patch('openhands.core.cli.display_agent_running_message')
@patch('openhands.core.cli.process_agent_pause')
async def test_agent_running_enables_pause(
self, mock_process_agent_pause, mock_display_message
):
"""Test that when the agent is running, pause functionality is enabled."""
# Create mock dependencies
event = MagicMock()
# AgentStateChangedObservation requires a content parameter
event.observation = AgentStateChangedObservation(
agent_state=AgentState.RUNNING, content='Agent state changed to RUNNING'
)
# Create a context similar to run_session to call on_event_async
loop = MagicMock()
is_paused = asyncio.Event()
config = MagicMock()
config.security.confirmation_mode = False
# Create a closure to capture the current context
async def test_func():
# Call our simplified on_event_async
async def on_event_async_test(event):
if isinstance(event.observation, AgentStateChangedObservation):
if event.observation.agent_state == AgentState.RUNNING:
# Enable pause/resume functionality only if the confirmation mode is disabled
if not config.security.confirmation_mode:
mock_display_message()
loop.create_task(mock_process_agent_pause(is_paused))
# Call the function
await on_event_async_test(event)
# Check that the message was displayed
mock_display_message.assert_called_once()
# Check that process_agent_pause was called with the right arguments
mock_process_agent_pause.assert_called_once_with(is_paused)
# Check that loop.create_task was called
loop.create_task.assert_called_once()
# Run the test function
await test_func()
@pytest.mark.asyncio
@patch('openhands.core.cli.display_event')
@patch('openhands.core.cli.update_usage_metrics')
async def test_pause_event_changes_agent_state(
self, mock_update_metrics, mock_display_event
):
"""Test that when is_paused is set, a PAUSED state change event is added to the stream."""
# Create mock dependencies
event = MagicMock()
event_stream = MagicMock()
is_paused = asyncio.Event()
config = MagicMock()
reload_microagents = False
# Set the pause event
is_paused.set()
# Create a closure to capture the current context
async def test_func():
# Create a context similar to run_session to call on_event_async
async def on_event_async_test(event):
nonlocal reload_microagents
mock_display_event(event, config)
mock_update_metrics(event, MagicMock())
# Pause the agent if the pause event is set (through Ctrl-P)
if is_paused.is_set():
event_stream.add_event(
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.USER,
)
is_paused.clear()
# Call the function
await on_event_async_test(event)
# Check that the event_stream.add_event was called with the correct action
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
action, source = args
assert isinstance(action, ChangeAgentStateAction)
assert action.agent_state == AgentState.PAUSED
assert source == EventSource.USER
# Check that is_paused was cleared
assert not is_paused.is_set()
# Run the test
await test_func()
@pytest.mark.asyncio
async def test_paused_agent_awaits_input(self):
"""Test that when the agent is paused, it awaits user input."""
# Create mock dependencies
event = MagicMock()
# AgentStateChangedObservation requires a content parameter
event.observation = AgentStateChangedObservation(
agent_state=AgentState.PAUSED, content='Agent state changed to PAUSED'
)
reload_microagents = False
memory = MagicMock()
runtime = MagicMock()
prompt_task = MagicMock()
# Create a closure to capture the current context
async def test_func():
# Create a simplified version of on_event_async
async def on_event_async_test(event):
nonlocal reload_microagents, prompt_task
if isinstance(event.observation, AgentStateChangedObservation):
if event.observation.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
AgentState.PAUSED,
]:
# Reload microagents after initialization of repo.md
if reload_microagents:
microagents = runtime.get_microagents_from_selected_repo(
None
)
memory.load_user_workspace_microagents(microagents)
reload_microagents = False
# Since prompt_for_next_task is a nested function in cli.py,
# we'll just check that we've reached this code path
prompt_task = 'Prompt for next task would be called here'
# Call the function
await on_event_async_test(event)
# Check that we reached the code path where prompt_for_next_task would be called
assert prompt_task == 'Prompt for next task would be called here'
# Run the test
await test_func()

View File

@ -1,4 +1,3 @@
import asyncio
from unittest.mock import MagicMock, Mock, patch
from openhands.core.cli_tui import (
@ -52,12 +51,9 @@ class TestDisplayFunctions:
@patch('openhands.core.cli_tui.print_formatted_text')
def test_display_banner(self, mock_print):
# Create a mock loaded event
is_loaded = asyncio.Event()
is_loaded.set()
session_id = 'test-session-id'
display_banner(session_id, is_loaded)
display_banner(session_id)
# Verify banner calls
assert mock_print.call_count >= 3