mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
feat: Add CLI support for agent pause and resume (#8129)
Co-authored-by: Bashwara Undupitiya <bashwarau@verdentra.com>
This commit is contained in:
parent
06ce12eff4
commit
998de564cd
@ -14,12 +14,14 @@ from openhands.core.cli_commands import (
|
||||
)
|
||||
from openhands.core.cli_tui import (
|
||||
UsageMetrics,
|
||||
display_agent_running_message,
|
||||
display_banner,
|
||||
display_event,
|
||||
display_initial_user_prompt,
|
||||
display_initialization_animation,
|
||||
display_runtime_initialization_message,
|
||||
display_welcome_message,
|
||||
process_agent_pause,
|
||||
read_confirmation_input,
|
||||
read_prompt_input,
|
||||
)
|
||||
@ -99,6 +101,7 @@ async def run_session(
|
||||
|
||||
sid = str(uuid4())
|
||||
is_loaded = asyncio.Event()
|
||||
is_paused = asyncio.Event()
|
||||
|
||||
# Show runtime initialization message
|
||||
display_runtime_initialization_message(config.runtime)
|
||||
@ -124,10 +127,12 @@ async def run_session(
|
||||
|
||||
usage_metrics = UsageMetrics()
|
||||
|
||||
async def prompt_for_next_task():
|
||||
async def prompt_for_next_task(agent_state: str):
|
||||
nonlocal reload_microagents, new_session_requested
|
||||
while True:
|
||||
next_message = await read_prompt_input(config.cli_multiline_input)
|
||||
next_message = await read_prompt_input(
|
||||
agent_state, multiline=config.cli_multiline_input
|
||||
)
|
||||
|
||||
if not next_message.strip():
|
||||
continue
|
||||
@ -150,14 +155,23 @@ async def run_session(
|
||||
return
|
||||
|
||||
async def on_event_async(event: Event) -> None:
|
||||
nonlocal reload_microagents
|
||||
nonlocal reload_microagents, is_paused
|
||||
display_event(event, config)
|
||||
update_usage_metrics(event, usage_metrics)
|
||||
|
||||
# Pause the agent if the pause event is set (if Ctrl-P is pressed)
|
||||
if is_paused.is_set():
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.USER,
|
||||
)
|
||||
is_paused.clear()
|
||||
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
AgentState.PAUSED,
|
||||
]:
|
||||
# Reload microagents after initialization of repo.md
|
||||
if reload_microagents:
|
||||
@ -166,20 +180,28 @@ async def run_session(
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
reload_microagents = False
|
||||
await prompt_for_next_task()
|
||||
await prompt_for_next_task(event.agent_state)
|
||||
|
||||
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||
user_confirmed = await read_confirmation_input()
|
||||
if user_confirmed:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
|
||||
EventSource.USER,
|
||||
)
|
||||
else:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_REJECTED),
|
||||
EventSource.USER,
|
||||
)
|
||||
# Only display the confirmation prompt if the agent is not paused
|
||||
if not is_paused.is_set():
|
||||
user_confirmed = await read_confirmation_input()
|
||||
if user_confirmed:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
|
||||
EventSource.USER,
|
||||
)
|
||||
else:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_REJECTED),
|
||||
EventSource.USER,
|
||||
)
|
||||
|
||||
if event.agent_state == AgentState.RUNNING:
|
||||
# Enable pause/resume functionality only if the confirmation mode is disabled
|
||||
if not config.security.confirmation_mode:
|
||||
display_agent_running_message()
|
||||
loop.create_task(process_agent_pause(is_paused))
|
||||
|
||||
def on_event(event: Event) -> None:
|
||||
loop.create_task(on_event_async(event))
|
||||
@ -212,7 +234,7 @@ async def run_session(
|
||||
clear()
|
||||
|
||||
# Show OpenHands banner and session ID
|
||||
display_banner(session_id=sid, is_loaded=is_loaded)
|
||||
display_banner(session_id=sid)
|
||||
|
||||
# Show OpenHands welcome
|
||||
display_welcome_message()
|
||||
@ -225,7 +247,7 @@ async def run_session(
|
||||
)
|
||||
else:
|
||||
# Otherwise prompt for the user's first message right away
|
||||
asyncio.create_task(prompt_for_next_task())
|
||||
asyncio.create_task(prompt_for_next_task(''))
|
||||
|
||||
await run_agent_until_done(
|
||||
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
|
||||
|
||||
@ -70,6 +70,8 @@ async def handle_commands(
|
||||
)
|
||||
elif command == '/settings':
|
||||
await handle_settings_command(config, settings_store)
|
||||
elif command == '/resume':
|
||||
close_repl, new_session_requested = await handle_resume_command(event_stream)
|
||||
else:
|
||||
close_repl = True
|
||||
action = MessageAction(content=command)
|
||||
@ -183,6 +185,28 @@ async def handle_settings_command(
|
||||
await modify_llm_settings_advanced(config, settings_store)
|
||||
|
||||
|
||||
# FIXME: Currently there's an issue with the actual 'resume' behavior.
|
||||
# Setting the agent state to RUNNING will currently freeze the agent without continuing with the rest of the task.
|
||||
# This is a workaround to handle the resume command for the time being. Replace user message with the state change event once the issue is fixed.
|
||||
async def handle_resume_command(
|
||||
event_stream: EventStream,
|
||||
) -> tuple[bool, bool]:
|
||||
close_repl = True
|
||||
new_session_requested = False
|
||||
|
||||
event_stream.add_event(
|
||||
MessageAction(content='continue'),
|
||||
EventSource.USER,
|
||||
)
|
||||
|
||||
# event_stream.add_event(
|
||||
# ChangeAgentStateAction(AgentState.RUNNING),
|
||||
# EventSource.ENVIRONMENT,
|
||||
# )
|
||||
|
||||
return close_repl, new_session_requested
|
||||
|
||||
|
||||
async def init_repository(current_dir: str) -> bool:
|
||||
repo_file_path = Path(current_dir) / '.openhands' / 'microagents' / 'repo.md'
|
||||
init_repo = False
|
||||
|
||||
@ -10,7 +10,9 @@ from prompt_toolkit import PromptSession, print_formatted_text
|
||||
from prompt_toolkit.application import Application
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
from prompt_toolkit.formatted_text import HTML, FormattedText, StyleAndTextTuples
|
||||
from prompt_toolkit.input import create_input
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.keys import Keys
|
||||
from prompt_toolkit.layout.containers import HSplit, Window
|
||||
from prompt_toolkit.layout.controls import FormattedTextControl
|
||||
from prompt_toolkit.layout.layout import Layout
|
||||
@ -22,6 +24,7 @@ from prompt_toolkit.widgets import Frame, TextArea
|
||||
|
||||
from openhands import __version__
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
@ -32,6 +35,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
@ -56,6 +60,7 @@ COMMANDS = {
|
||||
'/status': 'Display session details and usage metrics',
|
||||
'/new': 'Create a new session',
|
||||
'/settings': 'Display and modify current settings',
|
||||
'/resume': 'Resume the agent',
|
||||
}
|
||||
|
||||
|
||||
@ -114,7 +119,7 @@ def display_initialization_animation(text, is_loaded: asyncio.Event):
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def display_banner(session_id: str, is_loaded: asyncio.Event):
|
||||
def display_banner(session_id: str):
|
||||
print_formatted_text(
|
||||
HTML(r"""<gold>
|
||||
___ _ _ _
|
||||
@ -129,11 +134,8 @@ def display_banner(session_id: str, is_loaded: asyncio.Event):
|
||||
|
||||
print_formatted_text(HTML(f'<grey>OpenHands CLI v{__version__}</grey>'))
|
||||
|
||||
banner_text = (
|
||||
'Initialized session' if is_loaded.is_set() else 'Initializing session'
|
||||
)
|
||||
print_formatted_text('')
|
||||
print_formatted_text(HTML(f'<grey>{banner_text} {session_id}</grey>'))
|
||||
print_formatted_text(HTML(f'<grey>Initialized session {session_id}</grey>'))
|
||||
print_formatted_text('')
|
||||
|
||||
|
||||
@ -177,6 +179,8 @@ def display_event(event: Event, config: AppConfig) -> None:
|
||||
display_file_edit(event)
|
||||
if isinstance(event, FileReadObservation):
|
||||
display_file_read(event)
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
display_agent_paused_message(event.agent_state)
|
||||
|
||||
|
||||
def display_message(message: str):
|
||||
@ -389,77 +393,58 @@ def display_status(usage_metrics: UsageMetrics, session_id: str):
|
||||
display_usage_metrics(usage_metrics)
|
||||
|
||||
|
||||
def display_agent_running_message():
|
||||
print_formatted_text('')
|
||||
print_formatted_text(
|
||||
HTML('<gold>Agent running...</gold> <grey>(Ctrl-P to pause)</grey>')
|
||||
)
|
||||
|
||||
|
||||
def display_agent_paused_message(agent_state: str):
|
||||
if agent_state != AgentState.PAUSED:
|
||||
return
|
||||
print_formatted_text('')
|
||||
print_formatted_text(
|
||||
HTML('<gold>Agent paused</gold> <grey>(type /resume to resume)</grey>')
|
||||
)
|
||||
|
||||
|
||||
# Common input functions
|
||||
class CommandCompleter(Completer):
|
||||
"""Custom completer for commands."""
|
||||
|
||||
def __init__(self, agent_state: str):
|
||||
super().__init__()
|
||||
self.agent_state = agent_state
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text
|
||||
|
||||
# Only show completions if the user has typed '/'
|
||||
text = document.text_before_cursor.lstrip()
|
||||
if text.startswith('/'):
|
||||
# If just '/' is typed, show all commands
|
||||
if text == '/':
|
||||
for command, description in COMMANDS.items():
|
||||
available_commands = dict(COMMANDS)
|
||||
if self.agent_state != AgentState.PAUSED:
|
||||
available_commands.pop('/resume', None)
|
||||
|
||||
for command, description in available_commands.items():
|
||||
if command.startswith(text):
|
||||
yield Completion(
|
||||
command[1:], # Remove the leading '/' as it's already typed
|
||||
start_position=0,
|
||||
display=f'{command} - {description}',
|
||||
command,
|
||||
start_position=-len(text),
|
||||
display_meta=description,
|
||||
style='bg:ansidarkgray fg:ansiwhite',
|
||||
)
|
||||
# Otherwise show matching commands
|
||||
else:
|
||||
for command, description in COMMANDS.items():
|
||||
if command.startswith(text):
|
||||
yield Completion(
|
||||
command[len(text) :], # Complete the remaining part
|
||||
start_position=0,
|
||||
display=f'{command} - {description}',
|
||||
)
|
||||
|
||||
|
||||
prompt_session = PromptSession(style=DEFAULT_STYLE)
|
||||
|
||||
# RPrompt animation related variables
|
||||
SPINNER_FRAMES = [
|
||||
'[ ■□□□ ]',
|
||||
'[ □■□□ ]',
|
||||
'[ □□■□ ]',
|
||||
'[ □□□■ ]',
|
||||
'[ □□■□ ]',
|
||||
'[ □■□□ ]',
|
||||
]
|
||||
ANIMATION_INTERVAL = 0.2 # seconds
|
||||
|
||||
current_frame_index = 0
|
||||
last_update_time = time.monotonic()
|
||||
def create_prompt_session():
|
||||
return PromptSession(style=DEFAULT_STYLE)
|
||||
|
||||
|
||||
# RPrompt function for the user confirmation
|
||||
def get_rprompt() -> FormattedText:
|
||||
"""
|
||||
Returns the current animation frame for the rprompt.
|
||||
This function is called by prompt_toolkit during rendering.
|
||||
"""
|
||||
global current_frame_index, last_update_time
|
||||
|
||||
# Only update the frame if enough time has passed
|
||||
# This prevents excessive recalculation during rendering
|
||||
now = time.monotonic()
|
||||
if now - last_update_time > ANIMATION_INTERVAL:
|
||||
current_frame_index = (current_frame_index + 1) % len(SPINNER_FRAMES)
|
||||
last_update_time = now
|
||||
|
||||
# Return the frame wrapped in FormattedText
|
||||
return FormattedText(
|
||||
[
|
||||
('', ' '), # Add a space before the spinner
|
||||
(COLOR_GOLD, SPINNER_FRAMES[current_frame_index]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def read_prompt_input(multiline=False):
|
||||
async def read_prompt_input(agent_state: str, multiline=False):
|
||||
try:
|
||||
prompt_session = create_prompt_session()
|
||||
prompt_session.completer = (
|
||||
CommandCompleter(agent_state) if not multiline else None
|
||||
)
|
||||
|
||||
if multiline:
|
||||
kb = KeyBindings()
|
||||
|
||||
@ -470,38 +455,54 @@ async def read_prompt_input(multiline=False):
|
||||
with patch_stdout():
|
||||
print_formatted_text('')
|
||||
message = await prompt_session.prompt_async(
|
||||
'Enter your message and press Ctrl+D to finish:\n',
|
||||
HTML(
|
||||
'<gold>Enter your message and press Ctrl-D to finish:</gold>\n'
|
||||
),
|
||||
multiline=True,
|
||||
key_bindings=kb,
|
||||
)
|
||||
else:
|
||||
with patch_stdout():
|
||||
print_formatted_text('')
|
||||
prompt_session.completer = CommandCompleter()
|
||||
message = await prompt_session.prompt_async(
|
||||
'> ',
|
||||
HTML('<gold>> </gold>'),
|
||||
)
|
||||
return message
|
||||
return message if message is not None else ''
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return '/exit'
|
||||
|
||||
|
||||
async def read_confirmation_input():
|
||||
async def read_confirmation_input() -> bool:
|
||||
try:
|
||||
prompt_session = create_prompt_session()
|
||||
|
||||
with patch_stdout():
|
||||
prompt_session.completer = None
|
||||
confirmation = await prompt_session.prompt_async(
|
||||
'Proceed with action? (y)es/(n)o > ',
|
||||
rprompt=get_rprompt,
|
||||
refresh_interval=ANIMATION_INTERVAL / 2,
|
||||
print_formatted_text('')
|
||||
confirmation: str = await prompt_session.prompt_async(
|
||||
HTML('<gold>Proceed with action? (y)es/(n)o > </gold>'),
|
||||
)
|
||||
prompt_session.rprompt = None
|
||||
confirmation = confirmation.strip().lower()
|
||||
|
||||
confirmation = '' if confirmation is None else confirmation.strip().lower()
|
||||
return confirmation in ['y', 'yes']
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return False
|
||||
|
||||
|
||||
async def process_agent_pause(done: asyncio.Event) -> None:
|
||||
input = create_input()
|
||||
|
||||
def keys_ready():
|
||||
for key_press in input.read_keys():
|
||||
if key_press.key == Keys.ControlP:
|
||||
print_formatted_text('')
|
||||
print_formatted_text(HTML('<gold>Pausing the agent...</gold>'))
|
||||
done.set()
|
||||
|
||||
with input.raw_mode():
|
||||
with input.attach(keys_ready):
|
||||
await done.wait()
|
||||
|
||||
|
||||
def cli_confirm(
|
||||
question: str = 'Are you sure?', choices: list[str] | None = None
|
||||
) -> int:
|
||||
|
||||
@ -8,6 +8,7 @@ from openhands.core.cli_commands import (
|
||||
handle_help_command,
|
||||
handle_init_command,
|
||||
handle_new_command,
|
||||
handle_resume_command,
|
||||
handle_settings_command,
|
||||
handle_status_command,
|
||||
)
|
||||
@ -461,3 +462,27 @@ class TestHandleSettingsCommand:
|
||||
# Verify correct behavior
|
||||
mock_display_settings.assert_called_once_with(config)
|
||||
mock_cli_confirm.assert_called_once()
|
||||
|
||||
|
||||
class TestHandleResumeCommand:
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_resume_command(self):
|
||||
"""Test that handle_resume_command adds the 'continue' message to the event stream."""
|
||||
# Create a mock event stream
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
|
||||
# Call the function
|
||||
close_repl, new_session_requested = await handle_resume_command(event_stream)
|
||||
|
||||
# Check that the event stream add_event was called with the correct message action
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
message_action, source = args
|
||||
|
||||
assert isinstance(message_action, MessageAction)
|
||||
assert message_action.content == 'continue'
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check the return values
|
||||
assert close_repl is True
|
||||
assert new_session_requested is False
|
||||
|
||||
325
tests/unit/test_cli_pause_resume.py
Normal file
325
tests/unit/test_cli_pause_resume.py
Normal file
@ -0,0 +1,325 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.keys import Keys
|
||||
|
||||
from openhands.core.cli_tui import process_agent_pause
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import ChangeAgentStateAction
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.events.stream import EventStream
|
||||
|
||||
|
||||
class TestProcessAgentPause:
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.core.cli_tui.create_input')
|
||||
@patch('openhands.core.cli_tui.print_formatted_text')
|
||||
async def test_process_agent_pause_ctrl_p(self, mock_print, mock_create_input):
|
||||
"""Test that process_agent_pause sets the done event when Ctrl+P is pressed."""
|
||||
# Create the done event
|
||||
done = asyncio.Event()
|
||||
|
||||
# Set up the mock input
|
||||
mock_input = MagicMock()
|
||||
mock_create_input.return_value = mock_input
|
||||
|
||||
# Mock the context managers
|
||||
mock_raw_mode = MagicMock()
|
||||
mock_input.raw_mode.return_value = mock_raw_mode
|
||||
mock_raw_mode.__enter__ = MagicMock()
|
||||
mock_raw_mode.__exit__ = MagicMock()
|
||||
|
||||
mock_attach = MagicMock()
|
||||
mock_input.attach.return_value = mock_attach
|
||||
mock_attach.__enter__ = MagicMock()
|
||||
mock_attach.__exit__ = MagicMock()
|
||||
|
||||
# Capture the keys_ready function
|
||||
keys_ready_func = None
|
||||
|
||||
def fake_attach(callback):
|
||||
nonlocal keys_ready_func
|
||||
keys_ready_func = callback
|
||||
return mock_attach
|
||||
|
||||
mock_input.attach.side_effect = fake_attach
|
||||
|
||||
# Create a task to run process_agent_pause
|
||||
task = asyncio.create_task(process_agent_pause(done))
|
||||
|
||||
# Give it a moment to start and capture the callback
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Make sure we captured the callback
|
||||
assert keys_ready_func is not None
|
||||
|
||||
# Create a key press that simulates Ctrl+P
|
||||
key_press = MagicMock()
|
||||
key_press.key = Keys.ControlP
|
||||
mock_input.read_keys.return_value = [key_press]
|
||||
|
||||
# Manually call the callback to simulate key press
|
||||
keys_ready_func()
|
||||
|
||||
# Verify done was set
|
||||
assert done.is_set()
|
||||
|
||||
# Verify print was called with the pause message
|
||||
assert mock_print.call_count == 2
|
||||
assert mock_print.call_args_list[0] == call('')
|
||||
|
||||
# Check that the second call contains the pause message HTML
|
||||
second_call = mock_print.call_args_list[1][0][0]
|
||||
assert isinstance(second_call, HTML)
|
||||
assert 'Pausing the agent' in str(second_call)
|
||||
|
||||
# Cancel the task
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class TestCliPauseResumeInRunSession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_async_pause_processing(self):
|
||||
"""Test that on_event_async processes the pause event when is_paused is set."""
|
||||
# Create a mock event
|
||||
event = MagicMock()
|
||||
|
||||
# Create mock dependencies
|
||||
event_stream = MagicMock()
|
||||
is_paused = asyncio.Event()
|
||||
reload_microagents = False
|
||||
config = MagicMock()
|
||||
|
||||
# Patch the display_event function
|
||||
with patch('openhands.core.cli.display_event') as mock_display_event, patch(
|
||||
'openhands.core.cli.update_usage_metrics'
|
||||
) as mock_update_metrics:
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Set the pause event
|
||||
is_paused.set()
|
||||
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
# We're creating a function that mimics the environment of on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal reload_microagents, is_paused
|
||||
mock_display_event(event, config)
|
||||
mock_update_metrics(event, usage_metrics=MagicMock())
|
||||
|
||||
# Pause the agent if the pause event is set (through Ctrl-P)
|
||||
if is_paused.is_set():
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.USER,
|
||||
)
|
||||
is_paused.clear()
|
||||
|
||||
# Call our test function
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that the event_stream.add_event was called with the correct action
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
action, source = args
|
||||
|
||||
assert isinstance(action, ChangeAgentStateAction)
|
||||
assert action.agent_state == AgentState.PAUSED
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check that is_paused was cleared
|
||||
assert not is_paused.is_set()
|
||||
|
||||
# Run the test function
|
||||
await test_func()
|
||||
|
||||
|
||||
class TestCliCommandsPauseResume:
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.core.cli_commands.handle_resume_command')
|
||||
async def test_handle_commands_resume(self, mock_handle_resume):
|
||||
"""Test that the handle_commands function properly calls handle_resume_command."""
|
||||
# Import the handle_commands function
|
||||
from openhands.core.cli_commands import handle_commands
|
||||
|
||||
# Set up mocks
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
usage_metrics = MagicMock()
|
||||
sid = 'test-session-id'
|
||||
config = MagicMock()
|
||||
current_dir = '/test/dir'
|
||||
settings_store = MagicMock()
|
||||
|
||||
# Set the return value for handle_resume_command
|
||||
mock_handle_resume.return_value = (False, False)
|
||||
|
||||
# Call handle_commands with the resume command
|
||||
close_repl, reload_microagents, new_session_requested = await handle_commands(
|
||||
'/resume',
|
||||
event_stream,
|
||||
usage_metrics,
|
||||
sid,
|
||||
config,
|
||||
current_dir,
|
||||
settings_store,
|
||||
)
|
||||
|
||||
# Check that handle_resume_command was called with the correct arguments
|
||||
mock_handle_resume.assert_called_once_with(event_stream)
|
||||
|
||||
# Check the return values
|
||||
assert close_repl is False
|
||||
assert reload_microagents is False
|
||||
assert new_session_requested is False
|
||||
|
||||
|
||||
class TestAgentStatePauseResume:
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.core.cli.display_agent_running_message')
|
||||
@patch('openhands.core.cli.process_agent_pause')
|
||||
async def test_agent_running_enables_pause(
|
||||
self, mock_process_agent_pause, mock_display_message
|
||||
):
|
||||
"""Test that when the agent is running, pause functionality is enabled."""
|
||||
# Create mock dependencies
|
||||
event = MagicMock()
|
||||
# AgentStateChangedObservation requires a content parameter
|
||||
event.observation = AgentStateChangedObservation(
|
||||
agent_state=AgentState.RUNNING, content='Agent state changed to RUNNING'
|
||||
)
|
||||
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
loop = MagicMock()
|
||||
is_paused = asyncio.Event()
|
||||
config = MagicMock()
|
||||
config.security.confirmation_mode = False
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Call our simplified on_event_async
|
||||
async def on_event_async_test(event):
|
||||
if isinstance(event.observation, AgentStateChangedObservation):
|
||||
if event.observation.agent_state == AgentState.RUNNING:
|
||||
# Enable pause/resume functionality only if the confirmation mode is disabled
|
||||
if not config.security.confirmation_mode:
|
||||
mock_display_message()
|
||||
loop.create_task(mock_process_agent_pause(is_paused))
|
||||
|
||||
# Call the function
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that the message was displayed
|
||||
mock_display_message.assert_called_once()
|
||||
|
||||
# Check that process_agent_pause was called with the right arguments
|
||||
mock_process_agent_pause.assert_called_once_with(is_paused)
|
||||
|
||||
# Check that loop.create_task was called
|
||||
loop.create_task.assert_called_once()
|
||||
|
||||
# Run the test function
|
||||
await test_func()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.core.cli.display_event')
|
||||
@patch('openhands.core.cli.update_usage_metrics')
|
||||
async def test_pause_event_changes_agent_state(
|
||||
self, mock_update_metrics, mock_display_event
|
||||
):
|
||||
"""Test that when is_paused is set, a PAUSED state change event is added to the stream."""
|
||||
# Create mock dependencies
|
||||
event = MagicMock()
|
||||
event_stream = MagicMock()
|
||||
is_paused = asyncio.Event()
|
||||
config = MagicMock()
|
||||
reload_microagents = False
|
||||
|
||||
# Set the pause event
|
||||
is_paused.set()
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal reload_microagents
|
||||
mock_display_event(event, config)
|
||||
mock_update_metrics(event, MagicMock())
|
||||
|
||||
# Pause the agent if the pause event is set (through Ctrl-P)
|
||||
if is_paused.is_set():
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.USER,
|
||||
)
|
||||
is_paused.clear()
|
||||
|
||||
# Call the function
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that the event_stream.add_event was called with the correct action
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
action, source = args
|
||||
|
||||
assert isinstance(action, ChangeAgentStateAction)
|
||||
assert action.agent_state == AgentState.PAUSED
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check that is_paused was cleared
|
||||
assert not is_paused.is_set()
|
||||
|
||||
# Run the test
|
||||
await test_func()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paused_agent_awaits_input(self):
|
||||
"""Test that when the agent is paused, it awaits user input."""
|
||||
# Create mock dependencies
|
||||
event = MagicMock()
|
||||
# AgentStateChangedObservation requires a content parameter
|
||||
event.observation = AgentStateChangedObservation(
|
||||
agent_state=AgentState.PAUSED, content='Agent state changed to PAUSED'
|
||||
)
|
||||
reload_microagents = False
|
||||
memory = MagicMock()
|
||||
runtime = MagicMock()
|
||||
prompt_task = MagicMock()
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Create a simplified version of on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal reload_microagents, prompt_task
|
||||
|
||||
if isinstance(event.observation, AgentStateChangedObservation):
|
||||
if event.observation.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
AgentState.PAUSED,
|
||||
]:
|
||||
# Reload microagents after initialization of repo.md
|
||||
if reload_microagents:
|
||||
microagents = runtime.get_microagents_from_selected_repo(
|
||||
None
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
reload_microagents = False
|
||||
|
||||
# Since prompt_for_next_task is a nested function in cli.py,
|
||||
# we'll just check that we've reached this code path
|
||||
prompt_task = 'Prompt for next task would be called here'
|
||||
|
||||
# Call the function
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that we reached the code path where prompt_for_next_task would be called
|
||||
assert prompt_task == 'Prompt for next task would be called here'
|
||||
|
||||
# Run the test
|
||||
await test_func()
|
||||
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from openhands.core.cli_tui import (
|
||||
@ -52,12 +51,9 @@ class TestDisplayFunctions:
|
||||
|
||||
@patch('openhands.core.cli_tui.print_formatted_text')
|
||||
def test_display_banner(self, mock_print):
|
||||
# Create a mock loaded event
|
||||
is_loaded = asyncio.Event()
|
||||
is_loaded.set()
|
||||
session_id = 'test-session-id'
|
||||
|
||||
display_banner(session_id, is_loaded)
|
||||
display_banner(session_id)
|
||||
|
||||
# Verify banner calls
|
||||
assert mock_print.call_count >= 3
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user