fix(frontend): prevent chat message loss during websocket disconnections or page refresh (#13380)

This commit is contained in:
Hiep Le
2026-03-16 22:25:44 +07:00
committed by GitHub
parent aec95ecf3b
commit 238cab4d08
29 changed files with 2668 additions and 22 deletions

View File

@@ -59,6 +59,9 @@ from openhands.app_server.event_callback.event_callback_service import (
from openhands.app_server.event_callback.set_title_callback_processor import (
SetTitleCallbackProcessor,
)
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
)
from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService
from openhands.app_server.sandbox.sandbox_models import (
AGENT_SERVER,
@@ -127,6 +130,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
sandbox_service: SandboxService
sandbox_spec_service: SandboxSpecService
jwt_service: JwtService
pending_message_service: PendingMessageService
sandbox_startup_timeout: int
sandbox_startup_poll_frequency: int
max_num_conversations_per_sandbox: int
@@ -373,6 +377,15 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
task.app_conversation_id = info.id
yield task
# Process any pending messages queued while waiting for conversation
if sandbox.session_api_key:
await self._process_pending_messages(
task_id=task.id,
conversation_id=info.id,
agent_server_url=agent_server_url,
session_api_key=sandbox.session_api_key,
)
except Exception as exc:
_logger.exception('Error starting conversation', stack_info=True)
task.status = AppConversationStartTaskStatus.ERROR
@@ -1424,6 +1437,89 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
plugins=plugins,
)
async def _process_pending_messages(
self,
task_id: UUID,
conversation_id: UUID,
agent_server_url: str,
session_api_key: str,
) -> None:
"""Process pending messages queued before conversation was ready.
Messages are delivered concurrently to the agent server. After processing,
all messages are deleted from the database regardless of success or failure.
Args:
task_id: The start task ID (may have been used as conversation_id initially)
conversation_id: The real conversation ID
agent_server_url: URL of the agent server
session_api_key: API key for authenticating with agent server
"""
# Convert UUIDs to strings for the pending message service
# The frontend uses task-{uuid.hex} format (no hyphens), matching OpenHandsUUID serialization
task_id_str = f'task-{task_id.hex}'
# conversation_id uses standard format (with hyphens) for agent server API compatibility
conversation_id_str = str(conversation_id)
_logger.info(f'task_id={task_id_str} conversation_id={conversation_id_str}')
# First, update any messages that were queued with the task_id
updated_count = await self.pending_message_service.update_conversation_id(
old_conversation_id=task_id_str,
new_conversation_id=conversation_id_str,
)
_logger.info(f'updated_count={updated_count} ')
if updated_count > 0:
_logger.info(
f'Updated {updated_count} pending messages from task_id={task_id_str} '
f'to conversation_id={conversation_id_str}'
)
# Get all pending messages for this conversation
pending_messages = await self.pending_message_service.get_pending_messages(
conversation_id_str
)
if not pending_messages:
return
_logger.info(
f'Processing {len(pending_messages)} pending messages for '
f'conversation {conversation_id_str}'
)
# Process messages sequentially to preserve order
for msg in pending_messages:
try:
# Serialize content objects to JSON-compatible dicts
content_json = [item.model_dump() for item in msg.content]
# Use the events endpoint which handles message sending
response = await self.httpx_client.post(
f'{agent_server_url}/api/conversations/{conversation_id_str}/events',
json={
'role': msg.role,
'content': content_json,
'run': True,
},
headers={'X-Session-API-Key': session_api_key},
timeout=30.0,
)
response.raise_for_status()
_logger.debug(f'Delivered pending message {msg.id}')
except Exception as e:
_logger.warning(f'Failed to deliver pending message {msg.id}: {e}')
# Delete all pending messages after processing (regardless of success/failure)
deleted_count = (
await self.pending_message_service.delete_messages_for_conversation(
conversation_id_str
)
)
_logger.info(
f'Finished processing pending messages for conversation {conversation_id_str}. '
f'Deleted {deleted_count} messages.'
)
async def update_agent_server_conversation_title(
self,
conversation_id: str,
@@ -1796,6 +1892,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
get_global_config,
get_httpx_client,
get_jwt_service,
get_pending_message_service,
get_sandbox_service,
get_sandbox_spec_service,
get_user_context,
@@ -1815,6 +1912,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
get_event_service(state, request) as event_service,
get_jwt_service(state, request) as jwt_service,
get_httpx_client(state, request) as httpx_client,
get_pending_message_service(state, request) as pending_message_service,
):
access_token_hard_timeout = None
if self.access_token_hard_timeout:
@@ -1859,6 +1957,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
event_callback_service=event_callback_service,
event_service=event_service,
jwt_service=jwt_service,
pending_message_service=pending_message_service,
sandbox_startup_timeout=self.sandbox_startup_timeout,
sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency,
max_num_conversations_per_sandbox=self.max_num_conversations_per_sandbox,

View File

@@ -0,0 +1,39 @@
"""Add pending_messages table for server-side message queuing
Revision ID: 007
Revises: 006
Create Date: 2025-03-15 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '007'
down_revision: Union[str, None] = '006'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create pending_messages table for storing messages before conversation is ready.
Messages are stored temporarily until the conversation becomes ready, then
delivered and deleted regardless of success or failure.
"""
op.create_table(
'pending_messages',
sa.Column('id', sa.String(), primary_key=True),
sa.Column('conversation_id', sa.String(), nullable=False, index=True),
sa.Column('role', sa.String(20), nullable=False, server_default='user'),
sa.Column('content', sa.JSON, nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
)
def downgrade() -> None:
"""Remove pending_messages table."""
op.drop_table('pending_messages')

View File

@@ -33,6 +33,10 @@ from openhands.app_server.event_callback.event_callback_service import (
EventCallbackService,
EventCallbackServiceInjector,
)
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
PendingMessageServiceInjector,
)
from openhands.app_server.sandbox.sandbox_service import (
SandboxService,
SandboxServiceInjector,
@@ -114,6 +118,7 @@ class AppServerConfig(OpenHandsModel):
app_conversation_info: AppConversationInfoServiceInjector | None = None
app_conversation_start_task: AppConversationStartTaskServiceInjector | None = None
app_conversation: AppConversationServiceInjector | None = None
pending_message: PendingMessageServiceInjector | None = None
user: UserContextInjector | None = None
jwt: JwtServiceInjector | None = None
httpx: HttpxClientInjector = Field(default_factory=HttpxClientInjector)
@@ -280,6 +285,13 @@ def config_from_env() -> AppServerConfig:
tavily_api_key=tavily_api_key
)
if config.pending_message is None:
from openhands.app_server.pending_messages.pending_message_service import (
SQLPendingMessageServiceInjector,
)
config.pending_message = SQLPendingMessageServiceInjector()
if config.user is None:
config.user = AuthUserContextInjector()
@@ -358,6 +370,14 @@ def get_app_conversation_service(
return injector.context(state, request)
def get_pending_message_service(
state: InjectorState, request: Request | None = None
) -> AsyncContextManager[PendingMessageService]:
injector = get_global_config().pending_message
assert injector is not None
return injector.context(state, request)
def get_user_context(
state: InjectorState, request: Request | None = None
) -> AsyncContextManager[UserContext]:
@@ -433,6 +453,12 @@ def depends_app_conversation_service():
return Depends(injector.depends)
def depends_pending_message_service():
injector = get_global_config().pending_message
assert injector is not None
return Depends(injector.depends)
def depends_user_context():
injector = get_global_config().user
assert injector is not None

View File

@@ -0,0 +1,21 @@
"""Pending messages module for server-side message queuing."""
from openhands.app_server.pending_messages.pending_message_models import (
PendingMessage,
PendingMessageResponse,
)
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
PendingMessageServiceInjector,
SQLPendingMessageService,
SQLPendingMessageServiceInjector,
)
__all__ = [
'PendingMessage',
'PendingMessageResponse',
'PendingMessageService',
'PendingMessageServiceInjector',
'SQLPendingMessageService',
'SQLPendingMessageServiceInjector',
]

View File

@@ -0,0 +1,32 @@
"""Models for pending message queue functionality."""
from datetime import datetime
from uuid import uuid4
from pydantic import BaseModel, Field
from openhands.agent_server.models import ImageContent, TextContent
from openhands.agent_server.utils import utc_now
class PendingMessage(BaseModel):
"""A message queued for delivery when conversation becomes ready.
Pending messages are stored in the database and delivered to the agent_server
when the conversation transitions to READY status. Messages are deleted after
processing, regardless of success or failure.
"""
id: str = Field(default_factory=lambda: str(uuid4()))
conversation_id: str # Can be task-{uuid} or real conversation UUID
role: str = 'user'
content: list[TextContent | ImageContent]
created_at: datetime = Field(default_factory=utc_now)
class PendingMessageResponse(BaseModel):
"""Response when queueing a pending message."""
id: str
queued: bool
position: int = Field(description='Position in the queue (1-based)')

View File

@@ -0,0 +1,104 @@
"""REST API router for pending messages."""
import logging
from fastapi import APIRouter, HTTPException, Request, status
from pydantic import TypeAdapter, ValidationError
from openhands.agent_server.models import ImageContent, TextContent
from openhands.app_server.config import depends_pending_message_service
from openhands.app_server.pending_messages.pending_message_models import (
PendingMessageResponse,
)
from openhands.app_server.pending_messages.pending_message_service import (
PendingMessageService,
)
from openhands.server.dependencies import get_dependencies
logger = logging.getLogger(__name__)
# Type adapter for validating content from request
_content_type_adapter = TypeAdapter(list[TextContent | ImageContent])
# Create router with authentication dependencies
router = APIRouter(
prefix='/conversations/{conversation_id}/pending-messages',
tags=['Pending Messages'],
dependencies=get_dependencies(),
)
# Create dependency at module level
pending_message_service_dependency = depends_pending_message_service()
@router.post(
'', response_model=PendingMessageResponse, status_code=status.HTTP_201_CREATED
)
async def queue_pending_message(
conversation_id: str,
request: Request,
pending_service: PendingMessageService = pending_message_service_dependency,
) -> PendingMessageResponse:
"""Queue a message for delivery when conversation becomes ready.
This endpoint allows users to submit messages even when the conversation's
WebSocket connection is not yet established. Messages are stored server-side
and delivered automatically when the conversation transitions to READY status.
Args:
conversation_id: The conversation ID (can be task ID before conversation is ready)
request: The FastAPI request containing message content
Returns:
PendingMessageResponse with the message ID and queue position
Raises:
HTTPException 400: If the request body is invalid
HTTPException 429: If too many pending messages are queued (limit: 10)
"""
try:
body = await request.json()
except Exception:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Invalid request body',
)
raw_content = body.get('content')
role = body.get('role', 'user')
if not raw_content or not isinstance(raw_content, list):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='content must be a non-empty list',
)
# Validate and parse content into typed objects
try:
content = _content_type_adapter.validate_python(raw_content)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f'Invalid content format: {e}',
)
# Rate limit: max 10 pending messages per conversation
pending_count = await pending_service.count_pending_messages(conversation_id)
if pending_count >= 10:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail='Too many pending messages. Maximum 10 messages per conversation.',
)
response = await pending_service.add_message(
conversation_id=conversation_id,
content=content,
role=role,
)
logger.info(
f'Queued pending message {response.id} for conversation {conversation_id} '
f'(position: {response.position})'
)
return response

