Restore previous conversation in CLI (#8431)

This commit is contained in:
Engel Nyst 2025-05-15 23:47:41 +02:00 committed by GitHub
parent 033788c2d0
commit f7cb2d0f64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 256 additions and 24 deletions

View File

@ -1,7 +1,6 @@
import asyncio
import logging
import sys
from uuid import uuid4
from prompt_toolkit.shortcuts import clear
@ -42,6 +41,7 @@ from openhands.core.setup import (
create_controller,
create_memory,
create_runtime,
generate_sid,
initialize_repository_for_runtime,
)
from openhands.events import EventSource, EventStreamSubscriber
@ -81,6 +81,16 @@ async def cleanup_session(
if pending:
await asyncio.wait(pending, timeout=5.0)
event_stream = runtime.event_stream
# Save the final state
end_state = controller.get_state()
end_state.save_to_session(
event_stream.sid,
event_stream.file_store,
event_stream.user_id,
)
# Reset agent, close runtime and controller
agent.reset()
runtime.close()
@ -94,12 +104,13 @@ async def run_session(
config: AppConfig,
settings_store: FileSettingsStore,
current_dir: str,
initial_user_action: str | None = None,
task_content: str | None = None,
session_name: str | None = None,
) -> bool:
reload_microagents = False
new_session_requested = False
sid = str(uuid4())
sid = generate_sid(config, session_name)
is_loaded = asyncio.Event()
is_paused = asyncio.Event() # Event to track agent pause requests
always_confirm_mode = False # Flag to enable always confirm mode
@ -120,7 +131,7 @@ async def run_session(
agent=agent,
)
controller, _ = create_controller(agent, runtime, config)
controller, initial_state = create_controller(agent, runtime, config)
event_stream = runtime.event_stream
@ -218,7 +229,7 @@ async def run_session(
def on_event(event: Event) -> None:
loop.create_task(on_event_async(event))
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, sid)
await runtime.connect()
await add_mcp_tools_to_agent(agent, runtime, config.mcp)
@ -249,17 +260,38 @@ async def run_session(
# Show OpenHands banner and session ID
display_banner(session_id=sid)
# Show OpenHands welcome
display_welcome_message()
welcome_message = 'What do you want to build?' # from the application
initial_message = '' # from the user
if initial_user_action:
# If there's an initial user action, enqueue it and do not prompt again
display_initial_user_prompt(initial_user_action)
event_stream.add_event(
MessageAction(content=initial_user_action), EventSource.USER
)
if task_content:
initial_message = task_content
# If we loaded a state, we are resuming a previous session
if initial_state is not None:
logger.info(f'Resuming session: {sid}')
if initial_state.last_error:
# If the last session ended in an error, provide a message.
initial_message = (
'NOTE: the last session ended with an error.'
"Let's get back on track. Do NOT resume your task. Ask me about it."
)
else:
# If we are resuming, we already have a task
initial_message = ''
welcome_message += '\nLoading previous conversation.'
# Show OpenHands welcome
display_welcome_message(welcome_message)
# The prompt_for_next_task will be triggered if the agent enters AWAITING_USER_INPUT.
# If the restored state is already AWAITING_USER_INPUT, on_event_async will handle it.
if initial_message:
display_initial_user_prompt(initial_message)
event_stream.add_event(MessageAction(content=initial_message), EventSource.USER)
else:
# Otherwise prompt for the user's first message right away
# No session restored, no initial action: prompt for the user's first message
asyncio.create_task(prompt_for_next_task(''))
await run_agent_until_done(
@ -334,7 +366,12 @@ async def main(loop: asyncio.AbstractEventLoop) -> None:
# Run the first session
new_session_requested = await run_session(
loop, config, settings_store, current_dir, task_str
loop,
config,
settings_store,
current_dir,
task_str,
session_name=args.name,
)
# If a new session was requested, run it

View File

@ -145,14 +145,20 @@ def display_banner(session_id: str) -> None:
print_formatted_text('')
def display_welcome_message() -> None:
def display_welcome_message(message: str = '') -> None:
print_formatted_text(
HTML("<gold>Let's start building!</gold>\n"), style=DEFAULT_STYLE
)
print_formatted_text(
HTML('What do you want to build? <grey>Type /help for help</grey>'),
style=DEFAULT_STYLE,
)
if message:
print_formatted_text(
HTML(f'{message} <grey>Type /help for help</grey>'),
style=DEFAULT_STYLE,
)
else:
print_formatted_text(
HTML('What do you want to build? <grey>Type /help for help</grey>'),
style=DEFAULT_STYLE,
)
def display_initial_user_prompt(prompt: str) -> None:

View File

@ -132,7 +132,7 @@ pyarrow = "20.0.0" # transiti
datasets = "*"
[tool.poetry.scripts]
openhands = "openhands.core.cli:main"
openhands = "openhands.cli.main:main"
[tool.poetry.group.testgeneval.dependencies]
fuzzywuzzy = "^0.18.0"

View File

@ -5,6 +5,7 @@ import pytest
import pytest_asyncio
from openhands.cli import main as cli
from openhands.controller.state.state import State
from openhands.events import EventSource
from openhands.events.action import MessageAction
@ -28,6 +29,11 @@ def mock_runtime():
def mock_controller():
controller = AsyncMock()
controller.close = AsyncMock()
# Setup for get_state() and the returned state's save_to_session()
mock_state = MagicMock()
mock_state.save_to_session = MagicMock()
controller.get_state = MagicMock(return_value=mock_state)
return controller
@ -247,8 +253,10 @@ async def test_run_session_with_initial_action(
mock_create_runtime.return_value = mock_runtime
mock_controller = AsyncMock()
mock_controller_task = MagicMock()
mock_create_controller.return_value = (mock_controller, mock_controller_task)
mock_create_controller.return_value = (
mock_controller,
None,
) # Ensure initial_state is None for this test
mock_memory = AsyncMock()
mock_create_memory.return_value = mock_memory
@ -326,6 +334,7 @@ async def test_main_without_task(
mock_args = MagicMock()
mock_args.agent_cls = None
mock_args.llm_config = None
mock_args.name = None
mock_parse_args.return_value = mock_args
# Mock config
@ -372,7 +381,7 @@ async def test_main_without_task(
# Check that run_session was called with expected arguments
mock_run_session.assert_called_once_with(
loop, mock_config, mock_settings_store, '/test/dir', None
loop, mock_config, mock_settings_store, '/test/dir', None, session_name=None
)
@ -470,6 +479,186 @@ async def test_main_with_task(
assert second_call_args[4] is None
@pytest.mark.asyncio
@patch('openhands.cli.main.parse_arguments')
@patch('openhands.cli.main.setup_config_from_args')
@patch('openhands.cli.main.FileSettingsStore.get_instance')
@patch('openhands.cli.main.check_folder_security_agreement')
@patch('openhands.cli.main.read_task')
@patch('openhands.cli.main.run_session')
@patch('openhands.cli.main.LLMSummarizingCondenserConfig')
@patch('openhands.cli.main.NoOpCondenserConfig')
async def test_main_with_session_name_passes_name_to_run_session(
mock_noop_condenser,
mock_llm_condenser,
mock_run_session,
mock_read_task,
mock_check_security,
mock_get_settings_store,
mock_setup_config,
mock_parse_args,
):
"""Test main function with a session name passes it to run_session."""
loop = asyncio.get_running_loop()
test_session_name = 'my_named_session'
# Mock arguments
mock_args = MagicMock()
mock_args.agent_cls = None
mock_args.llm_config = None
mock_args.name = test_session_name # Set the session name
mock_parse_args.return_value = mock_args
# Mock config
mock_config = MagicMock()
mock_config.workspace_base = '/test/dir'
mock_config.cli_multiline_input = False
mock_setup_config.return_value = mock_config
# Mock settings store
mock_settings_store = AsyncMock()
mock_settings = MagicMock()
mock_settings.agent = 'test-agent'
mock_settings.llm_model = 'test-model' # Copied from test_main_without_task
mock_settings.llm_api_key = 'test-api-key' # Copied from test_main_without_task
mock_settings.llm_base_url = 'test-base-url' # Copied from test_main_without_task
mock_settings.confirmation_mode = True # Copied from test_main_without_task
mock_settings.enable_default_condenser = True # Copied from test_main_without_task
mock_settings_store.load.return_value = mock_settings
mock_get_settings_store.return_value = mock_settings_store
# Mock condenser config (as in test_main_without_task)
mock_llm_condenser_instance = MagicMock()
mock_llm_condenser.return_value = mock_llm_condenser_instance
# Mock security check
mock_check_security.return_value = True
# Mock read_task to return no task
mock_read_task.return_value = None
# Mock run_session to return False (no new session requested)
mock_run_session.return_value = False
# Run the function
await cli.main(loop)
# Assertions
mock_parse_args.assert_called_once()
mock_setup_config.assert_called_once_with(mock_args)
mock_get_settings_store.assert_called_once()
mock_settings_store.load.assert_called_once()
mock_check_security.assert_called_once_with(mock_config, '/test/dir')
mock_read_task.assert_called_once()
# Check that run_session was called with the correct session_name
mock_run_session.assert_called_once_with(
loop,
mock_config,
mock_settings_store,
'/test/dir',
None,
session_name=test_session_name,
)
@pytest.mark.asyncio
@patch('openhands.cli.main.generate_sid')
@patch('openhands.cli.main.create_agent')
@patch('openhands.cli.main.create_runtime') # Returns mock_runtime
@patch('openhands.cli.main.create_memory')
@patch('openhands.cli.main.add_mcp_tools_to_agent')
@patch('openhands.cli.main.run_agent_until_done')
@patch('openhands.cli.main.cleanup_session')
@patch(
'openhands.cli.main.read_prompt_input', new_callable=AsyncMock
) # For REPL control
@patch('openhands.cli.main.handle_commands', new_callable=AsyncMock) # For REPL control
@patch('openhands.core.setup.State.restore_from_session') # Key mock
@patch('openhands.controller.AgentController.__init__') # To check initial_state
@patch('openhands.cli.main.display_runtime_initialization_message') # Cosmetic
@patch('openhands.cli.main.display_initialization_animation') # Cosmetic
@patch('openhands.cli.main.initialize_repository_for_runtime') # Cosmetic / setup
@patch('openhands.cli.main.display_initial_user_prompt') # Cosmetic
async def test_run_session_with_name_attempts_state_restore(
mock_display_initial_user_prompt,
mock_initialize_repo,
mock_display_init_anim,
mock_display_runtime_init,
mock_agent_controller_init,
mock_restore_from_session,
mock_handle_commands,
mock_read_prompt_input,
mock_cleanup_session,
mock_run_agent_until_done,
mock_add_mcp_tools,
mock_create_memory,
mock_create_runtime,
mock_create_agent,
mock_generate_sid,
mock_config, # Fixture
mock_settings_store, # Fixture
):
"""Test run_session with a session_name attempts to restore state and passes it to AgentController."""
loop = asyncio.get_running_loop()
test_session_name = 'my_restore_test_session'
expected_sid = f'sid_for_{test_session_name}'
mock_generate_sid.return_value = expected_sid
mock_agent = AsyncMock()
mock_create_agent.return_value = mock_agent
mock_runtime = AsyncMock()
mock_runtime.event_stream = MagicMock() # This is the EventStream instance
mock_runtime.event_stream.sid = expected_sid
mock_runtime.event_stream.file_store = (
MagicMock()
) # Mock the file_store attribute on the EventStream
mock_create_runtime.return_value = mock_runtime
# This is what State.restore_from_session will return
mock_loaded_state = MagicMock(spec=State)
mock_restore_from_session.return_value = mock_loaded_state
# AgentController.__init__ should not return a value (it's __init__)
mock_agent_controller_init.return_value = None
# To make run_session exit cleanly after one loop
mock_read_prompt_input.return_value = '/exit'
mock_handle_commands.return_value = (
True,
False,
False,
) # close_repl, reload_microagents, new_session_requested
# Mock other functions called by run_session to avoid side effects
mock_initialize_repo.return_value = '/mocked/repo/dir'
mock_create_memory.return_value = AsyncMock() # Memory instance
await cli.run_session(
loop,
mock_config,
mock_settings_store, # This is FileSettingsStore, not directly used for restore in this path
'/test/dir',
task_content=None,
session_name=test_session_name,
)
mock_generate_sid.assert_called_once_with(mock_config, test_session_name)
# State.restore_from_session is called from within core.setup.create_controller,
# which receives the runtime object (and thus its event_stream with sid and file_store).
mock_restore_from_session.assert_called_once_with(
expected_sid, mock_runtime.event_stream.file_store
)
# Check that AgentController was initialized with the loaded state
mock_agent_controller_init.assert_called_once()
args, kwargs = mock_agent_controller_init.call_args
assert kwargs.get('initial_state') == mock_loaded_state
@pytest.mark.asyncio
@patch('openhands.cli.main.parse_arguments')
@patch('openhands.cli.main.setup_config_from_args')