mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add conversation start and stop endpoints (#8883)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
a1b3c0c7d6
commit
91e24a4a31
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@ -36,6 +35,7 @@ from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
get_user_secrets,
|
||||
get_user_settings,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
from openhands.server.utils import get_conversation_store
|
||||
@ -45,6 +45,7 @@ from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
@ -68,10 +69,11 @@ class InitSessionRequest(BaseModel):
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
|
||||
class InitSessionResponse(BaseModel):
|
||||
class ConversationResponse(BaseModel):
|
||||
status: str
|
||||
conversation_id: str
|
||||
message: str | None = None
|
||||
conversation_status: ConversationStatus | None = None
|
||||
|
||||
|
||||
@app.post('/conversations')
|
||||
@ -81,7 +83,7 @@ async def new_conversation(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
|
||||
user_secrets: UserSecrets = Depends(get_user_secrets),
|
||||
auth_type: AuthType | None = Depends(get_auth_type),
|
||||
) -> InitSessionResponse:
|
||||
) -> ConversationResponse:
|
||||
"""Initialize a new session or join an existing one.
|
||||
|
||||
After successful initialization, the client should connect to the WebSocket
|
||||
@ -126,7 +128,7 @@ async def new_conversation(
|
||||
await provider_handler.verify_repo_provider(repository, git_provider)
|
||||
|
||||
conversation_id = getattr(data, 'conversation_id', None) or uuid.uuid4().hex
|
||||
await create_new_conversation(
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
@ -141,9 +143,10 @@ async def new_conversation(
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
return InitSessionResponse(
|
||||
return ConversationResponse(
|
||||
status='ok',
|
||||
conversation_id=conversation_id,
|
||||
conversation_status=agent_loop_info.status,
|
||||
)
|
||||
except MissingSettingsError as e:
|
||||
return JSONResponse(
|
||||
@ -303,3 +306,108 @@ async def _get_conversation_info(
|
||||
extra={'session_id': conversation.conversation_id},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@app.post('/conversations/{conversation_id}/start')
|
||||
async def start_conversation(
|
||||
conversation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
settings: Settings = Depends(get_user_settings),
|
||||
conversation_store: ConversationStore = Depends(get_conversation_store),
|
||||
) -> ConversationResponse:
|
||||
"""Start an agent loop for a conversation.
|
||||
|
||||
This endpoint calls the conversation_manager's maybe_start_agent_loop method
|
||||
to start a conversation. If the conversation is already running, it will
|
||||
return the existing agent loop info.
|
||||
"""
|
||||
logger.info(f'Starting conversation: {conversation_id}')
|
||||
|
||||
try:
|
||||
|
||||
# Check that the conversation exists
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
except Exception:
|
||||
return JSONResponse(
|
||||
content={
|
||||
'status': 'error',
|
||||
'conversation_id': conversation_id,
|
||||
},
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Start the agent loop
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
sid=conversation_id,
|
||||
settings=settings,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return ConversationResponse(
|
||||
status='ok',
|
||||
conversation_id=conversation_id,
|
||||
conversation_status=agent_loop_info.status,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error starting conversation {conversation_id}: {str(e)}',
|
||||
extra={'session_id': conversation_id},
|
||||
)
|
||||
return JSONResponse(
|
||||
content={
|
||||
'status': 'error',
|
||||
'conversation_id': conversation_id,
|
||||
'message': f'Failed to start conversation: {str(e)}',
|
||||
},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@app.post('/conversations/{conversation_id}/stop')
|
||||
async def stop_conversation(
|
||||
conversation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> ConversationResponse:
|
||||
"""Stop an agent loop for a conversation.
|
||||
|
||||
This endpoint calls the conversation_manager's close_session method
|
||||
to stop a conversation.
|
||||
"""
|
||||
logger.info(f'Stopping conversation: {conversation_id}')
|
||||
|
||||
try:
|
||||
# Check if the conversation is running
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(user_id=user_id, filter_to_sids={conversation_id})
|
||||
conversation_status = agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED
|
||||
|
||||
if conversation_status not in (ConversationStatus.STARTING, ConversationStatus.RUNNING):
|
||||
return ConversationResponse(
|
||||
status='ok',
|
||||
conversation_id=conversation_id,
|
||||
message='Conversation was not running',
|
||||
conversation_status=conversation_status,
|
||||
)
|
||||
|
||||
# Stop the conversation
|
||||
await conversation_manager.close_session(conversation_id)
|
||||
|
||||
return ConversationResponse(
|
||||
status='ok',
|
||||
conversation_id=conversation_id,
|
||||
message='Conversation stopped successfully',
|
||||
conversation_status=conversation_status,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error stopping conversation {conversation_id}: {str(e)}',
|
||||
extra={'session_id': conversation_id},
|
||||
)
|
||||
return JSONResponse(
|
||||
content={
|
||||
'status': 'error',
|
||||
'conversation_id': conversation_id,
|
||||
'message': f'Failed to stop conversation: {str(e)}',
|
||||
},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
@ -20,8 +20,8 @@ from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.server.routes.manage_conversations import (
|
||||
ConversationResponse,
|
||||
InitSessionRequest,
|
||||
InitSessionResponse,
|
||||
delete_conversation,
|
||||
get_conversation,
|
||||
new_conversation,
|
||||
@ -250,6 +250,7 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
session_api_key=None,
|
||||
status=ConversationStatus.RUNNING,
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
@ -263,7 +264,7 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert isinstance(response, ConversationResponse)
|
||||
assert response.status == 'ok'
|
||||
# Don't check the exact conversation_id as it's now generated dynamically
|
||||
assert response.conversation_id is not None
|
||||
@ -293,6 +294,7 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
session_api_key=None,
|
||||
status=ConversationStatus.RUNNING,
|
||||
)
|
||||
|
||||
# Mock SuggestedTask.get_prompt_for_task
|
||||
@ -321,7 +323,7 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert isinstance(response, ConversationResponse)
|
||||
assert response.status == 'ok'
|
||||
# Don't check the exact conversation_id as it's now generated dynamically
|
||||
assert response.conversation_id is not None
|
||||
@ -479,6 +481,7 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock):
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
session_api_key=None,
|
||||
status=ConversationStatus.RUNNING,
|
||||
)
|
||||
|
||||
# Create the request object
|
||||
@ -492,7 +495,7 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock):
|
||||
response = await create_new_test_conversation(test_request, AuthType.BEARER)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert isinstance(response, ConversationResponse)
|
||||
assert response.status == 'ok'
|
||||
|
||||
# Verify that create_new_conversation was called with REMOTE_API_KEY trigger
|
||||
@ -516,6 +519,7 @@ async def test_new_conversation_with_null_repository():
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
session_api_key=None,
|
||||
status=ConversationStatus.RUNNING,
|
||||
)
|
||||
|
||||
# Create the request object with null repository
|
||||
@ -529,7 +533,7 @@ async def test_new_conversation_with_null_repository():
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert isinstance(response, ConversationResponse)
|
||||
assert response.status == 'ok'
|
||||
|
||||
# Verify that create_new_conversation was called with None repository
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user