feature: Add CLI support for always confirm mode (#8302)

This commit is contained in:
Bashwara Undupitiya 2025-05-07 06:04:00 -07:00 committed by GitHub
parent 13ca75c8cb
commit ab4f7e88ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 182 additions and 45 deletions

View File

@ -102,6 +102,7 @@ async def run_session(
sid = str(uuid4())
is_loaded = asyncio.Event()
is_paused = asyncio.Event() # Event to track agent pause requests
always_confirm_mode = False # Flag to enable always confirm mode
# Show runtime initialization message
display_runtime_initialization_message(config.runtime)
@ -153,7 +154,7 @@ async def run_session(
return
async def on_event_async(event: Event) -> None:
nonlocal reload_microagents, is_paused
nonlocal reload_microagents, is_paused, always_confirm_mode
display_event(event, config)
update_usage_metrics(event, usage_metrics)
@ -181,8 +182,15 @@ async def run_session(
if is_paused.is_set():
return
user_confirmed = await read_confirmation_input()
if user_confirmed:
if always_confirm_mode:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
EventSource.USER,
)
return
confirmation_status = await read_confirmation_input()
if confirmation_status == 'yes' or confirmation_status == 'always':
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
EventSource.USER,
@ -193,6 +201,10 @@ async def run_session(
EventSource.USER,
)
# Set the always_confirm_mode flag if the user wants to always confirm
if confirmation_status == 'always':
always_confirm_mode = True
if event.agent_state == AgentState.PAUSED:
is_paused.clear() # Revert the event state before prompting for user input
await prompt_for_next_task(event.agent_state)

View File

@ -32,7 +32,6 @@ from openhands.events.action import (
ActionConfirmationStatus,
ChangeAgentStateAction,
CmdRunAction,
FileEditAction,
MessageAction,
)
from openhands.events.event import Event
@ -171,6 +170,8 @@ def display_event(event: Event, config: AppConfig) -> None:
if isinstance(event, Action):
if hasattr(event, 'thought'):
display_message(event.thought)
if hasattr(event, 'final_thought'):
display_message(event.final_thought)
if isinstance(event, MessageAction):
if event.source == EventSource.AGENT:
display_message(event.content)
@ -178,14 +179,12 @@ def display_event(event: Event, config: AppConfig) -> None:
display_command(event)
if isinstance(event, CmdOutputObservation):
display_command_output(event.content)
if isinstance(event, FileEditAction):
display_file_edit(event)
if isinstance(event, FileEditObservation):
display_file_edit(event)
if isinstance(event, FileReadObservation):
display_file_read(event)
if isinstance(event, AgentStateChangedObservation):
display_agent_paused_message(event.agent_state)
display_agent_state_change_message(event.agent_state)
def display_message(message: str):
@ -239,26 +238,26 @@ def display_command_output(output: str):
print_container(container)
def display_file_edit(event: FileEditAction | FileEditObservation):
if isinstance(event, FileEditObservation):
container = Frame(
TextArea(
text=event.visualize_diff(n_context_lines=4),
read_only=True,
wrap_lines=True,
lexer=CustomDiffLexer(),
),
title='File Edit',
style=f'fg:{COLOR_GREY}',
)
print_formatted_text('')
print_container(container)
def display_file_edit(event: FileEditObservation):
container = Frame(
TextArea(
text=event.visualize_diff(n_context_lines=4),
read_only=True,
wrap_lines=True,
lexer=CustomDiffLexer(),
),
title='File Edit',
style=f'fg:{COLOR_GREY}',
)
print_formatted_text('')
print_container(container)
def display_file_read(event: FileReadObservation):
content = event.content.replace('\t', ' ')
container = Frame(
TextArea(
text=f'{event}',
text=content,
read_only=True,
style=COLOR_GREY,
wrap_lines=True,
@ -406,13 +405,20 @@ def display_agent_running_message():
)
def display_agent_paused_message(agent_state: str):
if agent_state != AgentState.PAUSED:
return
print_formatted_text('')
print_formatted_text(
HTML('<gold>Agent paused...</gold> <grey>(Enter /resume to continue)</grey>')
)
def display_agent_state_change_message(agent_state: str):
if agent_state == AgentState.PAUSED:
print_formatted_text('')
print_formatted_text(
HTML(
'<gold>Agent paused...</gold> <grey>(Enter /resume to continue)</grey>'
)
)
elif agent_state == AgentState.FINISHED:
print_formatted_text('')
print_formatted_text(HTML('<gold>Task completed...</gold>'))
elif agent_state == AgentState.AWAITING_USER_INPUT:
print_formatted_text('')
print_formatted_text(HTML('<gold>Agent is waiting for your input...</gold>'))
# Common input functions
@ -478,20 +484,28 @@ async def read_prompt_input(agent_state: str, multiline=False):
return '/exit'
async def read_confirmation_input() -> bool:
async def read_confirmation_input() -> str:
try:
prompt_session = create_prompt_session()
with patch_stdout():
print_formatted_text('')
confirmation: str = await prompt_session.prompt_async(
HTML('<gold>Proceed with action? (y)es/(n)o > </gold>'),
HTML('<gold>Proceed with action? (y)es/(n)o/(a)lways > </gold>'),
)
confirmation = '' if confirmation is None else confirmation.strip().lower()
return confirmation in ['y', 'yes']
if confirmation in ['y', 'yes']:
return 'yes'
elif confirmation in ['n', 'no']:
return 'no'
elif confirmation in ['a', 'always']:
return 'always'
else:
return 'no'
except (KeyboardInterrupt, EOFError):
return False
return 'no'
async def process_agent_pause(done: asyncio.Event, event_stream: EventStream) -> None:
@ -499,7 +513,11 @@ async def process_agent_pause(done: asyncio.Event, event_stream: EventStream) ->
def keys_ready():
for key_press in input.read_keys():
if key_press.key == Keys.ControlP:
if (
key_press.key == Keys.ControlP
or key_press.key == Keys.ControlC
or key_press.key == Keys.ControlD
):
print_formatted_text('')
print_formatted_text(HTML('<gold>Pausing the agent...</gold>'))
event_stream.add_event(

View File

@ -1,4 +1,6 @@
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from openhands.cli.tui import (
CustomDiffLexer,
@ -14,6 +16,7 @@ from openhands.cli.tui import (
display_usage_metrics,
display_welcome_message,
get_session_duration,
read_confirmation_input,
)
from openhands.core.config import AppConfig
from openhands.events import EventSource
@ -21,7 +24,6 @@ from openhands.events.action import (
Action,
ActionConfirmationStatus,
CmdRunAction,
FileEditAction,
MessageAction,
)
from openhands.events.observation import (
@ -98,15 +100,6 @@ class TestDisplayFunctions:
mock_display_output.assert_called_once_with('Test output')
@patch('openhands.cli.tui.display_file_edit')
def test_display_event_file_edit_action(self, mock_display_file_edit):
config = MagicMock(spec=AppConfig)
file_edit = FileEditAction(path='test.py', content="print('hello')")
display_event(file_edit, config)
mock_display_file_edit.assert_called_once_with(file_edit)
@patch('openhands.cli.tui.display_file_edit')
def test_display_event_file_edit_observation(self, mock_display_file_edit):
config = MagicMock(spec=AppConfig)
@ -259,3 +252,117 @@ class TestUserCancelledError:
def test_user_cancelled_error(self):
error = UserCancelledError()
assert isinstance(error, Exception)
class TestReadConfirmationInput:
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_yes(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.return_value = 'y'
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'yes'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_yes_full(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.return_value = 'yes'
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'yes'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_no(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.return_value = 'n'
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'no'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_no_full(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.return_value = 'no'
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'no'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_always(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.return_value = 'a'
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'always'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_always_full(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.return_value = 'always'
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'always'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_invalid(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.return_value = 'invalid'
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'no'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_empty(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.return_value = ''
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'no'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_none(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.return_value = None
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'no'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_keyboard_interrupt(
self, mock_create_session
):
mock_session = AsyncMock()
mock_session.prompt_async.side_effect = KeyboardInterrupt
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'no'
@pytest.mark.asyncio
@patch('openhands.cli.tui.create_prompt_session')
async def test_read_confirmation_input_eof_error(self, mock_create_session):
mock_session = AsyncMock()
mock_session.prompt_async.side_effect = EOFError
mock_create_session.return_value = mock_session
result = await read_confirmation_input()
assert result == 'no'