mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
fix(frontend): prevent chat message loss during websocket disconnections or page refresh (#13380)
This commit is contained in:
@@ -59,6 +59,9 @@ from openhands.app_server.event_callback.event_callback_service import (
|
||||
from openhands.app_server.event_callback.set_title_callback_processor import (
|
||||
SetTitleCallbackProcessor,
|
||||
)
|
||||
from openhands.app_server.pending_messages.pending_message_service import (
|
||||
PendingMessageService,
|
||||
)
|
||||
from openhands.app_server.sandbox.docker_sandbox_service import DockerSandboxService
|
||||
from openhands.app_server.sandbox.sandbox_models import (
|
||||
AGENT_SERVER,
|
||||
@@ -127,6 +130,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
sandbox_service: SandboxService
|
||||
sandbox_spec_service: SandboxSpecService
|
||||
jwt_service: JwtService
|
||||
pending_message_service: PendingMessageService
|
||||
sandbox_startup_timeout: int
|
||||
sandbox_startup_poll_frequency: int
|
||||
max_num_conversations_per_sandbox: int
|
||||
@@ -373,6 +377,15 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
task.app_conversation_id = info.id
|
||||
yield task
|
||||
|
||||
# Process any pending messages queued while waiting for conversation
|
||||
if sandbox.session_api_key:
|
||||
await self._process_pending_messages(
|
||||
task_id=task.id,
|
||||
conversation_id=info.id,
|
||||
agent_server_url=agent_server_url,
|
||||
session_api_key=sandbox.session_api_key,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
_logger.exception('Error starting conversation', stack_info=True)
|
||||
task.status = AppConversationStartTaskStatus.ERROR
|
||||
@@ -1424,6 +1437,89 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
async def _process_pending_messages(
|
||||
self,
|
||||
task_id: UUID,
|
||||
conversation_id: UUID,
|
||||
agent_server_url: str,
|
||||
session_api_key: str,
|
||||
) -> None:
|
||||
"""Process pending messages queued before conversation was ready.
|
||||
|
||||
Messages are delivered concurrently to the agent server. After processing,
|
||||
all messages are deleted from the database regardless of success or failure.
|
||||
|
||||
Args:
|
||||
task_id: The start task ID (may have been used as conversation_id initially)
|
||||
conversation_id: The real conversation ID
|
||||
agent_server_url: URL of the agent server
|
||||
session_api_key: API key for authenticating with agent server
|
||||
"""
|
||||
# Convert UUIDs to strings for the pending message service
|
||||
# The frontend uses task-{uuid.hex} format (no hyphens), matching OpenHandsUUID serialization
|
||||
task_id_str = f'task-{task_id.hex}'
|
||||
# conversation_id uses standard format (with hyphens) for agent server API compatibility
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
_logger.info(f'task_id={task_id_str} conversation_id={conversation_id_str}')
|
||||
|
||||
# First, update any messages that were queued with the task_id
|
||||
updated_count = await self.pending_message_service.update_conversation_id(
|
||||
old_conversation_id=task_id_str,
|
||||
new_conversation_id=conversation_id_str,
|
||||
)
|
||||
_logger.info(f'updated_count={updated_count} ')
|
||||
if updated_count > 0:
|
||||
_logger.info(
|
||||
f'Updated {updated_count} pending messages from task_id={task_id_str} '
|
||||
f'to conversation_id={conversation_id_str}'
|
||||
)
|
||||
|
||||
# Get all pending messages for this conversation
|
||||
pending_messages = await self.pending_message_service.get_pending_messages(
|
||||
conversation_id_str
|
||||
)
|
||||
|
||||
if not pending_messages:
|
||||
return
|
||||
|
||||
_logger.info(
|
||||
f'Processing {len(pending_messages)} pending messages for '
|
||||
f'conversation {conversation_id_str}'
|
||||
)
|
||||
|
||||
# Process messages sequentially to preserve order
|
||||
for msg in pending_messages:
|
||||
try:
|
||||
# Serialize content objects to JSON-compatible dicts
|
||||
content_json = [item.model_dump() for item in msg.content]
|
||||
# Use the events endpoint which handles message sending
|
||||
response = await self.httpx_client.post(
|
||||
f'{agent_server_url}/api/conversations/{conversation_id_str}/events',
|
||||
json={
|
||||
'role': msg.role,
|
||||
'content': content_json,
|
||||
'run': True,
|
||||
},
|
||||
headers={'X-Session-API-Key': session_api_key},
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
_logger.debug(f'Delivered pending message {msg.id}')
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to deliver pending message {msg.id}: {e}')
|
||||
|
||||
# Delete all pending messages after processing (regardless of success/failure)
|
||||
deleted_count = (
|
||||
await self.pending_message_service.delete_messages_for_conversation(
|
||||
conversation_id_str
|
||||
)
|
||||
)
|
||||
_logger.info(
|
||||
f'Finished processing pending messages for conversation {conversation_id_str}. '
|
||||
f'Deleted {deleted_count} messages.'
|
||||
)
|
||||
|
||||
async def update_agent_server_conversation_title(
|
||||
self,
|
||||
conversation_id: str,
|
||||
@@ -1796,6 +1892,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
||||
get_global_config,
|
||||
get_httpx_client,
|
||||
get_jwt_service,
|
||||
get_pending_message_service,
|
||||
get_sandbox_service,
|
||||
get_sandbox_spec_service,
|
||||
get_user_context,
|
||||
@@ -1815,6 +1912,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
||||
get_event_service(state, request) as event_service,
|
||||
get_jwt_service(state, request) as jwt_service,
|
||||
get_httpx_client(state, request) as httpx_client,
|
||||
get_pending_message_service(state, request) as pending_message_service,
|
||||
):
|
||||
access_token_hard_timeout = None
|
||||
if self.access_token_hard_timeout:
|
||||
@@ -1859,6 +1957,7 @@ class LiveStatusAppConversationServiceInjector(AppConversationServiceInjector):
|
||||
event_callback_service=event_callback_service,
|
||||
event_service=event_service,
|
||||
jwt_service=jwt_service,
|
||||
pending_message_service=pending_message_service,
|
||||
sandbox_startup_timeout=self.sandbox_startup_timeout,
|
||||
sandbox_startup_poll_frequency=self.sandbox_startup_poll_frequency,
|
||||
max_num_conversations_per_sandbox=self.max_num_conversations_per_sandbox,
|
||||
|
||||
39
openhands/app_server/app_lifespan/alembic/versions/007.py
Normal file
39
openhands/app_server/app_lifespan/alembic/versions/007.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Add pending_messages table for server-side message queuing
|
||||
|
||||
Revision ID: 007
|
||||
Revises: 006
|
||||
Create Date: 2025-03-15 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '007'
|
||||
down_revision: Union[str, None] = '006'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create pending_messages table for storing messages before conversation is ready.
|
||||
|
||||
Messages are stored temporarily until the conversation becomes ready, then
|
||||
delivered and deleted regardless of success or failure.
|
||||
"""
|
||||
op.create_table(
|
||||
'pending_messages',
|
||||
sa.Column('id', sa.String(), primary_key=True),
|
||||
sa.Column('conversation_id', sa.String(), nullable=False, index=True),
|
||||
sa.Column('role', sa.String(20), nullable=False, server_default='user'),
|
||||
sa.Column('content', sa.JSON, nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove pending_messages table."""
|
||||
op.drop_table('pending_messages')
|
||||
@@ -33,6 +33,10 @@ from openhands.app_server.event_callback.event_callback_service import (
|
||||
EventCallbackService,
|
||||
EventCallbackServiceInjector,
|
||||
)
|
||||
from openhands.app_server.pending_messages.pending_message_service import (
|
||||
PendingMessageService,
|
||||
PendingMessageServiceInjector,
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_service import (
|
||||
SandboxService,
|
||||
SandboxServiceInjector,
|
||||
@@ -114,6 +118,7 @@ class AppServerConfig(OpenHandsModel):
|
||||
app_conversation_info: AppConversationInfoServiceInjector | None = None
|
||||
app_conversation_start_task: AppConversationStartTaskServiceInjector | None = None
|
||||
app_conversation: AppConversationServiceInjector | None = None
|
||||
pending_message: PendingMessageServiceInjector | None = None
|
||||
user: UserContextInjector | None = None
|
||||
jwt: JwtServiceInjector | None = None
|
||||
httpx: HttpxClientInjector = Field(default_factory=HttpxClientInjector)
|
||||
@@ -280,6 +285,13 @@ def config_from_env() -> AppServerConfig:
|
||||
tavily_api_key=tavily_api_key
|
||||
)
|
||||
|
||||
if config.pending_message is None:
|
||||
from openhands.app_server.pending_messages.pending_message_service import (
|
||||
SQLPendingMessageServiceInjector,
|
||||
)
|
||||
|
||||
config.pending_message = SQLPendingMessageServiceInjector()
|
||||
|
||||
if config.user is None:
|
||||
config.user = AuthUserContextInjector()
|
||||
|
||||
@@ -358,6 +370,14 @@ def get_app_conversation_service(
|
||||
return injector.context(state, request)
|
||||
|
||||
|
||||
def get_pending_message_service(
|
||||
state: InjectorState, request: Request | None = None
|
||||
) -> AsyncContextManager[PendingMessageService]:
|
||||
injector = get_global_config().pending_message
|
||||
assert injector is not None
|
||||
return injector.context(state, request)
|
||||
|
||||
|
||||
def get_user_context(
|
||||
state: InjectorState, request: Request | None = None
|
||||
) -> AsyncContextManager[UserContext]:
|
||||
@@ -433,6 +453,12 @@ def depends_app_conversation_service():
|
||||
return Depends(injector.depends)
|
||||
|
||||
|
||||
def depends_pending_message_service():
|
||||
injector = get_global_config().pending_message
|
||||
assert injector is not None
|
||||
return Depends(injector.depends)
|
||||
|
||||
|
||||
def depends_user_context():
|
||||
injector = get_global_config().user
|
||||
assert injector is not None
|
||||
|
||||
21
openhands/app_server/pending_messages/__init__.py
Normal file
21
openhands/app_server/pending_messages/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Pending messages module for server-side message queuing."""
|
||||
|
||||
from openhands.app_server.pending_messages.pending_message_models import (
|
||||
PendingMessage,
|
||||
PendingMessageResponse,
|
||||
)
|
||||
from openhands.app_server.pending_messages.pending_message_service import (
|
||||
PendingMessageService,
|
||||
PendingMessageServiceInjector,
|
||||
SQLPendingMessageService,
|
||||
SQLPendingMessageServiceInjector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'PendingMessage',
|
||||
'PendingMessageResponse',
|
||||
'PendingMessageService',
|
||||
'PendingMessageServiceInjector',
|
||||
'SQLPendingMessageService',
|
||||
'SQLPendingMessageServiceInjector',
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Models for pending message queue functionality."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.agent_server.models import ImageContent, TextContent
|
||||
from openhands.agent_server.utils import utc_now
|
||||
|
||||
|
||||
class PendingMessage(BaseModel):
|
||||
"""A message queued for delivery when conversation becomes ready.
|
||||
|
||||
Pending messages are stored in the database and delivered to the agent_server
|
||||
when the conversation transitions to READY status. Messages are deleted after
|
||||
processing, regardless of success or failure.
|
||||
"""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
conversation_id: str # Can be task-{uuid} or real conversation UUID
|
||||
role: str = 'user'
|
||||
content: list[TextContent | ImageContent]
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
|
||||
class PendingMessageResponse(BaseModel):
|
||||
"""Response when queueing a pending message."""
|
||||
|
||||
id: str
|
||||
queued: bool
|
||||
position: int = Field(description='Position in the queue (1-based)')
|
||||
104
openhands/app_server/pending_messages/pending_message_router.py
Normal file
104
openhands/app_server/pending_messages/pending_message_router.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""REST API router for pending messages."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, status
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from openhands.agent_server.models import ImageContent, TextContent
|
||||
from openhands.app_server.config import depends_pending_message_service
|
||||
from openhands.app_server.pending_messages.pending_message_models import (
|
||||
PendingMessageResponse,
|
||||
)
|
||||
from openhands.app_server.pending_messages.pending_message_service import (
|
||||
PendingMessageService,
|
||||
)
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type adapter for validating content from request
|
||||
_content_type_adapter = TypeAdapter(list[TextContent | ImageContent])
|
||||
|
||||
# Create router with authentication dependencies
|
||||
router = APIRouter(
|
||||
prefix='/conversations/{conversation_id}/pending-messages',
|
||||
tags=['Pending Messages'],
|
||||
dependencies=get_dependencies(),
|
||||
)
|
||||
|
||||
# Create dependency at module level
|
||||
pending_message_service_dependency = depends_pending_message_service()
|
||||
|
||||
|
||||
@router.post(
|
||||
'', response_model=PendingMessageResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def queue_pending_message(
|
||||
conversation_id: str,
|
||||
request: Request,
|
||||
pending_service: PendingMessageService = pending_message_service_dependency,
|
||||
) -> PendingMessageResponse:
|
||||
"""Queue a message for delivery when conversation becomes ready.
|
||||
|
||||
This endpoint allows users to submit messages even when the conversation's
|
||||
WebSocket connection is not yet established. Messages are stored server-side
|
||||
and delivered automatically when the conversation transitions to READY status.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID (can be task ID before conversation is ready)
|
||||
request: The FastAPI request containing message content
|
||||
|
||||
Returns:
|
||||
PendingMessageResponse with the message ID and queue position
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If the request body is invalid
|
||||
HTTPException 429: If too many pending messages are queued (limit: 10)
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid request body',
|
||||
)
|
||||
|
||||
raw_content = body.get('content')
|
||||
role = body.get('role', 'user')
|
||||
|
||||
if not raw_content or not isinstance(raw_content, list):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='content must be a non-empty list',
|
||||
)
|
||||
|
||||
# Validate and parse content into typed objects
|
||||
try:
|
||||
content = _content_type_adapter.validate_python(raw_content)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f'Invalid content format: {e}',
|
||||
)
|
||||
|
||||
# Rate limit: max 10 pending messages per conversation
|
||||
pending_count = await pending_service.count_pending_messages(conversation_id)
|
||||
if pending_count >= 10:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail='Too many pending messages. Maximum 10 messages per conversation.',
|
||||
)
|
||||
|
||||
response = await pending_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=content,
|
||||
role=role,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Queued pending message {response.id} for conversation {conversation_id} '
|
||||
f'(position: {response.position})'
|
||||
)
|
||||
|
||||
return response
|
||||
200
openhands/app_server/pending_messages/pending_message_service.py
Normal file
200
openhands/app_server/pending_messages/pending_message_service.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Service for managing pending messages in SQL database."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import JSON, Column, String, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from openhands.agent_server.models import ImageContent, TextContent
|
||||
from openhands.app_server.pending_messages.pending_message_models import (
|
||||
PendingMessage,
|
||||
PendingMessageResponse,
|
||||
)
|
||||
from openhands.app_server.services.injector import Injector, InjectorState
|
||||
from openhands.app_server.utils.sql_utils import Base, UtcDateTime
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
# Type adapter for deserializing content from JSON
|
||||
_content_type_adapter = TypeAdapter(list[TextContent | ImageContent])
|
||||
|
||||
|
||||
class StoredPendingMessage(Base): # type: ignore
|
||||
"""SQLAlchemy model for pending messages."""
|
||||
|
||||
__tablename__ = 'pending_messages'
|
||||
id = Column(String, primary_key=True)
|
||||
conversation_id = Column(String, nullable=False, index=True)
|
||||
role = Column(String(20), nullable=False, default='user')
|
||||
content = Column(JSON, nullable=False)
|
||||
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
|
||||
|
||||
|
||||
class PendingMessageService(ABC):
|
||||
"""Abstract service for managing pending messages."""
|
||||
|
||||
@abstractmethod
|
||||
async def add_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
content: list[TextContent | ImageContent],
|
||||
role: str = 'user',
|
||||
) -> PendingMessageResponse:
|
||||
"""Queue a message for delivery when conversation becomes ready."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]:
|
||||
"""Get all pending messages for a conversation, ordered by created_at."""
|
||||
|
||||
@abstractmethod
|
||||
async def count_pending_messages(self, conversation_id: str) -> int:
|
||||
"""Count pending messages for a conversation."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_messages_for_conversation(self, conversation_id: str) -> int:
|
||||
"""Delete all pending messages for a conversation, returning count deleted."""
|
||||
|
||||
@abstractmethod
|
||||
async def update_conversation_id(
|
||||
self, old_conversation_id: str, new_conversation_id: str
|
||||
) -> int:
|
||||
"""Update conversation_id when task-id transitions to real conversation-id.
|
||||
|
||||
Returns the number of messages updated.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLPendingMessageService(PendingMessageService):
|
||||
"""SQL implementation of PendingMessageService."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
content: list[TextContent | ImageContent],
|
||||
role: str = 'user',
|
||||
) -> PendingMessageResponse:
|
||||
"""Queue a message for delivery when conversation becomes ready."""
|
||||
# Create the pending message
|
||||
pending_message = PendingMessage(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
)
|
||||
|
||||
# Count existing pending messages for position
|
||||
count_stmt = select(func.count()).where(
|
||||
StoredPendingMessage.conversation_id == conversation_id
|
||||
)
|
||||
result = await self.db_session.execute(count_stmt)
|
||||
position = result.scalar() or 0
|
||||
|
||||
# Serialize content to JSON-compatible format for storage
|
||||
content_json = [item.model_dump() for item in content]
|
||||
|
||||
# Store in database
|
||||
stored_message = StoredPendingMessage(
|
||||
id=str(pending_message.id),
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content_json,
|
||||
created_at=pending_message.created_at,
|
||||
)
|
||||
self.db_session.add(stored_message)
|
||||
await self.db_session.commit()
|
||||
|
||||
return PendingMessageResponse(
|
||||
id=pending_message.id,
|
||||
queued=True,
|
||||
position=position + 1,
|
||||
)
|
||||
|
||||
async def get_pending_messages(self, conversation_id: str) -> list[PendingMessage]:
|
||||
"""Get all pending messages for a conversation, ordered by created_at."""
|
||||
stmt = (
|
||||
select(StoredPendingMessage)
|
||||
.where(StoredPendingMessage.conversation_id == conversation_id)
|
||||
.order_by(StoredPendingMessage.created_at.asc())
|
||||
)
|
||||
result = await self.db_session.execute(stmt)
|
||||
stored_messages = result.scalars().all()
|
||||
|
||||
return [
|
||||
PendingMessage(
|
||||
id=msg.id,
|
||||
conversation_id=msg.conversation_id,
|
||||
role=msg.role,
|
||||
content=_content_type_adapter.validate_python(msg.content),
|
||||
created_at=msg.created_at,
|
||||
)
|
||||
for msg in stored_messages
|
||||
]
|
||||
|
||||
async def count_pending_messages(self, conversation_id: str) -> int:
|
||||
"""Count pending messages for a conversation."""
|
||||
count_stmt = select(func.count()).where(
|
||||
StoredPendingMessage.conversation_id == conversation_id
|
||||
)
|
||||
result = await self.db_session.execute(count_stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def delete_messages_for_conversation(self, conversation_id: str) -> int:
|
||||
"""Delete all pending messages for a conversation, returning count deleted."""
|
||||
stmt = select(StoredPendingMessage).where(
|
||||
StoredPendingMessage.conversation_id == conversation_id
|
||||
)
|
||||
result = await self.db_session.execute(stmt)
|
||||
stored_messages = result.scalars().all()
|
||||
|
||||
count = len(stored_messages)
|
||||
for msg in stored_messages:
|
||||
await self.db_session.delete(msg)
|
||||
|
||||
if count > 0:
|
||||
await self.db_session.commit()
|
||||
|
||||
return count
|
||||
|
||||
async def update_conversation_id(
|
||||
self, old_conversation_id: str, new_conversation_id: str
|
||||
) -> int:
|
||||
"""Update conversation_id when task-id transitions to real conversation-id."""
|
||||
stmt = select(StoredPendingMessage).where(
|
||||
StoredPendingMessage.conversation_id == old_conversation_id
|
||||
)
|
||||
result = await self.db_session.execute(stmt)
|
||||
stored_messages = result.scalars().all()
|
||||
|
||||
count = len(stored_messages)
|
||||
for msg in stored_messages:
|
||||
msg.conversation_id = new_conversation_id
|
||||
|
||||
if count > 0:
|
||||
await self.db_session.commit()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class PendingMessageServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[PendingMessageService], ABC
|
||||
):
|
||||
"""Abstract injector for PendingMessageService."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SQLPendingMessageServiceInjector(PendingMessageServiceInjector):
|
||||
"""SQL-based injector for PendingMessageService."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[PendingMessageService, None]:
|
||||
from openhands.app_server.config import get_db_session
|
||||
|
||||
async with get_db_session(state) as db_session:
|
||||
yield SQLPendingMessageService(db_session=db_session)
|
||||
@@ -5,6 +5,9 @@ from openhands.app_server.event import event_router
|
||||
from openhands.app_server.event_callback import (
|
||||
webhook_router,
|
||||
)
|
||||
from openhands.app_server.pending_messages.pending_message_router import (
|
||||
router as pending_message_router,
|
||||
)
|
||||
from openhands.app_server.sandbox import sandbox_router, sandbox_spec_router
|
||||
from openhands.app_server.user import user_router
|
||||
from openhands.app_server.web_client import web_client_router
|
||||
@@ -13,6 +16,7 @@ from openhands.app_server.web_client import web_client_router
|
||||
router = APIRouter(prefix='/api/v1')
|
||||
router.include_router(event_router.router)
|
||||
router.include_router(app_conversation_router.router)
|
||||
router.include_router(pending_message_router)
|
||||
router.include_router(sandbox_router.router)
|
||||
router.include_router(sandbox_spec_router.router)
|
||||
router.include_router(user_router.router)
|
||||
|
||||
Reference in New Issue
Block a user