OpenHands/enterprise/server/sharing/sql_shared_conversation_info_service.py
2025-12-22 19:27:58 -07:00

283 lines
10 KiB
Python

"""SQL implementation of SharedConversationInfoService.
This implementation provides read-only access to shared conversations:
- Direct database access without user permission checks
- Filters only conversations marked as shared (currently public)
- Full async/await support using SQL async db_sessions
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from openhands.app_server.services.injector import InjectorState
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
SharedConversationInfoServiceInjector,
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
logger = logging.getLogger(__name__)
@dataclass
class SQLSharedConversationInfoService(SharedConversationInfoService):
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
db_session: AsyncSession
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
query = self._public_select()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
# Exclude sub-conversations (only include top-level conversations)
query = query.where(
StoredConversationMetadata.parent_conversation_id.is_(None)
)
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
# Add sort order
if sort_order == SharedConversationSortOrder.CREATED_AT:
query = query.order_by(StoredConversationMetadata.created_at)
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.created_at.desc())
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
query = query.order_by(StoredConversationMetadata.last_updated_at)
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
elif sort_order == SharedConversationSortOrder.TITLE:
query = query.order_by(StoredConversationMetadata.title)
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
query = query.order_by(StoredConversationMetadata.title.desc())
# Apply pagination
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Apply limit and get one extra to check if there are more results
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.scalars().all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [self._to_shared_conversation(row) for row in rows]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
return SharedConversationPage(items=items, next_page_id=next_page_id)
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations matching the given filters."""
from sqlalchemy import func
query = select(func.count(StoredConversationMetadata.conversation_id))
# Only include shared conversations
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
result = await self.db_session.execute(query)
return result.scalar() or 0
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single public conversation info, returning None if missing or not shared."""
query = self._public_select().where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result = await self.db_session.execute(query)
stored = result.scalar_one_or_none()
if stored is None:
return None
return self._to_shared_conversation(stored)
def _public_select(self):
"""Create a select query that only returns public conversations."""
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'
)
# Only include conversations marked as public
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
return query
def _apply_filters(
self,
query,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
):
"""Apply common filters to a query."""
if title__contains is not None:
query = query.where(
StoredConversationMetadata.title.contains(title__contains)
)
if created_at__gte is not None:
query = query.where(
StoredConversationMetadata.created_at >= created_at__gte
)
if created_at__lt is not None:
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
if updated_at__gte is not None:
query = query.where(
StoredConversationMetadata.last_updated_at >= updated_at__gte
)
if updated_at__lt is not None:
query = query.where(
StoredConversationMetadata.last_updated_at < updated_at__lt
)
return query
def _to_shared_conversation(
self,
stored: StoredConversationMetadata,
sub_conversation_ids: list[UUID] | None = None,
) -> SharedConversation:
"""Convert StoredConversationMetadata to SharedConversation."""
# V1 conversations should always have a sandbox_id
sandbox_id = stored.sandbox_id
assert sandbox_id is not None
# Rebuild token usage
token_usage = TokenUsage(
prompt_tokens=stored.prompt_tokens,
completion_tokens=stored.completion_tokens,
cache_read_tokens=stored.cache_read_tokens,
cache_write_tokens=stored.cache_write_tokens,
context_window=stored.context_window,
per_turn_token=stored.per_turn_token,
)
# Rebuild metrics object
metrics = MetricsSnapshot(
accumulated_cost=stored.accumulated_cost,
max_budget_per_task=stored.max_budget_per_task,
accumulated_token_usage=token_usage,
)
# Get timestamps
created_at = self._fix_timezone(stored.created_at)
updated_at = self._fix_timezone(stored.last_updated_at)
return SharedConversation(
id=UUID(stored.conversation_id),
created_by_user_id=stored.user_id if stored.user_id else None,
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,
git_provider=(
ProviderType(stored.git_provider) if stored.git_provider else None
),
title=stored.title,
pr_number=stored.pr_number,
llm_model=stored.llm_model,
metrics=metrics,
parent_conversation_id=(
UUID(stored.parent_conversation_id)
if stored.parent_conversation_id
else None
),
sub_conversation_ids=sub_conversation_ids or [],
created_at=created_at,
updated_at=updated_at,
)
def _fix_timezone(self, value: datetime) -> datetime:
"""Sqlite does not store timezones - and since we can't update the existing models
we assume UTC if the timezone is missing."""
if not value.tzinfo:
value = value.replace(tzinfo=UTC)
return value
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedConversationInfoService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import get_db_session
async with get_db_session(state, request) as db_session:
service = SQLSharedConversationInfoService(db_session=db_session)
yield service