mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Restore previous conversation in CLI (#8431)
This commit is contained in:
parent
033788c2d0
commit
f7cb2d0f64
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from uuid import uuid4
|
||||
|
||||
from prompt_toolkit.shortcuts import clear
|
||||
|
||||
@ -42,6 +41,7 @@ from openhands.core.setup import (
|
||||
create_controller,
|
||||
create_memory,
|
||||
create_runtime,
|
||||
generate_sid,
|
||||
initialize_repository_for_runtime,
|
||||
)
|
||||
from openhands.events import EventSource, EventStreamSubscriber
|
||||
@ -81,6 +81,16 @@ async def cleanup_session(
|
||||
if pending:
|
||||
await asyncio.wait(pending, timeout=5.0)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
|
||||
# Save the final state
|
||||
end_state = controller.get_state()
|
||||
end_state.save_to_session(
|
||||
event_stream.sid,
|
||||
event_stream.file_store,
|
||||
event_stream.user_id,
|
||||
)
|
||||
|
||||
# Reset agent, close runtime and controller
|
||||
agent.reset()
|
||||
runtime.close()
|
||||
@ -94,12 +104,13 @@ async def run_session(
|
||||
config: AppConfig,
|
||||
settings_store: FileSettingsStore,
|
||||
current_dir: str,
|
||||
initial_user_action: str | None = None,
|
||||
task_content: str | None = None,
|
||||
session_name: str | None = None,
|
||||
) -> bool:
|
||||
reload_microagents = False
|
||||
new_session_requested = False
|
||||
|
||||
sid = str(uuid4())
|
||||
sid = generate_sid(config, session_name)
|
||||
is_loaded = asyncio.Event()
|
||||
is_paused = asyncio.Event() # Event to track agent pause requests
|
||||
always_confirm_mode = False # Flag to enable always confirm mode
|
||||
@ -120,7 +131,7 @@ async def run_session(
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
controller, _ = create_controller(agent, runtime, config)
|
||||
controller, initial_state = create_controller(agent, runtime, config)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
|
||||
@ -218,7 +229,7 @@ async def run_session(
|
||||
def on_event(event: Event) -> None:
|
||||
loop.create_task(on_event_async(event))
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, sid)
|
||||
|
||||
await runtime.connect()
|
||||
await add_mcp_tools_to_agent(agent, runtime, config.mcp)
|
||||
@ -249,17 +260,38 @@ async def run_session(
|
||||
# Show OpenHands banner and session ID
|
||||
display_banner(session_id=sid)
|
||||
|
||||
# Show OpenHands welcome
|
||||
display_welcome_message()
|
||||
welcome_message = 'What do you want to build?' # from the application
|
||||
initial_message = '' # from the user
|
||||
|
||||
if initial_user_action:
|
||||
# If there's an initial user action, enqueue it and do not prompt again
|
||||
display_initial_user_prompt(initial_user_action)
|
||||
event_stream.add_event(
|
||||
MessageAction(content=initial_user_action), EventSource.USER
|
||||
)
|
||||
if task_content:
|
||||
initial_message = task_content
|
||||
|
||||
# If we loaded a state, we are resuming a previous session
|
||||
if initial_state is not None:
|
||||
logger.info(f'Resuming session: {sid}')
|
||||
|
||||
if initial_state.last_error:
|
||||
# If the last session ended in an error, provide a message.
|
||||
initial_message = (
|
||||
'NOTE: the last session ended with an error.'
|
||||
"Let's get back on track. Do NOT resume your task. Ask me about it."
|
||||
)
|
||||
else:
|
||||
# If we are resuming, we already have a task
|
||||
initial_message = ''
|
||||
welcome_message += '\nLoading previous conversation.'
|
||||
|
||||
# Show OpenHands welcome
|
||||
display_welcome_message(welcome_message)
|
||||
|
||||
# The prompt_for_next_task will be triggered if the agent enters AWAITING_USER_INPUT.
|
||||
# If the restored state is already AWAITING_USER_INPUT, on_event_async will handle it.
|
||||
|
||||
if initial_message:
|
||||
display_initial_user_prompt(initial_message)
|
||||
event_stream.add_event(MessageAction(content=initial_message), EventSource.USER)
|
||||
else:
|
||||
# Otherwise prompt for the user's first message right away
|
||||
# No session restored, no initial action: prompt for the user's first message
|
||||
asyncio.create_task(prompt_for_next_task(''))
|
||||
|
||||
await run_agent_until_done(
|
||||
@ -334,7 +366,12 @@ async def main(loop: asyncio.AbstractEventLoop) -> None:
|
||||
|
||||
# Run the first session
|
||||
new_session_requested = await run_session(
|
||||
loop, config, settings_store, current_dir, task_str
|
||||
loop,
|
||||
config,
|
||||
settings_store,
|
||||
current_dir,
|
||||
task_str,
|
||||
session_name=args.name,
|
||||
)
|
||||
|
||||
# If a new session was requested, run it
|
||||
|
||||
@ -145,14 +145,20 @@ def display_banner(session_id: str) -> None:
|
||||
print_formatted_text('')
|
||||
|
||||
|
||||
def display_welcome_message() -> None:
|
||||
def display_welcome_message(message: str = '') -> None:
|
||||
print_formatted_text(
|
||||
HTML("<gold>Let's start building!</gold>\n"), style=DEFAULT_STYLE
|
||||
)
|
||||
print_formatted_text(
|
||||
HTML('What do you want to build? <grey>Type /help for help</grey>'),
|
||||
style=DEFAULT_STYLE,
|
||||
)
|
||||
if message:
|
||||
print_formatted_text(
|
||||
HTML(f'{message} <grey>Type /help for help</grey>'),
|
||||
style=DEFAULT_STYLE,
|
||||
)
|
||||
else:
|
||||
print_formatted_text(
|
||||
HTML('What do you want to build? <grey>Type /help for help</grey>'),
|
||||
style=DEFAULT_STYLE,
|
||||
)
|
||||
|
||||
|
||||
def display_initial_user_prompt(prompt: str) -> None:
|
||||
|
||||
@ -132,7 +132,7 @@ pyarrow = "20.0.0" # transiti
|
||||
datasets = "*"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
openhands = "openhands.core.cli:main"
|
||||
openhands = "openhands.cli.main:main"
|
||||
|
||||
[tool.poetry.group.testgeneval.dependencies]
|
||||
fuzzywuzzy = "^0.18.0"
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from openhands.cli import main as cli
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import MessageAction
|
||||
|
||||
@ -28,6 +29,11 @@ def mock_runtime():
|
||||
def mock_controller():
|
||||
controller = AsyncMock()
|
||||
controller.close = AsyncMock()
|
||||
|
||||
# Setup for get_state() and the returned state's save_to_session()
|
||||
mock_state = MagicMock()
|
||||
mock_state.save_to_session = MagicMock()
|
||||
controller.get_state = MagicMock(return_value=mock_state)
|
||||
return controller
|
||||
|
||||
|
||||
@ -247,8 +253,10 @@ async def test_run_session_with_initial_action(
|
||||
mock_create_runtime.return_value = mock_runtime
|
||||
|
||||
mock_controller = AsyncMock()
|
||||
mock_controller_task = MagicMock()
|
||||
mock_create_controller.return_value = (mock_controller, mock_controller_task)
|
||||
mock_create_controller.return_value = (
|
||||
mock_controller,
|
||||
None,
|
||||
) # Ensure initial_state is None for this test
|
||||
|
||||
mock_memory = AsyncMock()
|
||||
mock_create_memory.return_value = mock_memory
|
||||
@ -326,6 +334,7 @@ async def test_main_without_task(
|
||||
mock_args = MagicMock()
|
||||
mock_args.agent_cls = None
|
||||
mock_args.llm_config = None
|
||||
mock_args.name = None
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Mock config
|
||||
@ -372,7 +381,7 @@ async def test_main_without_task(
|
||||
|
||||
# Check that run_session was called with expected arguments
|
||||
mock_run_session.assert_called_once_with(
|
||||
loop, mock_config, mock_settings_store, '/test/dir', None
|
||||
loop, mock_config, mock_settings_store, '/test/dir', None, session_name=None
|
||||
)
|
||||
|
||||
|
||||
@ -470,6 +479,186 @@ async def test_main_with_task(
|
||||
assert second_call_args[4] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.cli.main.parse_arguments')
|
||||
@patch('openhands.cli.main.setup_config_from_args')
|
||||
@patch('openhands.cli.main.FileSettingsStore.get_instance')
|
||||
@patch('openhands.cli.main.check_folder_security_agreement')
|
||||
@patch('openhands.cli.main.read_task')
|
||||
@patch('openhands.cli.main.run_session')
|
||||
@patch('openhands.cli.main.LLMSummarizingCondenserConfig')
|
||||
@patch('openhands.cli.main.NoOpCondenserConfig')
|
||||
async def test_main_with_session_name_passes_name_to_run_session(
|
||||
mock_noop_condenser,
|
||||
mock_llm_condenser,
|
||||
mock_run_session,
|
||||
mock_read_task,
|
||||
mock_check_security,
|
||||
mock_get_settings_store,
|
||||
mock_setup_config,
|
||||
mock_parse_args,
|
||||
):
|
||||
"""Test main function with a session name passes it to run_session."""
|
||||
loop = asyncio.get_running_loop()
|
||||
test_session_name = 'my_named_session'
|
||||
|
||||
# Mock arguments
|
||||
mock_args = MagicMock()
|
||||
mock_args.agent_cls = None
|
||||
mock_args.llm_config = None
|
||||
mock_args.name = test_session_name # Set the session name
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Mock config
|
||||
mock_config = MagicMock()
|
||||
mock_config.workspace_base = '/test/dir'
|
||||
mock_config.cli_multiline_input = False
|
||||
mock_setup_config.return_value = mock_config
|
||||
|
||||
# Mock settings store
|
||||
mock_settings_store = AsyncMock()
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.agent = 'test-agent'
|
||||
mock_settings.llm_model = 'test-model' # Copied from test_main_without_task
|
||||
mock_settings.llm_api_key = 'test-api-key' # Copied from test_main_without_task
|
||||
mock_settings.llm_base_url = 'test-base-url' # Copied from test_main_without_task
|
||||
mock_settings.confirmation_mode = True # Copied from test_main_without_task
|
||||
mock_settings.enable_default_condenser = True # Copied from test_main_without_task
|
||||
mock_settings_store.load.return_value = mock_settings
|
||||
mock_get_settings_store.return_value = mock_settings_store
|
||||
|
||||
# Mock condenser config (as in test_main_without_task)
|
||||
mock_llm_condenser_instance = MagicMock()
|
||||
mock_llm_condenser.return_value = mock_llm_condenser_instance
|
||||
|
||||
# Mock security check
|
||||
mock_check_security.return_value = True
|
||||
|
||||
# Mock read_task to return no task
|
||||
mock_read_task.return_value = None
|
||||
|
||||
# Mock run_session to return False (no new session requested)
|
||||
mock_run_session.return_value = False
|
||||
|
||||
# Run the function
|
||||
await cli.main(loop)
|
||||
|
||||
# Assertions
|
||||
mock_parse_args.assert_called_once()
|
||||
mock_setup_config.assert_called_once_with(mock_args)
|
||||
mock_get_settings_store.assert_called_once()
|
||||
mock_settings_store.load.assert_called_once()
|
||||
mock_check_security.assert_called_once_with(mock_config, '/test/dir')
|
||||
mock_read_task.assert_called_once()
|
||||
|
||||
# Check that run_session was called with the correct session_name
|
||||
mock_run_session.assert_called_once_with(
|
||||
loop,
|
||||
mock_config,
|
||||
mock_settings_store,
|
||||
'/test/dir',
|
||||
None,
|
||||
session_name=test_session_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.cli.main.generate_sid')
|
||||
@patch('openhands.cli.main.create_agent')
|
||||
@patch('openhands.cli.main.create_runtime') # Returns mock_runtime
|
||||
@patch('openhands.cli.main.create_memory')
|
||||
@patch('openhands.cli.main.add_mcp_tools_to_agent')
|
||||
@patch('openhands.cli.main.run_agent_until_done')
|
||||
@patch('openhands.cli.main.cleanup_session')
|
||||
@patch(
|
||||
'openhands.cli.main.read_prompt_input', new_callable=AsyncMock
|
||||
) # For REPL control
|
||||
@patch('openhands.cli.main.handle_commands', new_callable=AsyncMock) # For REPL control
|
||||
@patch('openhands.core.setup.State.restore_from_session') # Key mock
|
||||
@patch('openhands.controller.AgentController.__init__') # To check initial_state
|
||||
@patch('openhands.cli.main.display_runtime_initialization_message') # Cosmetic
|
||||
@patch('openhands.cli.main.display_initialization_animation') # Cosmetic
|
||||
@patch('openhands.cli.main.initialize_repository_for_runtime') # Cosmetic / setup
|
||||
@patch('openhands.cli.main.display_initial_user_prompt') # Cosmetic
|
||||
async def test_run_session_with_name_attempts_state_restore(
|
||||
mock_display_initial_user_prompt,
|
||||
mock_initialize_repo,
|
||||
mock_display_init_anim,
|
||||
mock_display_runtime_init,
|
||||
mock_agent_controller_init,
|
||||
mock_restore_from_session,
|
||||
mock_handle_commands,
|
||||
mock_read_prompt_input,
|
||||
mock_cleanup_session,
|
||||
mock_run_agent_until_done,
|
||||
mock_add_mcp_tools,
|
||||
mock_create_memory,
|
||||
mock_create_runtime,
|
||||
mock_create_agent,
|
||||
mock_generate_sid,
|
||||
mock_config, # Fixture
|
||||
mock_settings_store, # Fixture
|
||||
):
|
||||
"""Test run_session with a session_name attempts to restore state and passes it to AgentController."""
|
||||
loop = asyncio.get_running_loop()
|
||||
test_session_name = 'my_restore_test_session'
|
||||
expected_sid = f'sid_for_{test_session_name}'
|
||||
|
||||
mock_generate_sid.return_value = expected_sid
|
||||
|
||||
mock_agent = AsyncMock()
|
||||
mock_create_agent.return_value = mock_agent
|
||||
|
||||
mock_runtime = AsyncMock()
|
||||
mock_runtime.event_stream = MagicMock() # This is the EventStream instance
|
||||
mock_runtime.event_stream.sid = expected_sid
|
||||
mock_runtime.event_stream.file_store = (
|
||||
MagicMock()
|
||||
) # Mock the file_store attribute on the EventStream
|
||||
mock_create_runtime.return_value = mock_runtime
|
||||
|
||||
# This is what State.restore_from_session will return
|
||||
mock_loaded_state = MagicMock(spec=State)
|
||||
mock_restore_from_session.return_value = mock_loaded_state
|
||||
|
||||
# AgentController.__init__ should not return a value (it's __init__)
|
||||
mock_agent_controller_init.return_value = None
|
||||
|
||||
# To make run_session exit cleanly after one loop
|
||||
mock_read_prompt_input.return_value = '/exit'
|
||||
mock_handle_commands.return_value = (
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
) # close_repl, reload_microagents, new_session_requested
|
||||
|
||||
# Mock other functions called by run_session to avoid side effects
|
||||
mock_initialize_repo.return_value = '/mocked/repo/dir'
|
||||
mock_create_memory.return_value = AsyncMock() # Memory instance
|
||||
|
||||
await cli.run_session(
|
||||
loop,
|
||||
mock_config,
|
||||
mock_settings_store, # This is FileSettingsStore, not directly used for restore in this path
|
||||
'/test/dir',
|
||||
task_content=None,
|
||||
session_name=test_session_name,
|
||||
)
|
||||
|
||||
mock_generate_sid.assert_called_once_with(mock_config, test_session_name)
|
||||
|
||||
# State.restore_from_session is called from within core.setup.create_controller,
|
||||
# which receives the runtime object (and thus its event_stream with sid and file_store).
|
||||
mock_restore_from_session.assert_called_once_with(
|
||||
expected_sid, mock_runtime.event_stream.file_store
|
||||
)
|
||||
|
||||
# Check that AgentController was initialized with the loaded state
|
||||
mock_agent_controller_init.assert_called_once()
|
||||
args, kwargs = mock_agent_controller_init.call_args
|
||||
assert kwargs.get('initial_state') == mock_loaded_state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.cli.main.parse_arguments')
|
||||
@patch('openhands.cli.main.setup_config_from_args')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user