refactor: Refactor pause/resume functionality and improve state handling in CLI (#8152)

This commit is contained in:
Bashwara Undupitiya
2025-05-02 03:04:35 -07:00
committed by GitHub
parent 03aa5d7456
commit 6e0fbfeeda
3 changed files with 176 additions and 86 deletions

View File

@@ -101,7 +101,7 @@ async def run_session(
sid = str(uuid4())
is_loaded = asyncio.Event()
is_paused = asyncio.Event()
is_paused = asyncio.Event() # Event to track agent pause requests
# Show runtime initialization message
display_runtime_initialization_message(config.runtime)
@@ -157,20 +157,15 @@ async def run_session(
display_event(event, config)
update_usage_metrics(event, usage_metrics)
# Pause the agent if the pause event is set (if Ctrl-P is pressed)
if is_paused.is_set():
event_stream.add_event(
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.USER,
)
is_paused.clear()
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
AgentState.PAUSED,
]:
# If the agent is paused, do not prompt for input as it's already handled by PAUSED state change
if is_paused.is_set():
return
# Reload microagents after initialization of repo.md
if reload_microagents:
microagents: list[BaseMicroagent] = (
@@ -181,25 +176,32 @@ async def run_session(
await prompt_for_next_task(event.agent_state)
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
# Only display the confirmation prompt if the agent is not paused
if not is_paused.is_set():
user_confirmed = await read_confirmation_input()
if user_confirmed:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
EventSource.USER,
)
else:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_REJECTED),
EventSource.USER,
)
# If the agent is paused, do not prompt for confirmation
# The confirmation step will re-run after the agent has been resumed
if is_paused.is_set():
return
user_confirmed = await read_confirmation_input()
if user_confirmed:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
EventSource.USER,
)
else:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_REJECTED),
EventSource.USER,
)
if event.agent_state == AgentState.PAUSED:
is_paused.clear() # Revert the event state before prompting for user input
await prompt_for_next_task(event.agent_state)
if event.agent_state == AgentState.RUNNING:
# Enable pause/resume functionality only if the confirmation mode is disabled
if not config.security.confirmation_mode:
display_agent_running_message()
loop.create_task(process_agent_pause(is_paused))
display_agent_running_message()
loop.create_task(
process_agent_pause(is_paused, event_stream)
) # Create a task to track agent pause requests from the user
def on_event(event: Event) -> None:
loop.create_task(on_event_async(event))

View File

@@ -25,10 +25,11 @@ from prompt_toolkit.widgets import Frame, TextArea
from openhands import __version__
from openhands.core.config import AppConfig
from openhands.core.schema import AgentState
from openhands.events import EventSource
from openhands.events import EventSource, EventStream
from openhands.events.action import (
Action,
ActionConfirmationStatus,
ChangeAgentStateAction,
CmdRunAction,
FileEditAction,
MessageAction,
@@ -60,7 +61,7 @@ COMMANDS = {
'/status': 'Display session details and usage metrics',
'/new': 'Create a new session',
'/settings': 'Display and modify current settings',
'/resume': 'Resume the agent',
'/resume': 'Resume the agent when paused',
}
@@ -396,7 +397,7 @@ def display_status(usage_metrics: UsageMetrics, session_id: str):
def display_agent_running_message():
print_formatted_text('')
print_formatted_text(
HTML('<gold>Agent running...</gold> <grey>(Ctrl-P to pause)</grey>')
HTML('<gold>Agent running...</gold> <grey>(Press Ctrl-P to pause)</grey>')
)
@@ -405,7 +406,7 @@ def display_agent_paused_message(agent_state: str):
return
print_formatted_text('')
print_formatted_text(
HTML('<gold>Agent paused</gold> <grey>(type /resume to resume)</grey>')
HTML('<gold>Agent paused...</gold> <grey>(Enter /resume to continue)</grey>')
)
@@ -430,7 +431,7 @@ class CommandCompleter(Completer):
command,
start_position=-len(text),
display_meta=description,
style='bg:ansidarkgray fg:ansiwhite',
style='bg:ansidarkgray fg:gold',
)
@@ -488,7 +489,7 @@ async def read_confirmation_input() -> bool:
return False
async def process_agent_pause(done: asyncio.Event) -> None:
async def process_agent_pause(done: asyncio.Event, event_stream: EventStream) -> None:
input = create_input()
def keys_ready():
@@ -496,6 +497,10 @@ async def process_agent_pause(done: asyncio.Event) -> None:
if key_press.key == Keys.ControlP:
print_formatted_text('')
print_formatted_text(HTML('<gold>Pausing the agent...</gold>'))
event_stream.add_event(
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.USER,
)
done.set()
with input.raw_mode():

