"""SQL implementation of SharedConversationInfoService. This implementation provides read-only access to shared conversations: - Direct database access without user permission checks - Filters only conversations marked as shared (currently public) - Full async/await support using SQL async db_sessions """ from __future__ import annotations import logging from dataclasses import dataclass from datetime import UTC, datetime from typing import AsyncGenerator from uuid import UUID from fastapi import Request from server.sharing.shared_conversation_info_service import ( SharedConversationInfoService, SharedConversationInfoServiceInjector, ) from server.sharing.shared_conversation_models import ( SharedConversation, SharedConversationPage, SharedConversationSortOrder, ) from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from openhands.app_server.app_conversation.sql_app_conversation_info_service import ( StoredConversationMetadata, ) from openhands.app_server.services.injector import InjectorState from openhands.integrations.provider import ProviderType from openhands.sdk.llm import MetricsSnapshot from openhands.sdk.llm.utils.metrics import TokenUsage logger = logging.getLogger(__name__) @dataclass class SQLSharedConversationInfoService(SharedConversationInfoService): """SQL implementation of SharedConversationInfoService for shared conversations only.""" db_session: AsyncSession async def search_shared_conversation_info( self, title__contains: str | None = None, created_at__gte: datetime | None = None, created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 100, include_sub_conversations: bool = False, ) -> SharedConversationPage: """Search for shared conversations.""" query = self._public_select() # Conditionally exclude sub-conversations based on the parameter if not include_sub_conversations: # Exclude sub-conversations (only include top-level conversations) query = query.where( StoredConversationMetadata.parent_conversation_id.is_(None) ) query = self._apply_filters( query=query, title__contains=title__contains, created_at__gte=created_at__gte, created_at__lt=created_at__lt, updated_at__gte=updated_at__gte, updated_at__lt=updated_at__lt, ) # Add sort order if sort_order == SharedConversationSortOrder.CREATED_AT: query = query.order_by(StoredConversationMetadata.created_at) elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC: query = query.order_by(StoredConversationMetadata.created_at.desc()) elif sort_order == SharedConversationSortOrder.UPDATED_AT: query = query.order_by(StoredConversationMetadata.last_updated_at) elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC: query = query.order_by(StoredConversationMetadata.last_updated_at.desc()) elif sort_order == SharedConversationSortOrder.TITLE: query = query.order_by(StoredConversationMetadata.title) elif sort_order == SharedConversationSortOrder.TITLE_DESC: query = query.order_by(StoredConversationMetadata.title.desc()) # Apply pagination if page_id is not None: try: offset = int(page_id) query = query.offset(offset) except ValueError: # If page_id is not a valid integer, start from beginning offset = 0 else: offset = 0 # Apply limit and get one extra to check if there are more results query = query.limit(limit + 1) result = await self.db_session.execute(query) rows = result.scalars().all() # Check if there are more results has_more = len(rows) > limit if has_more: rows = rows[:limit] items = [self._to_shared_conversation(row) for row in rows] # Calculate next page ID next_page_id = None if has_more: next_page_id = str(offset + limit) return SharedConversationPage(items=items, next_page_id=next_page_id) async def count_shared_conversation_info( self, title__contains: str | None = None, created_at__gte: datetime | None = None, created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, ) -> int: """Count shared conversations matching the given filters.""" from sqlalchemy import func query = select(func.count(StoredConversationMetadata.conversation_id)) # Only include shared conversations query = query.where(StoredConversationMetadata.public == True) # noqa: E712 query = query.where(StoredConversationMetadata.conversation_version == 'V1') query = self._apply_filters( query=query, title__contains=title__contains, created_at__gte=created_at__gte, created_at__lt=created_at__lt, updated_at__gte=updated_at__gte, updated_at__lt=updated_at__lt, ) result = await self.db_session.execute(query) return result.scalar() or 0 async def get_shared_conversation_info( self, conversation_id: UUID ) -> SharedConversation | None: """Get a single public conversation info, returning None if missing or not shared.""" query = self._public_select().where( StoredConversationMetadata.conversation_id == str(conversation_id) ) result = await self.db_session.execute(query) stored = result.scalar_one_or_none() if stored is None: return None return self._to_shared_conversation(stored) def _public_select(self): """Create a select query that only returns public conversations.""" query = select(StoredConversationMetadata).where( StoredConversationMetadata.conversation_version == 'V1' ) # Only include conversations marked as public query = query.where(StoredConversationMetadata.public == True) # noqa: E712 return query def _apply_filters( self, query, title__contains: str | None = None, created_at__gte: datetime | None = None, created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, ): """Apply common filters to a query.""" if title__contains is not None: query = query.where( StoredConversationMetadata.title.contains(title__contains) ) if created_at__gte is not None: query = query.where( StoredConversationMetadata.created_at >= created_at__gte ) if created_at__lt is not None: query = query.where(StoredConversationMetadata.created_at < created_at__lt) if updated_at__gte is not None: query = query.where( StoredConversationMetadata.last_updated_at >= updated_at__gte ) if updated_at__lt is not None: query = query.where( StoredConversationMetadata.last_updated_at < updated_at__lt ) return query def _to_shared_conversation( self, stored: StoredConversationMetadata, sub_conversation_ids: list[UUID] | None = None, ) -> SharedConversation: """Convert StoredConversationMetadata to SharedConversation.""" # V1 conversations should always have a sandbox_id sandbox_id = stored.sandbox_id assert sandbox_id is not None # Rebuild token usage token_usage = TokenUsage( prompt_tokens=stored.prompt_tokens, completion_tokens=stored.completion_tokens, cache_read_tokens=stored.cache_read_tokens, cache_write_tokens=stored.cache_write_tokens, context_window=stored.context_window, per_turn_token=stored.per_turn_token, ) # Rebuild metrics object metrics = MetricsSnapshot( accumulated_cost=stored.accumulated_cost, max_budget_per_task=stored.max_budget_per_task, accumulated_token_usage=token_usage, ) # Get timestamps created_at = self._fix_timezone(stored.created_at) updated_at = self._fix_timezone(stored.last_updated_at) return SharedConversation( id=UUID(stored.conversation_id), created_by_user_id=stored.user_id if stored.user_id else None, sandbox_id=stored.sandbox_id, selected_repository=stored.selected_repository, selected_branch=stored.selected_branch, git_provider=( ProviderType(stored.git_provider) if stored.git_provider else None ), title=stored.title, pr_number=stored.pr_number, llm_model=stored.llm_model, metrics=metrics, parent_conversation_id=( UUID(stored.parent_conversation_id) if stored.parent_conversation_id else None ), sub_conversation_ids=sub_conversation_ids or [], created_at=created_at, updated_at=updated_at, ) def _fix_timezone(self, value: datetime) -> datetime: """Sqlite does not store timezones - and since we can't update the existing models we assume UTC if the timezone is missing.""" if not value.tzinfo: value = value.replace(tzinfo=UTC) return value class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector): async def inject( self, state: InjectorState, request: Request | None = None ) -> AsyncGenerator[SharedConversationInfoService, None]: # Define inline to prevent circular lookup from openhands.app_server.config import get_db_session async with get_db_session(state, request) as db_session: service = SQLSharedConversationInfoService(db_session=db_session) yield service