mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
419 lines
16 KiB
Python
419 lines
16 KiB
Python
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
|
|
|
from datetime import datetime
|
|
from typing import AsyncGenerator
|
|
from uuid import UUID
|
|
|
|
from fastapi import Request
|
|
from sqlalchemy import func, select
|
|
from storage.stored_conversation_metadata import StoredConversationMetadata
|
|
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
|
from storage.user import User
|
|
|
|
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
|
AppConversationInfoService,
|
|
AppConversationInfoServiceInjector,
|
|
)
|
|
from openhands.app_server.app_conversation.app_conversation_models import (
|
|
AppConversationInfo,
|
|
AppConversationInfoPage,
|
|
AppConversationSortOrder,
|
|
)
|
|
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
|
SQLAppConversationInfoService,
|
|
)
|
|
from openhands.app_server.errors import AuthError
|
|
from openhands.app_server.services.injector import InjectorState
|
|
from openhands.app_server.user.specifiy_user_context import ADMIN
|
|
|
|
|
|
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
|
"""Extended SQLAppConversationInfoService with user and organization-based filtering and SAAS metadata handling."""
|
|
|
|
async def _get_current_user(self) -> User | None:
|
|
"""Get the current user using the existing db_session.
|
|
|
|
Uses self.db_session to avoid opening a separate database session.
|
|
|
|
Returns:
|
|
User object or None if no user_id is available
|
|
"""
|
|
user_id_str = await self.user_context.get_user_id()
|
|
if not user_id_str:
|
|
return None
|
|
|
|
user_id_uuid = UUID(user_id_str)
|
|
result = await self.db_session.execute(
|
|
select(User).where(User.id == user_id_uuid)
|
|
)
|
|
return result.scalars().first()
|
|
|
|
async def _apply_user_and_org_filter(self, query):
|
|
"""Apply user_id and org_id filters to ensure conversation isolation.
|
|
|
|
Filters conversations by:
|
|
- user_id: Only show conversations belonging to the current user
|
|
- org_id: Only show conversations belonging to the user's current organization
|
|
|
|
Args:
|
|
query: SQLAlchemy query to apply filters to
|
|
|
|
Returns:
|
|
Query with user and organization filters applied
|
|
|
|
Raises:
|
|
AuthError: If no user_id is available (secure default: deny access)
|
|
"""
|
|
# For internal operations such as getting a conversation by session_api_key
|
|
# we need a mode that does not have filtering. The dependency `as_admin()`
|
|
# is used to enable it
|
|
if self.user_context == ADMIN:
|
|
return query
|
|
|
|
user_id_str = await self.user_context.get_user_id()
|
|
if not user_id_str:
|
|
# Secure default: no user means no access, not "show everything"
|
|
raise AuthError('User authentication required')
|
|
|
|
user_id_uuid = UUID(user_id_str)
|
|
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
|
|
|
# Filter by organization ID to ensure conversations are isolated per organization
|
|
user = await self._get_current_user()
|
|
if user and user.current_org_id is not None:
|
|
query = query.where(
|
|
StoredConversationMetadataSaas.org_id == user.current_org_id
|
|
)
|
|
|
|
return query
|
|
|
|
async def _secure_select(self):
|
|
query = (
|
|
select(StoredConversationMetadata)
|
|
.join(
|
|
StoredConversationMetadataSaas,
|
|
StoredConversationMetadata.conversation_id
|
|
== StoredConversationMetadataSaas.conversation_id,
|
|
)
|
|
.where(StoredConversationMetadata.conversation_version == 'V1')
|
|
)
|
|
return await self._apply_user_and_org_filter(query)
|
|
|
|
async def _secure_select_with_saas_metadata(self):
|
|
"""Select query that includes SAAS metadata for retrieving user_id."""
|
|
query = (
|
|
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
|
.join(
|
|
StoredConversationMetadataSaas,
|
|
StoredConversationMetadata.conversation_id
|
|
== StoredConversationMetadataSaas.conversation_id,
|
|
)
|
|
.where(StoredConversationMetadata.conversation_version == 'V1')
|
|
)
|
|
return await self._apply_user_and_org_filter(query)
|
|
|
|
async def search_app_conversation_info(
|
|
self,
|
|
title__contains: str | None = None,
|
|
created_at__gte: datetime | None = None,
|
|
created_at__lt: datetime | None = None,
|
|
updated_at__gte: datetime | None = None,
|
|
updated_at__lt: datetime | None = None,
|
|
sandbox_id__eq: str | None = None,
|
|
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
|
|
page_id: str | None = None,
|
|
limit: int = 100,
|
|
include_sub_conversations: bool = False,
|
|
) -> AppConversationInfoPage:
|
|
"""Search for conversations with user_id from SAAS metadata."""
|
|
query = await self._secure_select_with_saas_metadata()
|
|
|
|
# Conditionally exclude sub-conversations based on the parameter
|
|
if not include_sub_conversations:
|
|
# Exclude sub-conversations (only include top-level conversations)
|
|
query = query.where(
|
|
StoredConversationMetadata.parent_conversation_id.is_(None)
|
|
)
|
|
|
|
query = self._apply_filters_with_saas_metadata(
|
|
query=query,
|
|
title__contains=title__contains,
|
|
created_at__gte=created_at__gte,
|
|
created_at__lt=created_at__lt,
|
|
updated_at__gte=updated_at__gte,
|
|
updated_at__lt=updated_at__lt,
|
|
sandbox_id__eq=sandbox_id__eq,
|
|
)
|
|
|
|
# Add sort order
|
|
if sort_order == AppConversationSortOrder.CREATED_AT:
|
|
query = query.order_by(StoredConversationMetadata.created_at)
|
|
elif sort_order == AppConversationSortOrder.CREATED_AT_DESC:
|
|
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
|
elif sort_order == AppConversationSortOrder.UPDATED_AT:
|
|
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
|
elif sort_order == AppConversationSortOrder.UPDATED_AT_DESC:
|
|
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
|
elif sort_order == AppConversationSortOrder.TITLE:
|
|
query = query.order_by(StoredConversationMetadata.title)
|
|
elif sort_order == AppConversationSortOrder.TITLE_DESC:
|
|
query = query.order_by(StoredConversationMetadata.title.desc())
|
|
|
|
# Apply pagination
|
|
if page_id is not None:
|
|
try:
|
|
offset = int(page_id)
|
|
query = query.offset(offset)
|
|
except ValueError:
|
|
# If page_id is not a valid integer, start from beginning
|
|
offset = 0
|
|
else:
|
|
offset = 0
|
|
|
|
# Apply limit and get one extra to check if there are more results
|
|
query = query.limit(limit + 1)
|
|
|
|
result = await self.db_session.execute(query)
|
|
rows = result.all()
|
|
|
|
# Check if there are more results
|
|
has_more = len(rows) > limit
|
|
if has_more:
|
|
rows = rows[:limit]
|
|
|
|
items = [
|
|
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
|
for stored_metadata, saas_metadata in rows
|
|
]
|
|
|
|
# Calculate next page ID
|
|
next_page_id = None
|
|
if has_more:
|
|
next_page_id = str(offset + limit)
|
|
|
|
return AppConversationInfoPage(items=items, next_page_id=next_page_id)
|
|
|
|
async def count_app_conversation_info(
|
|
self,
|
|
title__contains: str | None = None,
|
|
created_at__gte: datetime | None = None,
|
|
created_at__lt: datetime | None = None,
|
|
updated_at__gte: datetime | None = None,
|
|
updated_at__lt: datetime | None = None,
|
|
sandbox_id__eq: str | None = None,
|
|
) -> int:
|
|
"""Count conversations matching the given filters with SAAS metadata."""
|
|
query = (
|
|
select(func.count(StoredConversationMetadata.conversation_id))
|
|
.join(
|
|
StoredConversationMetadataSaas,
|
|
StoredConversationMetadata.conversation_id
|
|
== StoredConversationMetadataSaas.conversation_id,
|
|
)
|
|
.where(StoredConversationMetadata.conversation_version == 'V1')
|
|
)
|
|
|
|
# Apply user and organization filtering
|
|
query = await self._apply_user_and_org_filter(query)
|
|
|
|
query = self._apply_filters_with_saas_metadata(
|
|
query=query,
|
|
title__contains=title__contains,
|
|
created_at__gte=created_at__gte,
|
|
created_at__lt=created_at__lt,
|
|
updated_at__gte=updated_at__gte,
|
|
updated_at__lt=updated_at__lt,
|
|
sandbox_id__eq=sandbox_id__eq,
|
|
)
|
|
|
|
result = await self.db_session.execute(query)
|
|
count = result.scalar()
|
|
return count or 0
|
|
|
|
def _apply_filters_with_saas_metadata(
|
|
self,
|
|
query,
|
|
title__contains: str | None = None,
|
|
created_at__gte: datetime | None = None,
|
|
created_at__lt: datetime | None = None,
|
|
updated_at__gte: datetime | None = None,
|
|
updated_at__lt: datetime | None = None,
|
|
sandbox_id__eq: str | None = None,
|
|
):
|
|
"""Apply filters to query that includes SAAS metadata."""
|
|
# Apply the same filters as the base class
|
|
conditions = []
|
|
if title__contains is not None:
|
|
conditions.append(
|
|
StoredConversationMetadata.title.like(f'%{title__contains}%')
|
|
)
|
|
|
|
if created_at__gte is not None:
|
|
conditions.append(StoredConversationMetadata.created_at >= created_at__gte)
|
|
|
|
if created_at__lt is not None:
|
|
conditions.append(StoredConversationMetadata.created_at < created_at__lt)
|
|
|
|
if updated_at__gte is not None:
|
|
conditions.append(
|
|
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
|
)
|
|
|
|
if updated_at__lt is not None:
|
|
conditions.append(
|
|
StoredConversationMetadata.last_updated_at < updated_at__lt
|
|
)
|
|
|
|
if sandbox_id__eq is not None:
|
|
conditions.append(StoredConversationMetadata.sandbox_id == sandbox_id__eq)
|
|
|
|
if conditions:
|
|
query = query.where(*conditions)
|
|
return query
|
|
|
|
async def get_app_conversation_info(
|
|
self, conversation_id: UUID
|
|
) -> AppConversationInfo | None:
|
|
"""Get conversation info with user_id from SAAS metadata."""
|
|
query = await self._secure_select_with_saas_metadata()
|
|
query = query.where(
|
|
StoredConversationMetadata.conversation_id == str(conversation_id)
|
|
)
|
|
result_set = await self.db_session.execute(query)
|
|
result = result_set.first()
|
|
if result:
|
|
stored_metadata, saas_metadata = result
|
|
# Fetch sub-conversation IDs
|
|
sub_conversation_ids = await self.get_sub_conversation_ids(conversation_id)
|
|
return self._to_info_with_user_id(
|
|
stored_metadata,
|
|
saas_metadata,
|
|
sub_conversation_ids=sub_conversation_ids,
|
|
)
|
|
return None
|
|
|
|
async def batch_get_app_conversation_info(
|
|
self, conversation_ids: list[UUID]
|
|
) -> list[AppConversationInfo | None]:
|
|
"""Batch get conversation info with user_id from SAAS metadata."""
|
|
conversation_id_strs = [
|
|
str(conversation_id) for conversation_id in conversation_ids
|
|
]
|
|
query = await self._secure_select_with_saas_metadata()
|
|
query = query.where(
|
|
StoredConversationMetadata.conversation_id.in_(conversation_id_strs)
|
|
)
|
|
result = await self.db_session.execute(query)
|
|
rows = result.all()
|
|
|
|
# Create a mapping of conversation_id to (metadata, saas_metadata)
|
|
info_by_id = {}
|
|
for stored_metadata, saas_metadata in rows:
|
|
info_by_id[stored_metadata.conversation_id] = (
|
|
stored_metadata,
|
|
saas_metadata,
|
|
)
|
|
|
|
results: list[AppConversationInfo | None] = []
|
|
for conversation_id in conversation_id_strs:
|
|
if conversation_id in info_by_id:
|
|
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
|
# Fetch sub-conversation IDs for each conversation
|
|
sub_conversation_ids = await self.get_sub_conversation_ids(
|
|
UUID(conversation_id)
|
|
)
|
|
results.append(
|
|
self._to_info_with_user_id(
|
|
stored_metadata,
|
|
saas_metadata,
|
|
sub_conversation_ids=sub_conversation_ids,
|
|
)
|
|
)
|
|
else:
|
|
results.append(None)
|
|
|
|
return results
|
|
|
|
async def save_app_conversation_info(
|
|
self, info: AppConversationInfo
|
|
) -> AppConversationInfo:
|
|
"""Save conversation info and create/update SAAS metadata with user_id and org_id."""
|
|
# Save the base conversation metadata
|
|
await super().save_app_conversation_info(info)
|
|
|
|
# Get current user_id for SAAS metadata
|
|
# Fall back to info.created_by_user_id for webhook callbacks (which use ADMIN context)
|
|
user_id_str = await self.user_context.get_user_id()
|
|
if not user_id_str and info.created_by_user_id:
|
|
user_id_str = info.created_by_user_id
|
|
if user_id_str:
|
|
# Convert string user_id to UUID
|
|
user_id_uuid = UUID(user_id_str)
|
|
user_query = select(User).where(User.id == user_id_uuid)
|
|
result = await self.db_session.execute(user_query)
|
|
user = result.scalar_one_or_none()
|
|
assert user
|
|
|
|
# Check if SAAS metadata already exists
|
|
saas_query = select(StoredConversationMetadataSaas).where(
|
|
StoredConversationMetadataSaas.conversation_id == str(info.id)
|
|
)
|
|
result = await self.db_session.execute(saas_query)
|
|
existing_saas_metadata = result.scalar_one_or_none()
|
|
assert existing_saas_metadata is None or (
|
|
existing_saas_metadata.user_id == user_id_uuid
|
|
and existing_saas_metadata.org_id == user.current_org_id
|
|
)
|
|
|
|
if not existing_saas_metadata:
|
|
# Create new SAAS metadata
|
|
# Set org_id to user_id as specified in requirements
|
|
saas_metadata = StoredConversationMetadataSaas(
|
|
conversation_id=str(info.id),
|
|
user_id=user_id_uuid,
|
|
org_id=user.current_org_id,
|
|
)
|
|
self.db_session.add(saas_metadata)
|
|
|
|
await self.db_session.commit()
|
|
|
|
return info
|
|
|
|
def _to_info_with_user_id(
|
|
self,
|
|
stored: StoredConversationMetadata,
|
|
saas_metadata: StoredConversationMetadataSaas,
|
|
sub_conversation_ids: list[UUID] | None = None,
|
|
) -> AppConversationInfo:
|
|
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
|
# Use the base _to_info method to get the basic info
|
|
info = self._to_info(stored, sub_conversation_ids=sub_conversation_ids)
|
|
|
|
# Override the created_by_user_id with the user_id from SAAS metadata
|
|
info.created_by_user_id = (
|
|
str(saas_metadata.user_id) if saas_metadata.user_id else None
|
|
)
|
|
|
|
return info
|
|
|
|
|
|
class SaasAppConversationInfoServiceInjector(AppConversationInfoServiceInjector):
|
|
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
|
|
|
async def inject(
|
|
self, state: InjectorState, request: Request | None = None
|
|
) -> AsyncGenerator[AppConversationInfoService, None]:
|
|
from openhands.app_server.config import (
|
|
get_db_session,
|
|
get_user_context,
|
|
)
|
|
|
|
async with (
|
|
get_user_context(state, request) as user_context,
|
|
get_db_session(state, request) as db_session,
|
|
):
|
|
service = SaasSQLAppConversationInfoService(
|
|
db_session=db_session, user_context=user_context
|
|
)
|
|
yield service
|