View File

@@ -10,7 +10,6 @@ from openhands.core.schema import AgentState
from openhands.events import EventSource
from openhands.events.action import ChangeAgentStateAction
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.stream import EventStream
class TestProcessAgentPause:
@@ -48,7 +47,7 @@ class TestProcessAgentPause:
mock_input.attach.side_effect = fake_attach
# Create a task to run process_agent_pause
task = asyncio.create_task(process_agent_pause(done))
task = asyncio.create_task(process_agent_pause(done, event_stream=MagicMock()))
# Give it a moment to start and capture the callback
await asyncio.sleep(0.1)
@@ -119,12 +118,13 @@ class TestCliPauseResumeInRunSession:
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.USER,
)
is_paused.clear()
# The pause event is not cleared here because we want to simulate
# the PAUSED event processing in a future event
# Call our test function
# Call on_event_async_test
await on_event_async_test(event)
# Check that the event_stream.add_event was called with the correct action
# Check that event_stream.add_event was called with the correct action
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
action, source = args
@@ -133,35 +133,127 @@ class TestCliPauseResumeInRunSession:
assert action.agent_state == AgentState.PAUSED
assert source == EventSource.USER
# Check that is_paused was cleared
assert not is_paused.is_set()
# Check that is_paused is still set (will be cleared when PAUSED state is processed)
assert is_paused.is_set()
# Run the test function
await test_func()
@pytest.mark.asyncio
async def test_awaiting_user_input_paused_skip(self):
"""Test that when is_paused is set, awaiting user input events do not trigger prompting."""
# Create a mock event with AgentStateChangedObservation
event = MagicMock()
event.observation = AgentStateChangedObservation(
agent_state=AgentState.AWAITING_USER_INPUT, content='Agent awaiting input'
)
# Create mock dependencies
is_paused = asyncio.Event()
reload_microagents = False
# Mock function that would be called if code reaches that point
mock_prompt_task = MagicMock()
# Create a closure to capture the current context
async def test_func():
# Set the pause event
is_paused.set()
# Create a context similar to run_session to call on_event_async
async def on_event_async_test(event):
nonlocal reload_microagents, is_paused
if isinstance(event.observation, AgentStateChangedObservation):
if event.observation.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
]:
# If the agent is paused, do not prompt for input
if is_paused.is_set():
return
# This code should not be reached if is_paused is set
mock_prompt_task()
# Call on_event_async_test
await on_event_async_test(event)
# Verify that mock_prompt_task was not called
mock_prompt_task.assert_not_called()
# Run the test
await test_func()
@pytest.mark.asyncio
async def test_awaiting_confirmation_paused_skip(self):
"""Test that when is_paused is set, awaiting confirmation events do not trigger prompting."""
# Create a mock event with AgentStateChangedObservation
event = MagicMock()
event.observation = AgentStateChangedObservation(
agent_state=AgentState.AWAITING_USER_CONFIRMATION,
content='Agent awaiting confirmation',
)
# Create mock dependencies
is_paused = asyncio.Event()
# Mock function that would be called if code reaches that point
mock_confirmation = MagicMock()
# Create a closure to capture the current context
async def test_func():
# Set the pause event
is_paused.set()
# Create a context similar to run_session to call on_event_async
async def on_event_async_test(event):
nonlocal is_paused
if isinstance(event.observation, AgentStateChangedObservation):
if (
event.observation.agent_state
== AgentState.AWAITING_USER_CONFIRMATION
):
if is_paused.is_set():
return
# This code should not be reached if is_paused is set
mock_confirmation()
# Call on_event_async_test
await on_event_async_test(event)
# Verify that confirmation function was not called
mock_confirmation.assert_not_called()
# Run the test
await test_func()
class TestCliCommandsPauseResume:
@pytest.mark.asyncio
@patch('openhands.core.cli_commands.handle_resume_command')
async def test_handle_commands_resume(self, mock_handle_resume):
"""Test that the handle_commands function properly calls handle_resume_command."""
# Import the handle_commands function
# Import here to avoid circular imports in test
from openhands.core.cli_commands import handle_commands
# Set up mocks
event_stream = MagicMock(spec=EventStream)
# Create mocks
message = '/resume'
event_stream = MagicMock()
usage_metrics = MagicMock()
sid = 'test-session-id'
config = MagicMock()
current_dir = '/test/dir'
settings_store = MagicMock()
# Set the return value for handle_resume_command
# Mock return value
mock_handle_resume.return_value = (False, False)
# Call handle_commands with the resume command
# Call handle_commands
close_repl, reload_microagents, new_session_requested = await handle_commands(
'/resume',
message,
event_stream,
usage_metrics,
sid,
@@ -170,7 +262,7 @@ class TestCliCommandsPauseResume:
settings_store,
)
# Check that handle_resume_command was called with the correct arguments
# Check that handle_resume_command was called with correct args
mock_handle_resume.assert_called_once_with(event_stream)
# Check the return values
@@ -187,39 +279,37 @@ class TestAgentStatePauseResume:
self, mock_process_agent_pause, mock_display_message
):
"""Test that when the agent is running, pause functionality is enabled."""
# Create mock dependencies
# Create a mock event and event stream
event = MagicMock()
# AgentStateChangedObservation requires a content parameter
event.observation = AgentStateChangedObservation(
agent_state=AgentState.RUNNING, content='Agent state changed to RUNNING'
agent_state=AgentState.RUNNING, content='Agent is running'
)
event_stream = MagicMock()
# Create a context similar to run_session to call on_event_async
loop = MagicMock()
# Create mock dependencies
is_paused = asyncio.Event()
config = MagicMock()
config.security.confirmation_mode = False
loop = MagicMock()
reload_microagents = False
# Create a closure to capture the current context
async def test_func():
# Call our simplified on_event_async
# Create a context similar to run_session to call on_event_async
async def on_event_async_test(event):
nonlocal reload_microagents
if isinstance(event.observation, AgentStateChangedObservation):
if event.observation.agent_state == AgentState.RUNNING:
# Enable pause/resume functionality only if the confirmation mode is disabled
if not config.security.confirmation_mode:
mock_display_message()
loop.create_task(mock_process_agent_pause(is_paused))
mock_display_message()
loop.create_task(
mock_process_agent_pause(is_paused, event_stream)
)
# Call the function
# Call on_event_async_test
await on_event_async_test(event)
# Check that the message was displayed
# Check that display_agent_running_message was called
mock_display_message.assert_called_once()
# Check that process_agent_pause was called with the right arguments
mock_process_agent_pause.assert_called_once_with(is_paused)
# Check that loop.create_task was called
loop.create_task.assert_called_once()
@@ -286,40 +376,33 @@ class TestAgentStatePauseResume:
event.observation = AgentStateChangedObservation(
agent_state=AgentState.PAUSED, content='Agent state changed to PAUSED'
)
reload_microagents = False
memory = MagicMock()
runtime = MagicMock()
prompt_task = MagicMock()
is_paused = asyncio.Event()
# Mock function that would be called for prompting
mock_prompt_task = MagicMock()
# Create a closure to capture the current context
async def test_func():
# Create a simplified version of on_event_async
async def on_event_async_test(event):
nonlocal reload_microagents, prompt_task
nonlocal is_paused
if isinstance(event.observation, AgentStateChangedObservation):
if event.observation.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
AgentState.PAUSED,
]:
# Reload microagents after initialization of repo.md
if reload_microagents:
microagents = runtime.get_microagents_from_selected_repo(
None
)
memory.load_user_workspace_microagents(microagents)
reload_microagents = False
if event.observation.agent_state == AgentState.PAUSED:
is_paused.clear() # Revert the event state before prompting for user input
mock_prompt_task(event.observation.agent_state)
# Since prompt_for_next_task is a nested function in cli.py,
# we'll just check that we've reached this code path
prompt_task = 'Prompt for next task would be called here'
# Set is_paused to test that it gets cleared
is_paused.set()
# Call the function
await on_event_async_test(event)
# Check that we reached the code path where prompt_for_next_task would be called
assert prompt_task == 'Prompt for next task would be called here'
# Check that is_paused was cleared
assert not is_paused.is_set()
# Check that prompt task was called with the correct state
mock_prompt_task.assert_called_once_with(AgentState.PAUSED)
# Run the test
await test_func()