diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 61d63544a4..850f193034 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -1,4 +1,3 @@ -import asyncio import os import uuid from datetime import datetime, timezone @@ -36,6 +35,7 @@ from openhands.server.user_auth import ( get_provider_tokens, get_user_id, get_user_secrets, + get_user_settings, ) from openhands.server.user_auth.user_auth import AuthType from openhands.server.utils import get_conversation_store @@ -45,6 +45,7 @@ from openhands.storage.data_models.conversation_metadata import ( ConversationTrigger, ) from openhands.storage.data_models.conversation_status import ConversationStatus +from openhands.storage.data_models.settings import Settings from openhands.storage.data_models.user_secrets import UserSecrets from openhands.utils.async_utils import wait_all from openhands.utils.conversation_summary import get_default_conversation_title @@ -68,10 +69,11 @@ class InitSessionRequest(BaseModel): model_config = {'extra': 'forbid'} -class InitSessionResponse(BaseModel): +class ConversationResponse(BaseModel): status: str conversation_id: str message: str | None = None + conversation_status: ConversationStatus | None = None @app.post('/conversations') @@ -81,7 +83,7 @@ async def new_conversation( provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens), user_secrets: UserSecrets = Depends(get_user_secrets), auth_type: AuthType | None = Depends(get_auth_type), -) -> InitSessionResponse: +) -> ConversationResponse: """Initialize a new session or join an existing one. After successful initialization, the client should connect to the WebSocket @@ -126,7 +128,7 @@ async def new_conversation( await provider_handler.verify_repo_provider(repository, git_provider) conversation_id = getattr(data, 'conversation_id', None) or uuid.uuid4().hex - await create_new_conversation( + agent_loop_info = await create_new_conversation( user_id=user_id, git_provider_tokens=provider_tokens, custom_secrets=user_secrets.custom_secrets if user_secrets else None, @@ -141,9 +143,10 @@ async def new_conversation( conversation_id=conversation_id, ) - return InitSessionResponse( + return ConversationResponse( status='ok', conversation_id=conversation_id, + conversation_status=agent_loop_info.status, ) except MissingSettingsError as e: return JSONResponse( @@ -303,3 +306,108 @@ async def _get_conversation_info( extra={'session_id': conversation.conversation_id}, ) return None + + +@app.post('/conversations/{conversation_id}/start') +async def start_conversation( + conversation_id: str, + user_id: str = Depends(get_user_id), + settings: Settings = Depends(get_user_settings), + conversation_store: ConversationStore = Depends(get_conversation_store), +) -> ConversationResponse: + """Start an agent loop for a conversation. + + This endpoint calls the conversation_manager's maybe_start_agent_loop method + to start a conversation. If the conversation is already running, it will + return the existing agent loop info. + """ + logger.info(f'Starting conversation: {conversation_id}') + + try: + + # Check that the conversation exists + try: + await conversation_store.get_metadata(conversation_id) + except Exception: + return JSONResponse( + content={ + 'status': 'error', + 'conversation_id': conversation_id, + }, + status_code=status.HTTP_404_NOT_FOUND, + ) + + # Start the agent loop + agent_loop_info = await conversation_manager.maybe_start_agent_loop( + sid=conversation_id, + settings=settings, + user_id=user_id, + ) + + return ConversationResponse( + status='ok', + conversation_id=conversation_id, + conversation_status=agent_loop_info.status, + ) + except Exception as e: + logger.error( + f'Error starting conversation {conversation_id}: {str(e)}', + extra={'session_id': conversation_id}, + ) + return JSONResponse( + content={ + 'status': 'error', + 'conversation_id': conversation_id, + 'message': f'Failed to start conversation: {str(e)}', + }, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@app.post('/conversations/{conversation_id}/stop') +async def stop_conversation( + conversation_id: str, + user_id: str = Depends(get_user_id), +) -> ConversationResponse: + """Stop an agent loop for a conversation. + + This endpoint calls the conversation_manager's close_session method + to stop a conversation. + """ + logger.info(f'Stopping conversation: {conversation_id}') + + try: + # Check if the conversation is running + agent_loop_info = await conversation_manager.get_agent_loop_info(user_id=user_id, filter_to_sids={conversation_id}) + conversation_status = agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED + + if conversation_status not in (ConversationStatus.STARTING, ConversationStatus.RUNNING): + return ConversationResponse( + status='ok', + conversation_id=conversation_id, + message='Conversation was not running', + conversation_status=conversation_status, + ) + + # Stop the conversation + await conversation_manager.close_session(conversation_id) + + return ConversationResponse( + status='ok', + conversation_id=conversation_id, + message='Conversation stopped successfully', + conversation_status=conversation_status, + ) + except Exception as e: + logger.error( + f'Error stopping conversation {conversation_id}: {str(e)}', + extra={'session_id': conversation_id}, + ) + return JSONResponse( + content={ + 'status': 'error', + 'conversation_id': conversation_id, + 'message': f'Failed to stop conversation: {str(e)}', + }, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) diff --git a/tests/unit/test_conversation.py b/tests/unit/test_conversation.py index 0cd7ecd837..79808edbfb 100644 --- a/tests/unit/test_conversation.py +++ b/tests/unit/test_conversation.py @@ -20,8 +20,8 @@ from openhands.server.data_models.conversation_info_result_set import ( ConversationInfoResultSet, ) from openhands.server.routes.manage_conversations import ( + ConversationResponse, InitSessionRequest, - InitSessionResponse, delete_conversation, get_conversation, new_conversation, @@ -250,6 +250,7 @@ async def test_new_conversation_success(provider_handler_mock): conversation_id='test_conversation_id', url='https://my-conversation.com', session_api_key=None, + status=ConversationStatus.RUNNING, ) test_request = InitSessionRequest( @@ -263,7 +264,7 @@ async def test_new_conversation_success(provider_handler_mock): response = await create_new_test_conversation(test_request) # Verify the response - assert isinstance(response, InitSessionResponse) + assert isinstance(response, ConversationResponse) assert response.status == 'ok' # Don't check the exact conversation_id as it's now generated dynamically assert response.conversation_id is not None @@ -293,6 +294,7 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock): conversation_id='test_conversation_id', url='https://my-conversation.com', session_api_key=None, + status=ConversationStatus.RUNNING, ) # Mock SuggestedTask.get_prompt_for_task @@ -321,7 +323,7 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock): response = await create_new_test_conversation(test_request) # Verify the response - assert isinstance(response, InitSessionResponse) + assert isinstance(response, ConversationResponse) assert response.status == 'ok' # Don't check the exact conversation_id as it's now generated dynamically assert response.conversation_id is not None @@ -479,6 +481,7 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock): conversation_id='test_conversation_id', url='https://my-conversation.com', session_api_key=None, + status=ConversationStatus.RUNNING, ) # Create the request object @@ -492,7 +495,7 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock): response = await create_new_test_conversation(test_request, AuthType.BEARER) # Verify the response - assert isinstance(response, InitSessionResponse) + assert isinstance(response, ConversationResponse) assert response.status == 'ok' # Verify that create_new_conversation was called with REMOTE_API_KEY trigger @@ -516,6 +519,7 @@ async def test_new_conversation_with_null_repository(): conversation_id='test_conversation_id', url='https://my-conversation.com', session_api_key=None, + status=ConversationStatus.RUNNING, ) # Create the request object with null repository @@ -529,7 +533,7 @@ async def test_new_conversation_with_null_repository(): response = await create_new_test_conversation(test_request) # Verify the response - assert isinstance(response, InitSessionResponse) + assert isinstance(response, ConversationResponse) assert response.status == 'ok' # Verify that create_new_conversation was called with None repository