View File

@@ -0,0 +1,200 @@
"""Service for managing pending messages in SQL database."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncGenerator
from fastapi import Request
from pydantic import TypeAdapter
from sqlalchemy import JSON, Column, String, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.agent_server.models import ImageContent, TextContent
from openhands.app_server.pending_messages.pending_message_models import (
PendingMessage,
PendingMessageResponse,
)
from openhands.app_server.services.injector import Injector, InjectorState
from openhands.app_server.utils.sql_utils import Base, UtcDateTime
from openhands.sdk.utils.models import DiscriminatedUnionMixin
# Type adapter for deserializing content from JSON
_content_type_adapter = TypeAdapter(list[TextContent | ImageContent])
class StoredPendingMessage(Base): # type: ignore
"""SQLAlchemy model for pending messages."""
__tablename__ = 'pending_messages'
id = Column(String, primary_key=True)
conversation_id = Column(String, nullable=False, index=True)
role = Column(String(20), nullable=False, default='user')
content = Column(JSON, nullable=False)
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
class PendingMessageService(ABC):
"""Abstract service for managing pending messages."""
@abstractmethod
async def add_message(
self,
conversation_id: str,
content: list[TextContent | ImageContent],
role: str = 'user',
) -> PendingMessageResponse:
"""Queue a message for delivery when conversation becomes ready."""
@abstractmethod
async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]:
"""Get all pending messages for a conversation, ordered by created_at."""
@abstractmethod
async def count_pending_messages(self, conversation_id: str) -> int:
"""Count pending messages for a conversation."""
@abstractmethod
async def delete_messages_for_conversation(self, conversation_id: str) -> int:
"""Delete all pending messages for a conversation, returning count deleted."""
@abstractmethod
async def update_conversation_id(
self, old_conversation_id: str, new_conversation_id: str
) -> int:
"""Update conversation_id when task-id transitions to real conversation-id.
Returns the number of messages updated.
"""
@dataclass
class SQLPendingMessageService(PendingMessageService):
"""SQL implementation of PendingMessageService."""
db_session: AsyncSession
async def add_message(
self,
conversation_id: str,
content: list[TextContent | ImageContent],
role: str = 'user',
) -> PendingMessageResponse:
"""Queue a message for delivery when conversation becomes ready."""
# Create the pending message
pending_message = PendingMessage(
conversation_id=conversation_id,
role=role,
content=content,
)
# Count existing pending messages for position
count_stmt = select(func.count()).where(
StoredPendingMessage.conversation_id == conversation_id
)
result = await self.db_session.execute(count_stmt)
position = result.scalar() or 0
# Serialize content to JSON-compatible format for storage
content_json = [item.model_dump() for item in content]
# Store in database
stored_message = StoredPendingMessage(
id=str(pending_message.id),
conversation_id=conversation_id,
role=role,
content=content_json,
created_at=pending_message.created_at,
)
self.db_session.add(stored_message)
await self.db_session.commit()
return PendingMessageResponse(
id=pending_message.id,
queued=True,
position=position + 1,
)
async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]:
"""Get all pending messages for a conversation, ordered by created_at."""
stmt = (
select(StoredPendingMessage)
.where(StoredPendingMessage.conversation_id == conversation_id)
.order_by(StoredPendingMessage.created_at.asc())
)
result = await self.db_session.execute(stmt)
stored_messages = result.scalars().all()
return [
PendingMessage(
id=msg.id,
conversation_id=msg.conversation_id,
role=msg.role,
content=_content_type_adapter.validate_python(msg.content),
created_at=msg.created_at,
)
for msg in stored_messages
]
async def count_pending_messages(self, conversation_id: str) -> int:
"""Count pending messages for a conversation."""
count_stmt = select(func.count()).where(
StoredPendingMessage.conversation_id == conversation_id
)
result = await self.db_session.execute(count_stmt)
return result.scalar() or 0
async def delete_messages_for_conversation(self, conversation_id: str) -> int:
"""Delete all pending messages for a conversation, returning count deleted."""
stmt = select(StoredPendingMessage).where(
StoredPendingMessage.conversation_id == conversation_id
)
result = await self.db_session.execute(stmt)
stored_messages = result.scalars().all()
count = len(stored_messages)
for msg in stored_messages:
await self.db_session.delete(msg)
if count > 0:
await self.db_session.commit()
return count
async def update_conversation_id(
self, old_conversation_id: str, new_conversation_id: str
) -> int:
"""Update conversation_id when task-id transitions to real conversation-id."""
stmt = select(StoredPendingMessage).where(
StoredPendingMessage.conversation_id == old_conversation_id
)
result = await self.db_session.execute(stmt)
stored_messages = result.scalars().all()
count = len(stored_messages)
for msg in stored_messages:
msg.conversation_id = new_conversation_id
if count > 0:
await self.db_session.commit()
return count
class PendingMessageServiceInjector(
DiscriminatedUnionMixin, Injector[PendingMessageService], ABC
):
"""Abstract injector for PendingMessageService."""
pass
class SQLPendingMessageServiceInjector(PendingMessageServiceInjector):
"""SQL-based injector for PendingMessageService."""
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[PendingMessageService, None]:
from openhands.app_server.config import get_db_session
async with get_db_session(state) as db_session:
yield SQLPendingMessageService(db_session=db_session)

View File

@@ -5,6 +5,9 @@ from openhands.app_server.event import event_router
from openhands.app_server.event_callback import (
webhook_router,
)
from openhands.app_server.pending_messages.pending_message_router import (
router as pending_message_router,
)
from openhands.app_server.sandbox import sandbox_router, sandbox_spec_router
from openhands.app_server.user import user_router
from openhands.app_server.web_client import web_client_router
@@ -13,6 +16,7 @@ from openhands.app_server.web_client import web_client_router
router = APIRouter(prefix='/api/v1')
router.include_router(event_router.router)
router.include_router(app_conversation_router.router)
router.include_router(pending_message_router)
router.include_router(sandbox_router.router)
router.include_router(sandbox_spec_router.router)
router.include_router(user_router.router)