mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
refactor: Refactor pause/resume functionality and improve state handling in CLI (#8152)
This commit is contained in:
committed by
GitHub
parent
03aa5d7456
commit
6e0fbfeeda
@@ -101,7 +101,7 @@ async def run_session(
|
||||
|
||||
sid = str(uuid4())
|
||||
is_loaded = asyncio.Event()
|
||||
is_paused = asyncio.Event()
|
||||
is_paused = asyncio.Event() # Event to track agent pause requests
|
||||
|
||||
# Show runtime initialization message
|
||||
display_runtime_initialization_message(config.runtime)
|
||||
@@ -157,20 +157,15 @@ async def run_session(
|
||||
display_event(event, config)
|
||||
update_usage_metrics(event, usage_metrics)
|
||||
|
||||
# Pause the agent if the pause event is set (if Ctrl-P is pressed)
|
||||
if is_paused.is_set():
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.USER,
|
||||
)
|
||||
is_paused.clear()
|
||||
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
AgentState.PAUSED,
|
||||
]:
|
||||
# If the agent is paused, do not prompt for input as it's already handled by PAUSED state change
|
||||
if is_paused.is_set():
|
||||
return
|
||||
|
||||
# Reload microagents after initialization of repo.md
|
||||
if reload_microagents:
|
||||
microagents: list[BaseMicroagent] = (
|
||||
@@ -181,25 +176,32 @@ async def run_session(
|
||||
await prompt_for_next_task(event.agent_state)
|
||||
|
||||
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||
# Only display the confirmation prompt if the agent is not paused
|
||||
if not is_paused.is_set():
|
||||
user_confirmed = await read_confirmation_input()
|
||||
if user_confirmed:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
|
||||
EventSource.USER,
|
||||
)
|
||||
else:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_REJECTED),
|
||||
EventSource.USER,
|
||||
)
|
||||
# If the agent is paused, do not prompt for confirmation
|
||||
# The confirmation step will re-run after the agent has been resumed
|
||||
if is_paused.is_set():
|
||||
return
|
||||
|
||||
user_confirmed = await read_confirmation_input()
|
||||
if user_confirmed:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
|
||||
EventSource.USER,
|
||||
)
|
||||
else:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_REJECTED),
|
||||
EventSource.USER,
|
||||
)
|
||||
|
||||
if event.agent_state == AgentState.PAUSED:
|
||||
is_paused.clear() # Revert the event state before prompting for user input
|
||||
await prompt_for_next_task(event.agent_state)
|
||||
|
||||
if event.agent_state == AgentState.RUNNING:
|
||||
# Enable pause/resume functionality only if the confirmation mode is disabled
|
||||
if not config.security.confirmation_mode:
|
||||
display_agent_running_message()
|
||||
loop.create_task(process_agent_pause(is_paused))
|
||||
display_agent_running_message()
|
||||
loop.create_task(
|
||||
process_agent_pause(is_paused, event_stream)
|
||||
) # Create a task to track agent pause requests from the user
|
||||
|
||||
def on_event(event: Event) -> None:
|
||||
loop.create_task(on_event_async(event))
|
||||
|
||||
@@ -25,10 +25,11 @@ from prompt_toolkit.widgets import Frame, TextArea
|
||||
from openhands import __version__
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource
|
||||
from openhands.events import EventSource, EventStream
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
ChangeAgentStateAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
MessageAction,
|
||||
@@ -60,7 +61,7 @@ COMMANDS = {
|
||||
'/status': 'Display session details and usage metrics',
|
||||
'/new': 'Create a new session',
|
||||
'/settings': 'Display and modify current settings',
|
||||
'/resume': 'Resume the agent',
|
||||
'/resume': 'Resume the agent when paused',
|
||||
}
|
||||
|
||||
|
||||
@@ -396,7 +397,7 @@ def display_status(usage_metrics: UsageMetrics, session_id: str):
|
||||
def display_agent_running_message():
|
||||
print_formatted_text('')
|
||||
print_formatted_text(
|
||||
HTML('<gold>Agent running...</gold> <grey>(Ctrl-P to pause)</grey>')
|
||||
HTML('<gold>Agent running...</gold> <grey>(Press Ctrl-P to pause)</grey>')
|
||||
)
|
||||
|
||||
|
||||
@@ -405,7 +406,7 @@ def display_agent_paused_message(agent_state: str):
|
||||
return
|
||||
print_formatted_text('')
|
||||
print_formatted_text(
|
||||
HTML('<gold>Agent paused</gold> <grey>(type /resume to resume)</grey>')
|
||||
HTML('<gold>Agent paused...</gold> <grey>(Enter /resume to continue)</grey>')
|
||||
)
|
||||
|
||||
|
||||
@@ -430,7 +431,7 @@ class CommandCompleter(Completer):
|
||||
command,
|
||||
start_position=-len(text),
|
||||
display_meta=description,
|
||||
style='bg:ansidarkgray fg:ansiwhite',
|
||||
style='bg:ansidarkgray fg:gold',
|
||||
)
|
||||
|
||||
|
||||
@@ -488,7 +489,7 @@ async def read_confirmation_input() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def process_agent_pause(done: asyncio.Event) -> None:
|
||||
async def process_agent_pause(done: asyncio.Event, event_stream: EventStream) -> None:
|
||||
input = create_input()
|
||||
|
||||
def keys_ready():
|
||||
@@ -496,6 +497,10 @@ async def process_agent_pause(done: asyncio.Event) -> None:
|
||||
if key_press.key == Keys.ControlP:
|
||||
print_formatted_text('')
|
||||
print_formatted_text(HTML('<gold>Pausing the agent...</gold>'))
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.USER,
|
||||
)
|
||||
done.set()
|
||||
|
||||
with input.raw_mode():
|
||||
|
||||
@@ -10,7 +10,6 @@ from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import ChangeAgentStateAction
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.events.stream import EventStream
|
||||
|
||||
|
||||
class TestProcessAgentPause:
|
||||
@@ -48,7 +47,7 @@ class TestProcessAgentPause:
|
||||
mock_input.attach.side_effect = fake_attach
|
||||
|
||||
# Create a task to run process_agent_pause
|
||||
task = asyncio.create_task(process_agent_pause(done))
|
||||
task = asyncio.create_task(process_agent_pause(done, event_stream=MagicMock()))
|
||||
|
||||
# Give it a moment to start and capture the callback
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -119,12 +118,13 @@ class TestCliPauseResumeInRunSession:
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.USER,
|
||||
)
|
||||
is_paused.clear()
|
||||
# The pause event is not cleared here because we want to simulate
|
||||
# the PAUSED event processing in a future event
|
||||
|
||||
# Call our test function
|
||||
# Call on_event_async_test
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that the event_stream.add_event was called with the correct action
|
||||
# Check that event_stream.add_event was called with the correct action
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
action, source = args
|
||||
@@ -133,35 +133,127 @@ class TestCliPauseResumeInRunSession:
|
||||
assert action.agent_state == AgentState.PAUSED
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check that is_paused was cleared
|
||||
assert not is_paused.is_set()
|
||||
# Check that is_paused is still set (will be cleared when PAUSED state is processed)
|
||||
assert is_paused.is_set()
|
||||
|
||||
# Run the test function
|
||||
await test_func()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_awaiting_user_input_paused_skip(self):
|
||||
"""Test that when is_paused is set, awaiting user input events do not trigger prompting."""
|
||||
# Create a mock event with AgentStateChangedObservation
|
||||
event = MagicMock()
|
||||
event.observation = AgentStateChangedObservation(
|
||||
agent_state=AgentState.AWAITING_USER_INPUT, content='Agent awaiting input'
|
||||
)
|
||||
|
||||
# Create mock dependencies
|
||||
is_paused = asyncio.Event()
|
||||
reload_microagents = False
|
||||
|
||||
# Mock function that would be called if code reaches that point
|
||||
mock_prompt_task = MagicMock()
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Set the pause event
|
||||
is_paused.set()
|
||||
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal reload_microagents, is_paused
|
||||
|
||||
if isinstance(event.observation, AgentStateChangedObservation):
|
||||
if event.observation.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
]:
|
||||
# If the agent is paused, do not prompt for input
|
||||
if is_paused.is_set():
|
||||
return
|
||||
|
||||
# This code should not be reached if is_paused is set
|
||||
mock_prompt_task()
|
||||
|
||||
# Call on_event_async_test
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Verify that mock_prompt_task was not called
|
||||
mock_prompt_task.assert_not_called()
|
||||
|
||||
# Run the test
|
||||
await test_func()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_awaiting_confirmation_paused_skip(self):
|
||||
"""Test that when is_paused is set, awaiting confirmation events do not trigger prompting."""
|
||||
# Create a mock event with AgentStateChangedObservation
|
||||
event = MagicMock()
|
||||
event.observation = AgentStateChangedObservation(
|
||||
agent_state=AgentState.AWAITING_USER_CONFIRMATION,
|
||||
content='Agent awaiting confirmation',
|
||||
)
|
||||
|
||||
# Create mock dependencies
|
||||
is_paused = asyncio.Event()
|
||||
|
||||
# Mock function that would be called if code reaches that point
|
||||
mock_confirmation = MagicMock()
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Set the pause event
|
||||
is_paused.set()
|
||||
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal is_paused
|
||||
|
||||
if isinstance(event.observation, AgentStateChangedObservation):
|
||||
if (
|
||||
event.observation.agent_state
|
||||
== AgentState.AWAITING_USER_CONFIRMATION
|
||||
):
|
||||
if is_paused.is_set():
|
||||
return
|
||||
|
||||
# This code should not be reached if is_paused is set
|
||||
mock_confirmation()
|
||||
|
||||
# Call on_event_async_test
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Verify that confirmation function was not called
|
||||
mock_confirmation.assert_not_called()
|
||||
|
||||
# Run the test
|
||||
await test_func()
|
||||
|
||||
|
||||
class TestCliCommandsPauseResume:
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.core.cli_commands.handle_resume_command')
|
||||
async def test_handle_commands_resume(self, mock_handle_resume):
|
||||
"""Test that the handle_commands function properly calls handle_resume_command."""
|
||||
# Import the handle_commands function
|
||||
# Import here to avoid circular imports in test
|
||||
from openhands.core.cli_commands import handle_commands
|
||||
|
||||
# Set up mocks
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
# Create mocks
|
||||
message = '/resume'
|
||||
event_stream = MagicMock()
|
||||
usage_metrics = MagicMock()
|
||||
sid = 'test-session-id'
|
||||
config = MagicMock()
|
||||
current_dir = '/test/dir'
|
||||
settings_store = MagicMock()
|
||||
|
||||
# Set the return value for handle_resume_command
|
||||
# Mock return value
|
||||
mock_handle_resume.return_value = (False, False)
|
||||
|
||||
# Call handle_commands with the resume command
|
||||
# Call handle_commands
|
||||
close_repl, reload_microagents, new_session_requested = await handle_commands(
|
||||
'/resume',
|
||||
message,
|
||||
event_stream,
|
||||
usage_metrics,
|
||||
sid,
|
||||
@@ -170,7 +262,7 @@ class TestCliCommandsPauseResume:
|
||||
settings_store,
|
||||
)
|
||||
|
||||
# Check that handle_resume_command was called with the correct arguments
|
||||
# Check that handle_resume_command was called with correct args
|
||||
mock_handle_resume.assert_called_once_with(event_stream)
|
||||
|
||||
# Check the return values
|
||||
@@ -187,39 +279,37 @@ class TestAgentStatePauseResume:
|
||||
self, mock_process_agent_pause, mock_display_message
|
||||
):
|
||||
"""Test that when the agent is running, pause functionality is enabled."""
|
||||
# Create mock dependencies
|
||||
# Create a mock event and event stream
|
||||
event = MagicMock()
|
||||
# AgentStateChangedObservation requires a content parameter
|
||||
event.observation = AgentStateChangedObservation(
|
||||
agent_state=AgentState.RUNNING, content='Agent state changed to RUNNING'
|
||||
agent_state=AgentState.RUNNING, content='Agent is running'
|
||||
)
|
||||
event_stream = MagicMock()
|
||||
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
loop = MagicMock()
|
||||
# Create mock dependencies
|
||||
is_paused = asyncio.Event()
|
||||
config = MagicMock()
|
||||
config.security.confirmation_mode = False
|
||||
loop = MagicMock()
|
||||
reload_microagents = False
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Call our simplified on_event_async
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal reload_microagents
|
||||
|
||||
if isinstance(event.observation, AgentStateChangedObservation):
|
||||
if event.observation.agent_state == AgentState.RUNNING:
|
||||
# Enable pause/resume functionality only if the confirmation mode is disabled
|
||||
if not config.security.confirmation_mode:
|
||||
mock_display_message()
|
||||
loop.create_task(mock_process_agent_pause(is_paused))
|
||||
mock_display_message()
|
||||
loop.create_task(
|
||||
mock_process_agent_pause(is_paused, event_stream)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
# Call on_event_async_test
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that the message was displayed
|
||||
# Check that display_agent_running_message was called
|
||||
mock_display_message.assert_called_once()
|
||||
|
||||
# Check that process_agent_pause was called with the right arguments
|
||||
mock_process_agent_pause.assert_called_once_with(is_paused)
|
||||
|
||||
# Check that loop.create_task was called
|
||||
loop.create_task.assert_called_once()
|
||||
|
||||
@@ -286,40 +376,33 @@ class TestAgentStatePauseResume:
|
||||
event.observation = AgentStateChangedObservation(
|
||||
agent_state=AgentState.PAUSED, content='Agent state changed to PAUSED'
|
||||
)
|
||||
reload_microagents = False
|
||||
memory = MagicMock()
|
||||
runtime = MagicMock()
|
||||
prompt_task = MagicMock()
|
||||
is_paused = asyncio.Event()
|
||||
|
||||
# Mock function that would be called for prompting
|
||||
mock_prompt_task = MagicMock()
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Create a simplified version of on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal reload_microagents, prompt_task
|
||||
nonlocal is_paused
|
||||
|
||||
if isinstance(event.observation, AgentStateChangedObservation):
|
||||
if event.observation.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
AgentState.PAUSED,
|
||||
]:
|
||||
# Reload microagents after initialization of repo.md
|
||||
if reload_microagents:
|
||||
microagents = runtime.get_microagents_from_selected_repo(
|
||||
None
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
reload_microagents = False
|
||||
if event.observation.agent_state == AgentState.PAUSED:
|
||||
is_paused.clear() # Revert the event state before prompting for user input
|
||||
mock_prompt_task(event.observation.agent_state)
|
||||
|
||||
# Since prompt_for_next_task is a nested function in cli.py,
|
||||
# we'll just check that we've reached this code path
|
||||
prompt_task = 'Prompt for next task would be called here'
|
||||
# Set is_paused to test that it gets cleared
|
||||
is_paused.set()
|
||||
|
||||
# Call the function
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that we reached the code path where prompt_for_next_task would be called
|
||||
assert prompt_task == 'Prompt for next task would be called here'
|
||||
# Check that is_paused was cleared
|
||||
assert not is_paused.is_set()
|
||||
|
||||
# Check that prompt task was called with the correct state
|
||||
mock_prompt_task.assert_called_once_with(AgentState.PAUSED)
|
||||
|
||||
# Run the test
|
||||
await test_func()
|
||||
|
||||
Reference in New Issue
Block a user