diff --git a/enterprise/migrations/versions/085_add_public_column_to_conversation_metadata.py b/enterprise/migrations/versions/085_add_public_column_to_conversation_metadata.py new file mode 100644 index 0000000000..71324b0068 --- /dev/null +++ b/enterprise/migrations/versions/085_add_public_column_to_conversation_metadata.py @@ -0,0 +1,41 @@ +"""add public column to conversation_metadata + +Revision ID: 085 +Revises: 084 +Create Date: 2025-01-27 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '085' +down_revision: Union[str, None] = '084' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column( + 'conversation_metadata', + sa.Column('public', sa.Boolean(), nullable=True), + ) + op.create_index( + op.f('ix_conversation_metadata_public'), + 'conversation_metadata', + ['public'], + unique=False, + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_index( + op.f('ix_conversation_metadata_public'), + table_name='conversation_metadata', + ) + op.drop_column('conversation_metadata', 'public') diff --git a/enterprise/poetry.lock b/enterprise/poetry.lock index bd2c55c317..2535aef566 100644 --- a/enterprise/poetry.lock +++ b/enterprise/poetry.lock @@ -5860,7 +5860,7 @@ wsproto = ">=1.2.0" [[package]] name = "openhands-ai" -version = "0.0.0-post.5687+7853b41ad" +version = "0.0.0-post.5750+f19fb1043" description = "OpenHands: Code Less, Make More" optional = false python-versions = "^3.12,<3.14" diff --git a/enterprise/saas_server.py b/enterprise/saas_server.py index 96e19a9815..ec1480cbda 100644 --- a/enterprise/saas_server.py +++ b/enterprise/saas_server.py @@ -37,6 +37,12 @@ from server.routes.mcp_patch import patch_mcp_server # noqa: E402 from server.routes.oauth_device import oauth_device_router # noqa: E402 from server.routes.readiness import readiness_router # noqa: E402 from server.routes.user import saas_user_router # noqa: E402 +from server.sharing.shared_conversation_router import ( # noqa: E402 + router as shared_conversation_router, +) +from server.sharing.shared_event_router import ( # noqa: E402 + router as shared_event_router, +) from openhands.server.app import app as base_app # noqa: E402 from openhands.server.listen_socket import sio # noqa: E402 @@ -66,6 +72,8 @@ base_app.include_router(saas_user_router) # Add additional route SAAS user call base_app.include_router( billing_router ) # Add routes for credit management and Stripe payment integration +base_app.include_router(shared_conversation_router) +base_app.include_router(shared_event_router) # Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set if GITHUB_APP_CLIENT_ID: @@ -99,6 +107,7 @@ base_app.include_router( event_webhook_router ) # Add routes for Events in nested runtimes + base_app.add_middleware( CORSMiddleware, allow_origins=PERMITTED_CORS_ORIGINS, diff --git a/enterprise/server/sharing/README.md b/enterprise/server/sharing/README.md new file mode 100644 index 0000000000..5eb5474d21 --- /dev/null +++ b/enterprise/server/sharing/README.md @@ -0,0 +1,20 @@ +# Sharing Package + +This package contains functionality for sharing conversations. + +## Components + +- **shared.py**: Data models for shared conversations +- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info +- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service +- **shared_event_service.py**: Service interface for accessing shared events +- **shared_event_service_impl.py**: Implementation of the shared event service +- **shared_conversation_router.py**: REST API endpoints for shared conversations +- **shared_event_router.py**: REST API endpoints for shared events + +## Features + +- Read-only access to shared conversations +- Event access for shared conversations +- Search and filtering capabilities +- Pagination support diff --git a/enterprise/server/sharing/filesystem_shared_event_service.py b/enterprise/server/sharing/filesystem_shared_event_service.py new file mode 100644 index 0000000000..e39f880bdf --- /dev/null +++ b/enterprise/server/sharing/filesystem_shared_event_service.py @@ -0,0 +1,142 @@ +"""Implementation of SharedEventService. + +This implementation provides read-only access to events from shared conversations: +- Validates that the conversation is shared before returning events +- Uses existing EventService for actual event retrieval +- Uses SharedConversationInfoService for shared conversation validation +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import AsyncGenerator +from uuid import UUID + +from fastapi import Request +from server.sharing.shared_conversation_info_service import ( + SharedConversationInfoService, +) +from server.sharing.shared_event_service import ( + SharedEventService, + SharedEventServiceInjector, +) +from server.sharing.sql_shared_conversation_info_service import ( + SQLSharedConversationInfoService, +) + +from openhands.agent_server.models import EventPage, EventSortOrder +from openhands.app_server.event.event_service import EventService +from openhands.app_server.event_callback.event_callback_models import EventKind +from openhands.app_server.services.injector import InjectorState +from openhands.sdk import Event + +logger = logging.getLogger(__name__) + + +@dataclass +class SharedEventServiceImpl(SharedEventService): + """Implementation of SharedEventService that validates shared access.""" + + shared_conversation_info_service: SharedConversationInfoService + event_service: EventService + + async def get_shared_event( + self, conversation_id: UUID, event_id: str + ) -> Event | None: + """Given a conversation_id and event_id, retrieve an event if the conversation is shared.""" + # First check if the conversation is shared + shared_conversation_info = ( + await self.shared_conversation_info_service.get_shared_conversation_info( + conversation_id + ) + ) + if shared_conversation_info is None: + return None + + # If conversation is shared, get the event + return await self.event_service.get_event(event_id) + + async def search_shared_events( + self, + conversation_id: UUID, + kind__eq: EventKind | None = None, + timestamp__gte: datetime | None = None, + timestamp__lt: datetime | None = None, + sort_order: EventSortOrder = EventSortOrder.TIMESTAMP, + page_id: str | None = None, + limit: int = 100, + ) -> EventPage: + """Search events for a specific shared conversation.""" + # First check if the conversation is shared + shared_conversation_info = ( + await self.shared_conversation_info_service.get_shared_conversation_info( + conversation_id + ) + ) + if shared_conversation_info is None: + # Return empty page if conversation is not shared + return EventPage(items=[], next_page_id=None) + + # If conversation is shared, search events for this conversation + return await self.event_service.search_events( + conversation_id__eq=conversation_id, + kind__eq=kind__eq, + timestamp__gte=timestamp__gte, + timestamp__lt=timestamp__lt, + sort_order=sort_order, + page_id=page_id, + limit=limit, + ) + + async def count_shared_events( + self, + conversation_id: UUID, + kind__eq: EventKind | None = None, + timestamp__gte: datetime | None = None, + timestamp__lt: datetime | None = None, + sort_order: EventSortOrder = EventSortOrder.TIMESTAMP, + ) -> int: + """Count events for a specific shared conversation.""" + # First check if the conversation is shared + shared_conversation_info = ( + await self.shared_conversation_info_service.get_shared_conversation_info( + conversation_id + ) + ) + if shared_conversation_info is None: + return 0 + + # If conversation is shared, count events for this conversation + return await self.event_service.count_events( + conversation_id__eq=conversation_id, + kind__eq=kind__eq, + timestamp__gte=timestamp__gte, + timestamp__lt=timestamp__lt, + sort_order=sort_order, + ) + + +class SharedEventServiceImplInjector(SharedEventServiceInjector): + async def inject( + self, state: InjectorState, request: Request | None = None + ) -> AsyncGenerator[SharedEventService, None]: + # Define inline to prevent circular lookup + from openhands.app_server.config import ( + get_db_session, + get_event_service, + ) + + async with ( + get_db_session(state, request) as db_session, + get_event_service(state, request) as event_service, + ): + shared_conversation_info_service = SQLSharedConversationInfoService( + db_session=db_session + ) + service = SharedEventServiceImpl( + shared_conversation_info_service=shared_conversation_info_service, + event_service=event_service, + ) + yield service diff --git a/enterprise/server/sharing/shared_conversation_info_service.py b/enterprise/server/sharing/shared_conversation_info_service.py new file mode 100644 index 0000000000..a1fdec6718 --- /dev/null +++ b/enterprise/server/sharing/shared_conversation_info_service.py @@ -0,0 +1,66 @@ +import asyncio +from abc import ABC, abstractmethod +from datetime import datetime +from uuid import UUID + +from server.sharing.shared_conversation_models import ( + SharedConversation, + SharedConversationPage, + SharedConversationSortOrder, +) + +from openhands.app_server.services.injector import Injector +from openhands.sdk.utils.models import DiscriminatedUnionMixin + + +class SharedConversationInfoService(ABC): + """Service for accessing shared conversation info without user restrictions.""" + + @abstractmethod + async def search_shared_conversation_info( + self, + title__contains: str | None = None, + created_at__gte: datetime | None = None, + created_at__lt: datetime | None = None, + updated_at__gte: datetime | None = None, + updated_at__lt: datetime | None = None, + sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC, + page_id: str | None = None, + limit: int = 100, + include_sub_conversations: bool = False, + ) -> SharedConversationPage: + """Search for shared conversations.""" + + @abstractmethod + async def count_shared_conversation_info( + self, + title__contains: str | None = None, + created_at__gte: datetime | None = None, + created_at__lt: datetime | None = None, + updated_at__gte: datetime | None = None, + updated_at__lt: datetime | None = None, + ) -> int: + """Count shared conversations.""" + + @abstractmethod + async def get_shared_conversation_info( + self, conversation_id: UUID + ) -> SharedConversation | None: + """Get a single shared conversation info, returning None if missing or not shared.""" + + async def batch_get_shared_conversation_info( + self, conversation_ids: list[UUID] + ) -> list[SharedConversation | None]: + """Get a batch of shared conversation info, return None for any missing or non-shared.""" + return await asyncio.gather( + *[ + self.get_shared_conversation_info(conversation_id) + for conversation_id in conversation_ids + ] + ) + + +class SharedConversationInfoServiceInjector( + DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC +): + pass diff --git a/enterprise/server/sharing/shared_conversation_models.py b/enterprise/server/sharing/shared_conversation_models.py new file mode 100644 index 0000000000..806ddefb12 --- /dev/null +++ b/enterprise/server/sharing/shared_conversation_models.py @@ -0,0 +1,56 @@ +from datetime import datetime +from enum import Enum + +# Simplified imports to avoid dependency chain issues +# from openhands.integrations.service_types import ProviderType +# from openhands.sdk.llm import MetricsSnapshot +# from openhands.storage.data_models.conversation_metadata import ConversationTrigger +# For now, use Any to avoid import issues +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from openhands.agent_server.utils import OpenHandsUUID, utc_now + +ProviderType = Any +MetricsSnapshot = Any +ConversationTrigger = Any + + +class SharedConversation(BaseModel): + """Shared conversation info model with all fields from AppConversationInfo.""" + + id: OpenHandsUUID = Field(default_factory=uuid4) + + created_by_user_id: str | None + sandbox_id: str + + selected_repository: str | None = None + selected_branch: str | None = None + git_provider: ProviderType | None = None + title: str | None = None + pr_number: list[int] = Field(default_factory=list) + llm_model: str | None = None + + metrics: MetricsSnapshot | None = None + + parent_conversation_id: OpenHandsUUID | None = None + sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list) + + created_at: datetime = Field(default_factory=utc_now) + updated_at: datetime = Field(default_factory=utc_now) + + +class SharedConversationSortOrder(Enum): + CREATED_AT = 'CREATED_AT' + CREATED_AT_DESC = 'CREATED_AT_DESC' + UPDATED_AT = 'UPDATED_AT' + UPDATED_AT_DESC = 'UPDATED_AT_DESC' + TITLE = 'TITLE' + TITLE_DESC = 'TITLE_DESC' + + +class SharedConversationPage(BaseModel): + items: list[SharedConversation] + next_page_id: str | None = None diff --git a/enterprise/server/sharing/shared_conversation_router.py b/enterprise/server/sharing/shared_conversation_router.py new file mode 100644 index 0000000000..26fe047e6d --- /dev/null +++ b/enterprise/server/sharing/shared_conversation_router.py @@ -0,0 +1,135 @@ +"""Shared Conversation router for OpenHands Server.""" + +from datetime import datetime +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, Query +from server.sharing.shared_conversation_info_service import ( + SharedConversationInfoService, +) +from server.sharing.shared_conversation_models import ( + SharedConversation, + SharedConversationPage, + SharedConversationSortOrder, +) +from server.sharing.sql_shared_conversation_info_service import ( + SQLSharedConversationInfoServiceInjector, +) + +router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing']) +shared_conversation_info_service_dependency = Depends( + SQLSharedConversationInfoServiceInjector().depends +) + +# Read methods + + +@router.get('/search') +async def search_shared_conversations( + title__contains: Annotated[ + str | None, + Query(title='Filter by title containing this string'), + ] = None, + created_at__gte: Annotated[ + datetime | None, + Query(title='Filter by created_at greater than or equal to this datetime'), + ] = None, + created_at__lt: Annotated[ + datetime | None, + Query(title='Filter by created_at less than this datetime'), + ] = None, + updated_at__gte: Annotated[ + datetime | None, + Query(title='Filter by updated_at greater than or equal to this datetime'), + ] = None, + updated_at__lt: Annotated[ + datetime | None, + Query(title='Filter by updated_at less than this datetime'), + ] = None, + sort_order: Annotated[ + SharedConversationSortOrder, + Query(title='Sort order for results'), + ] = SharedConversationSortOrder.CREATED_AT_DESC, + page_id: Annotated[ + str | None, + Query(title='Optional next_page_id from the previously returned page'), + ] = None, + limit: Annotated[ + int, + Query( + title='The max number of results in the page', + gt=0, + lte=100, + ), + ] = 100, + include_sub_conversations: Annotated[ + bool, + Query( + title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.' + ), + ] = False, + shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency, +) -> SharedConversationPage: + """Search / List shared conversations.""" + assert limit > 0 + assert limit <= 100 + return await shared_conversation_service.search_shared_conversation_info( + title__contains=title__contains, + created_at__gte=created_at__gte, + created_at__lt=created_at__lt, + updated_at__gte=updated_at__gte, + updated_at__lt=updated_at__lt, + sort_order=sort_order, + page_id=page_id, + limit=limit, + include_sub_conversations=include_sub_conversations, + ) + + +@router.get('/count') +async def count_shared_conversations( + title__contains: Annotated[ + str | None, + Query(title='Filter by title containing this string'), + ] = None, + created_at__gte: Annotated[ + datetime | None, + Query(title='Filter by created_at greater than or equal to this datetime'), + ] = None, + created_at__lt: Annotated[ + datetime | None, + Query(title='Filter by created_at less than this datetime'), + ] = None, + updated_at__gte: Annotated[ + datetime | None, + Query(title='Filter by updated_at greater than or equal to this datetime'), + ] = None, + updated_at__lt: Annotated[ + datetime | None, + Query(title='Filter by updated_at less than this datetime'), + ] = None, + shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency, +) -> int: + """Count shared conversations matching the given filters.""" + return await shared_conversation_service.count_shared_conversation_info( + title__contains=title__contains, + created_at__gte=created_at__gte, + created_at__lt=created_at__lt, + updated_at__gte=updated_at__gte, + updated_at__lt=updated_at__lt, + ) + + +@router.get('') +async def batch_get_shared_conversations( + ids: Annotated[list[str], Query()], + shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency, +) -> list[SharedConversation | None]: + """Get a batch of shared conversations given their ids. Return None for any missing or non-shared.""" + assert len(ids) <= 100 + uuids = [UUID(id_) for id_ in ids] + shared_conversation_info = ( + await shared_conversation_service.batch_get_shared_conversation_info(uuids) + ) + return shared_conversation_info diff --git a/enterprise/server/sharing/shared_event_router.py b/enterprise/server/sharing/shared_event_router.py new file mode 100644 index 0000000000..4fc579196c --- /dev/null +++ b/enterprise/server/sharing/shared_event_router.py @@ -0,0 +1,126 @@ +"""Shared Event router for OpenHands Server.""" + +from datetime import datetime +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, Query +from server.sharing.filesystem_shared_event_service import ( + SharedEventServiceImplInjector, +) +from server.sharing.shared_event_service import SharedEventService + +from openhands.agent_server.models import EventPage, EventSortOrder +from openhands.app_server.event_callback.event_callback_models import EventKind +from openhands.sdk import Event + +router = APIRouter(prefix='/api/shared-events', tags=['Sharing']) +shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends) + + +# Read methods + + +@router.get('/search') +async def search_shared_events( + conversation_id: Annotated[ + str, + Query(title='Conversation ID to search events for'), + ], + kind__eq: Annotated[ + EventKind | None, + Query(title='Optional filter by event kind'), + ] = None, + timestamp__gte: Annotated[ + datetime | None, + Query(title='Optional filter by timestamp greater than or equal to'), + ] = None, + timestamp__lt: Annotated[ + datetime | None, + Query(title='Optional filter by timestamp less than'), + ] = None, + sort_order: Annotated[ + EventSortOrder, + Query(title='Sort order for results'), + ] = EventSortOrder.TIMESTAMP, + page_id: Annotated[ + str | None, + Query(title='Optional next_page_id from the previously returned page'), + ] = None, + limit: Annotated[ + int, + Query(title='The max number of results in the page', gt=0, lte=100), + ] = 100, + shared_event_service: SharedEventService = shared_event_service_dependency, +) -> EventPage: + """Search / List events for a shared conversation.""" + assert limit > 0 + assert limit <= 100 + return await shared_event_service.search_shared_events( + conversation_id=UUID(conversation_id), + kind__eq=kind__eq, + timestamp__gte=timestamp__gte, + timestamp__lt=timestamp__lt, + sort_order=sort_order, + page_id=page_id, + limit=limit, + ) + + +@router.get('/count') +async def count_shared_events( + conversation_id: Annotated[ + str, + Query(title='Conversation ID to count events for'), + ], + kind__eq: Annotated[ + EventKind | None, + Query(title='Optional filter by event kind'), + ] = None, + timestamp__gte: Annotated[ + datetime | None, + Query(title='Optional filter by timestamp greater than or equal to'), + ] = None, + timestamp__lt: Annotated[ + datetime | None, + Query(title='Optional filter by timestamp less than'), + ] = None, + sort_order: Annotated[ + EventSortOrder, + Query(title='Sort order for results'), + ] = EventSortOrder.TIMESTAMP, + shared_event_service: SharedEventService = shared_event_service_dependency, +) -> int: + """Count events for a shared conversation matching the given filters.""" + return await shared_event_service.count_shared_events( + conversation_id=UUID(conversation_id), + kind__eq=kind__eq, + timestamp__gte=timestamp__gte, + timestamp__lt=timestamp__lt, + sort_order=sort_order, + ) + + +@router.get('') +async def batch_get_shared_events( + conversation_id: Annotated[ + UUID, + Query(title='Conversation ID to get events for'), + ], + id: Annotated[list[str], Query()], + shared_event_service: SharedEventService = shared_event_service_dependency, +) -> list[Event | None]: + """Get a batch of events for a shared conversation given their ids, returning null for any missing event.""" + assert len(id) <= 100 + events = await shared_event_service.batch_get_shared_events(conversation_id, id) + return events + + +@router.get('/{conversation_id}/{event_id}') +async def get_shared_event( + conversation_id: UUID, + event_id: str, + shared_event_service: SharedEventService = shared_event_service_dependency, +) -> Event | None: + """Get a single event from a shared conversation by conversation_id and event_id.""" + return await shared_event_service.get_shared_event(conversation_id, event_id) diff --git a/enterprise/server/sharing/shared_event_service.py b/enterprise/server/sharing/shared_event_service.py new file mode 100644 index 0000000000..054153d03f --- /dev/null +++ b/enterprise/server/sharing/shared_event_service.py @@ -0,0 +1,64 @@ +import asyncio +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from uuid import UUID + +from openhands.agent_server.models import EventPage, EventSortOrder +from openhands.app_server.event_callback.event_callback_models import EventKind +from openhands.app_server.services.injector import Injector +from openhands.sdk import Event +from openhands.sdk.utils.models import DiscriminatedUnionMixin + +_logger = logging.getLogger(__name__) + + +class SharedEventService(ABC): + """Event Service for getting events from shared conversations only.""" + + @abstractmethod + async def get_shared_event( + self, conversation_id: UUID, event_id: str + ) -> Event | None: + """Given a conversation_id and event_id, retrieve an event if the conversation is shared.""" + + @abstractmethod + async def search_shared_events( + self, + conversation_id: UUID, + kind__eq: EventKind | None = None, + timestamp__gte: datetime | None = None, + timestamp__lt: datetime | None = None, + sort_order: EventSortOrder = EventSortOrder.TIMESTAMP, + page_id: str | None = None, + limit: int = 100, + ) -> EventPage: + """Search events for a specific shared conversation.""" + + @abstractmethod + async def count_shared_events( + self, + conversation_id: UUID, + kind__eq: EventKind | None = None, + timestamp__gte: datetime | None = None, + timestamp__lt: datetime | None = None, + sort_order: EventSortOrder = EventSortOrder.TIMESTAMP, + ) -> int: + """Count events for a specific shared conversation.""" + + async def batch_get_shared_events( + self, conversation_id: UUID, event_ids: list[str] + ) -> list[Event | None]: + """Given a conversation_id and list of event_ids, get events if the conversation is shared.""" + return await asyncio.gather( + *[ + self.get_shared_event(conversation_id, event_id) + for event_id in event_ids + ] + ) + + +class SharedEventServiceInjector( + DiscriminatedUnionMixin, Injector[SharedEventService], ABC +): + pass diff --git a/enterprise/server/sharing/sql_shared_conversation_info_service.py b/enterprise/server/sharing/sql_shared_conversation_info_service.py new file mode 100644 index 0000000000..f86a6045bb --- /dev/null +++ b/enterprise/server/sharing/sql_shared_conversation_info_service.py @@ -0,0 +1,282 @@ +"""SQL implementation of SharedConversationInfoService. + +This implementation provides read-only access to shared conversations: +- Direct database access without user permission checks +- Filters only conversations marked as shared (currently public) +- Full async/await support using SQL async db_sessions +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import AsyncGenerator +from uuid import UUID + +from fastapi import Request +from server.sharing.shared_conversation_info_service import ( + SharedConversationInfoService, + SharedConversationInfoServiceInjector, +) +from server.sharing.shared_conversation_models import ( + SharedConversation, + SharedConversationPage, + SharedConversationSortOrder, +) +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.app_server.app_conversation.sql_app_conversation_info_service import ( + StoredConversationMetadata, +) +from openhands.app_server.services.injector import InjectorState +from openhands.integrations.provider import ProviderType +from openhands.sdk.llm import MetricsSnapshot +from openhands.sdk.llm.utils.metrics import TokenUsage + +logger = logging.getLogger(__name__) + + +@dataclass +class SQLSharedConversationInfoService(SharedConversationInfoService): + """SQL implementation of SharedConversationInfoService for shared conversations only.""" + + db_session: AsyncSession + + async def search_shared_conversation_info( + self, + title__contains: str | None = None, + created_at__gte: datetime | None = None, + created_at__lt: datetime | None = None, + updated_at__gte: datetime | None = None, + updated_at__lt: datetime | None = None, + sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC, + page_id: str | None = None, + limit: int = 100, + include_sub_conversations: bool = False, + ) -> SharedConversationPage: + """Search for shared conversations.""" + query = self._public_select() + + # Conditionally exclude sub-conversations based on the parameter + if not include_sub_conversations: + # Exclude sub-conversations (only include top-level conversations) + query = query.where( + StoredConversationMetadata.parent_conversation_id.is_(None) + ) + + query = self._apply_filters( + query=query, + title__contains=title__contains, + created_at__gte=created_at__gte, + created_at__lt=created_at__lt, + updated_at__gte=updated_at__gte, + updated_at__lt=updated_at__lt, + ) + + # Add sort order + if sort_order == SharedConversationSortOrder.CREATED_AT: + query = query.order_by(StoredConversationMetadata.created_at) + elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC: + query = query.order_by(StoredConversationMetadata.created_at.desc()) + elif sort_order == SharedConversationSortOrder.UPDATED_AT: + query = query.order_by(StoredConversationMetadata.last_updated_at) + elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC: + query = query.order_by(StoredConversationMetadata.last_updated_at.desc()) + elif sort_order == SharedConversationSortOrder.TITLE: + query = query.order_by(StoredConversationMetadata.title) + elif sort_order == SharedConversationSortOrder.TITLE_DESC: + query = query.order_by(StoredConversationMetadata.title.desc()) + + # Apply pagination + if page_id is not None: + try: + offset = int(page_id) + query = query.offset(offset) + except ValueError: + # If page_id is not a valid integer, start from beginning + offset = 0 + else: + offset = 0 + + # Apply limit and get one extra to check if there are more results + query = query.limit(limit + 1) + + result = await self.db_session.execute(query) + rows = result.scalars().all() + + # Check if there are more results + has_more = len(rows) > limit + if has_more: + rows = rows[:limit] + + items = [self._to_shared_conversation(row) for row in rows] + + # Calculate next page ID + next_page_id = None + if has_more: + next_page_id = str(offset + limit) + + return SharedConversationPage(items=items, next_page_id=next_page_id) + + async def count_shared_conversation_info( + self, + title__contains: str | None = None, + created_at__gte: datetime | None = None, + created_at__lt: datetime | None = None, + updated_at__gte: datetime | None = None, + updated_at__lt: datetime | None = None, + ) -> int: + """Count shared conversations matching the given filters.""" + from sqlalchemy import func + + query = select(func.count(StoredConversationMetadata.conversation_id)) + # Only include shared conversations + query = query.where(StoredConversationMetadata.public == True) # noqa: E712 + query = query.where(StoredConversationMetadata.conversation_version == 'V1') + + query = self._apply_filters( + query=query, + title__contains=title__contains, + created_at__gte=created_at__gte, + created_at__lt=created_at__lt, + updated_at__gte=updated_at__gte, + updated_at__lt=updated_at__lt, + ) + + result = await self.db_session.execute(query) + return result.scalar() or 0 + + async def get_shared_conversation_info( + self, conversation_id: UUID + ) -> SharedConversation | None: + """Get a single public conversation info, returning None if missing or not shared.""" + query = self._public_select().where( + StoredConversationMetadata.conversation_id == str(conversation_id) + ) + + result = await self.db_session.execute(query) + stored = result.scalar_one_or_none() + + if stored is None: + return None + + return self._to_shared_conversation(stored) + + def _public_select(self): + """Create a select query that only returns public conversations.""" + query = select(StoredConversationMetadata).where( + StoredConversationMetadata.conversation_version == 'V1' + ) + # Only include conversations marked as public + query = query.where(StoredConversationMetadata.public == True) # noqa: E712 + return query + + def _apply_filters( + self, + query, + title__contains: str | None = None, + created_at__gte: datetime | None = None, + created_at__lt: datetime | None = None, + updated_at__gte: datetime | None = None, + updated_at__lt: datetime | None = None, + ): + """Apply common filters to a query.""" + if title__contains is not None: + query = query.where( + StoredConversationMetadata.title.contains(title__contains) + ) + + if created_at__gte is not None: + query = query.where( + StoredConversationMetadata.created_at >= created_at__gte + ) + + if created_at__lt is not None: + query = query.where(StoredConversationMetadata.created_at < created_at__lt) + + if updated_at__gte is not None: + query = query.where( + StoredConversationMetadata.last_updated_at >= updated_at__gte + ) + + if updated_at__lt is not None: + query = query.where( + StoredConversationMetadata.last_updated_at < updated_at__lt + ) + + return query + + def _to_shared_conversation( + self, + stored: StoredConversationMetadata, + sub_conversation_ids: list[UUID] | None = None, + ) -> SharedConversation: + """Convert StoredConversationMetadata to SharedConversation.""" + # V1 conversations should always have a sandbox_id + sandbox_id = stored.sandbox_id + assert sandbox_id is not None + + # Rebuild token usage + token_usage = TokenUsage( + prompt_tokens=stored.prompt_tokens, + completion_tokens=stored.completion_tokens, + cache_read_tokens=stored.cache_read_tokens, + cache_write_tokens=stored.cache_write_tokens, + context_window=stored.context_window, + per_turn_token=stored.per_turn_token, + ) + + # Rebuild metrics object + metrics = MetricsSnapshot( + accumulated_cost=stored.accumulated_cost, + max_budget_per_task=stored.max_budget_per_task, + accumulated_token_usage=token_usage, + ) + + # Get timestamps + created_at = self._fix_timezone(stored.created_at) + updated_at = self._fix_timezone(stored.last_updated_at) + + return SharedConversation( + id=UUID(stored.conversation_id), + created_by_user_id=stored.user_id if stored.user_id else None, + sandbox_id=stored.sandbox_id, + selected_repository=stored.selected_repository, + selected_branch=stored.selected_branch, + git_provider=( + ProviderType(stored.git_provider) if stored.git_provider else None + ), + title=stored.title, + pr_number=stored.pr_number, + llm_model=stored.llm_model, + metrics=metrics, + parent_conversation_id=( + UUID(stored.parent_conversation_id) + if stored.parent_conversation_id + else None + ), + sub_conversation_ids=sub_conversation_ids or [], + created_at=created_at, + updated_at=updated_at, + ) + + def _fix_timezone(self, value: datetime) -> datetime: + """Sqlite does not store timezones - and since we can't update the existing models + we assume UTC if the timezone is missing.""" + if not value.tzinfo: + value = value.replace(tzinfo=UTC) + return value + + +class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector): + async def inject( + self, state: InjectorState, request: Request | None = None + ) -> AsyncGenerator[SharedConversationInfoService, None]: + # Define inline to prevent circular lookup + from openhands.app_server.config import get_db_session + + async with get_db_session(state, request) as db_session: + service = SQLSharedConversationInfoService(db_session=db_session) + yield service diff --git a/enterprise/storage/saas_conversation_store.py b/enterprise/storage/saas_conversation_store.py index 160c3a80a2..3a0756dd5a 100644 --- a/enterprise/storage/saas_conversation_store.py +++ b/enterprise/storage/saas_conversation_store.py @@ -61,6 +61,7 @@ class SaasConversationStore(ConversationStore): kwargs.pop('context_window', None) kwargs.pop('per_turn_token', None) kwargs.pop('parent_conversation_id', None) + kwargs.pop('public') return ConversationMetadata(**kwargs) diff --git a/enterprise/tests/unit/test_sharing/__init__.py b/enterprise/tests/unit/test_sharing/__init__.py new file mode 100644 index 0000000000..8958107d4e --- /dev/null +++ b/enterprise/tests/unit/test_sharing/__init__.py @@ -0,0 +1 @@ +"""Tests for sharing package.""" diff --git a/enterprise/tests/unit/test_sharing/test_shared_conversation_models.py b/enterprise/tests/unit/test_sharing/test_shared_conversation_models.py new file mode 100644 index 0000000000..ec8ae8ce8e --- /dev/null +++ b/enterprise/tests/unit/test_sharing/test_shared_conversation_models.py @@ -0,0 +1,91 @@ +"""Tests for public conversation models.""" + +from datetime import datetime +from uuid import uuid4 + +from server.sharing.shared_conversation_models import ( + SharedConversation, + SharedConversationPage, + SharedConversationSortOrder, +) + + +def test_public_conversation_creation(): + """Test that SharedConversation can be created with all required fields.""" + conversation_id = uuid4() + now = datetime.utcnow() + + conversation = SharedConversation( + id=conversation_id, + created_by_user_id='test_user', + sandbox_id='test_sandbox', + title='Test Conversation', + created_at=now, + updated_at=now, + selected_repository=None, + parent_conversation_id=None, + ) + + assert conversation.id == conversation_id + assert conversation.title == 'Test Conversation' + assert conversation.created_by_user_id == 'test_user' + assert conversation.sandbox_id == 'test_sandbox' + + +def test_public_conversation_page_creation(): + """Test that SharedConversationPage can be created.""" + conversation_id = uuid4() + now = datetime.utcnow() + + conversation = SharedConversation( + id=conversation_id, + created_by_user_id='test_user', + sandbox_id='test_sandbox', + title='Test Conversation', + created_at=now, + updated_at=now, + selected_repository=None, + parent_conversation_id=None, + ) + + page = SharedConversationPage( + items=[conversation], + next_page_id='next_page', + ) + + assert len(page.items) == 1 + assert page.items[0].id == conversation_id + assert page.next_page_id == 'next_page' + + +def test_public_conversation_sort_order_enum(): + """Test that SharedConversationSortOrder enum has expected values.""" + assert hasattr(SharedConversationSortOrder, 'CREATED_AT') + assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC') + assert hasattr(SharedConversationSortOrder, 'UPDATED_AT') + assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC') + assert hasattr(SharedConversationSortOrder, 'TITLE') + assert hasattr(SharedConversationSortOrder, 'TITLE_DESC') + + +def test_public_conversation_optional_fields(): + """Test that SharedConversation works with optional fields.""" + conversation_id = uuid4() + parent_id = uuid4() + now = datetime.utcnow() + + conversation = SharedConversation( + id=conversation_id, + created_by_user_id='test_user', + sandbox_id='test_sandbox', + title='Test Conversation', + created_at=now, + updated_at=now, + selected_repository='owner/repo', + parent_conversation_id=parent_id, + llm_model='gpt-4', + ) + + assert conversation.selected_repository == 'owner/repo' + assert conversation.parent_conversation_id == parent_id + assert conversation.llm_model == 'gpt-4' diff --git a/enterprise/tests/unit/test_sharing/test_sharing_shared_conversation_info_service.py b/enterprise/tests/unit/test_sharing/test_sharing_shared_conversation_info_service.py new file mode 100644 index 0000000000..bacb9edb58 --- /dev/null +++ b/enterprise/tests/unit/test_sharing/test_sharing_shared_conversation_info_service.py @@ -0,0 +1,430 @@ +"""Tests for SharedConversationInfoService.""" + +from datetime import UTC, datetime +from typing import AsyncGenerator +from uuid import uuid4 + +import pytest +from server.sharing.shared_conversation_models import ( + SharedConversationSortOrder, +) +from server.sharing.sql_shared_conversation_info_service import ( + SQLSharedConversationInfoService, +) +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from openhands.app_server.app_conversation.app_conversation_models import ( + AppConversationInfo, +) +from openhands.app_server.app_conversation.sql_app_conversation_info_service import ( + SQLAppConversationInfoService, +) +from openhands.app_server.user.specifiy_user_context import SpecifyUserContext +from openhands.app_server.utils.sql_utils import Base +from openhands.integrations.provider import ProviderType +from openhands.sdk.llm import MetricsSnapshot +from openhands.sdk.llm.utils.metrics import TokenUsage +from openhands.storage.data_models.conversation_metadata import ConversationTrigger + + +@pytest.fixture +async def async_engine(): + """Create an async SQLite engine for testing.""" + engine = create_async_engine( + 'sqlite+aiosqlite:///:memory:', + poolclass=StaticPool, + connect_args={'check_same_thread': False}, + echo=False, + ) + + # Create all tables + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + await engine.dispose() + + +@pytest.fixture +async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]: + """Create an async session for testing.""" + async_session_maker = async_sessionmaker( + async_engine, class_=AsyncSession, expire_on_commit=False + ) + + async with async_session_maker() as db_session: + yield db_session + + +@pytest.fixture +async def shared_conversation_info_service(async_session): + """Create a SharedConversationInfoService for testing.""" + return SQLSharedConversationInfoService(db_session=async_session) + + +@pytest.fixture +async def app_conversation_service(async_session): + """Create an AppConversationInfoService for creating test data.""" + return SQLAppConversationInfoService( + db_session=async_session, user_context=SpecifyUserContext(user_id=None) + ) + + +@pytest.fixture +def sample_conversation_info(): + """Create a sample conversation info for testing.""" + return AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user', + sandbox_id='test_sandbox', + selected_repository='test/repo', + selected_branch='main', + git_provider=ProviderType.GITHUB, + title='Test Conversation', + trigger=ConversationTrigger.GUI, + pr_number=[123], + llm_model='gpt-4', + metrics=MetricsSnapshot( + accumulated_cost=1.5, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage( + prompt_tokens=100, + completion_tokens=50, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=4096, + per_turn_token=150, + ), + ), + parent_conversation_id=None, + sub_conversation_ids=[], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + public=True, # Make it public for testing + ) + + +@pytest.fixture +def sample_private_conversation_info(): + """Create a sample private conversation info for testing.""" + return AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user', + sandbox_id='test_sandbox_private', + selected_repository='test/private_repo', + selected_branch='main', + git_provider=ProviderType.GITHUB, + title='Private Conversation', + trigger=ConversationTrigger.GUI, + pr_number=[124], + llm_model='gpt-4', + metrics=MetricsSnapshot( + accumulated_cost=2.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage( + prompt_tokens=200, + completion_tokens=100, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=4096, + per_turn_token=300, + ), + ), + parent_conversation_id=None, + sub_conversation_ids=[], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + public=False, # Make it private + ) + + +class TestSharedConversationInfoService: + """Test cases for SharedConversationInfoService.""" + + @pytest.mark.asyncio + @pytest.mark.asyncio + async def test_get_shared_conversation_info_returns_public_conversation( + self, + shared_conversation_info_service, + app_conversation_service, + sample_conversation_info, + ): + """Test that get_shared_conversation_info returns a public conversation.""" + # Create a public conversation + await app_conversation_service.save_app_conversation_info( + sample_conversation_info + ) + + # Retrieve it via public service + result = await shared_conversation_info_service.get_shared_conversation_info( + sample_conversation_info.id + ) + + assert result is not None + assert result.id == sample_conversation_info.id + assert result.title == sample_conversation_info.title + assert result.created_by_user_id == sample_conversation_info.created_by_user_id + + @pytest.mark.asyncio + async def test_get_shared_conversation_info_returns_none_for_private_conversation( + self, + shared_conversation_info_service, + app_conversation_service, + sample_private_conversation_info, + ): + """Test that get_shared_conversation_info returns None for private conversations.""" + # Create a private conversation + await app_conversation_service.save_app_conversation_info( + sample_private_conversation_info + ) + + # Try to retrieve it via public service + result = await shared_conversation_info_service.get_shared_conversation_info( + sample_private_conversation_info.id + ) + + assert result is None + + @pytest.mark.asyncio + async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation( + self, shared_conversation_info_service + ): + """Test that get_shared_conversation_info returns None for nonexistent conversations.""" + nonexistent_id = uuid4() + result = await shared_conversation_info_service.get_shared_conversation_info( + nonexistent_id + ) + assert result is None + + @pytest.mark.asyncio + async def test_search_shared_conversation_info_returns_only_public_conversations( + self, + shared_conversation_info_service, + app_conversation_service, + sample_conversation_info, + sample_private_conversation_info, + ): + """Test that search only returns public conversations.""" + # Create both public and private conversations + await app_conversation_service.save_app_conversation_info( + sample_conversation_info + ) + await app_conversation_service.save_app_conversation_info( + sample_private_conversation_info + ) + + # Search for all conversations + result = ( + await shared_conversation_info_service.search_shared_conversation_info() + ) + + # Should only return the public conversation + assert len(result.items) == 1 + assert result.items[0].id == sample_conversation_info.id + assert result.items[0].title == sample_conversation_info.title + + @pytest.mark.asyncio + async def test_search_shared_conversation_info_with_title_filter( + self, + shared_conversation_info_service, + app_conversation_service, + sample_conversation_info, + ): + """Test searching with title filter.""" + # Create a public conversation + await app_conversation_service.save_app_conversation_info( + sample_conversation_info + ) + + # Search with matching title + result = await shared_conversation_info_service.search_shared_conversation_info( + title__contains='Test' + ) + assert len(result.items) == 1 + + # Search with non-matching title + result = await shared_conversation_info_service.search_shared_conversation_info( + title__contains='NonExistent' + ) + assert len(result.items) == 0 + + @pytest.mark.asyncio + async def test_search_shared_conversation_info_with_sort_order( + self, + shared_conversation_info_service, + app_conversation_service, + ): + """Test searching with different sort orders.""" + # Create multiple public conversations with different titles and timestamps + conv1 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user', + sandbox_id='test_sandbox_1', + title='A First Conversation', + created_at=datetime(2023, 1, 1, tzinfo=UTC), + updated_at=datetime(2023, 1, 1, tzinfo=UTC), + public=True, + metrics=MetricsSnapshot( + accumulated_cost=0.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage(), + ), + ) + conv2 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user', + sandbox_id='test_sandbox_2', + title='B Second Conversation', + created_at=datetime(2023, 1, 2, tzinfo=UTC), + updated_at=datetime(2023, 1, 2, tzinfo=UTC), + public=True, + metrics=MetricsSnapshot( + accumulated_cost=0.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage(), + ), + ) + + await app_conversation_service.save_app_conversation_info(conv1) + await app_conversation_service.save_app_conversation_info(conv2) + + # Test sort by title ascending + result = await shared_conversation_info_service.search_shared_conversation_info( + sort_order=SharedConversationSortOrder.TITLE + ) + assert len(result.items) == 2 + assert result.items[0].title == 'A First Conversation' + assert result.items[1].title == 'B Second Conversation' + + # Test sort by title descending + result = await shared_conversation_info_service.search_shared_conversation_info( + sort_order=SharedConversationSortOrder.TITLE_DESC + ) + assert len(result.items) == 2 + assert result.items[0].title == 'B Second Conversation' + assert result.items[1].title == 'A First Conversation' + + # Test sort by created_at ascending + result = await shared_conversation_info_service.search_shared_conversation_info( + sort_order=SharedConversationSortOrder.CREATED_AT + ) + assert len(result.items) == 2 + assert result.items[0].id == conv1.id + assert result.items[1].id == conv2.id + + # Test sort by created_at descending (default) + result = await shared_conversation_info_service.search_shared_conversation_info( + sort_order=SharedConversationSortOrder.CREATED_AT_DESC + ) + assert len(result.items) == 2 + assert result.items[0].id == conv2.id + assert result.items[1].id == conv1.id + + @pytest.mark.asyncio + async def test_count_shared_conversation_info( + self, + shared_conversation_info_service, + app_conversation_service, + sample_conversation_info, + sample_private_conversation_info, + ): + """Test counting public conversations.""" + # Initially should be 0 + count = await shared_conversation_info_service.count_shared_conversation_info() + assert count == 0 + + # Create a public conversation + await app_conversation_service.save_app_conversation_info( + sample_conversation_info + ) + count = await shared_conversation_info_service.count_shared_conversation_info() + assert count == 1 + + # Create a private conversation - count should remain 1 + await app_conversation_service.save_app_conversation_info( + sample_private_conversation_info + ) + count = await shared_conversation_info_service.count_shared_conversation_info() + assert count == 1 + + @pytest.mark.asyncio + async def test_batch_get_shared_conversation_info( + self, + shared_conversation_info_service, + app_conversation_service, + sample_conversation_info, + sample_private_conversation_info, + ): + """Test batch getting public conversations.""" + # Create both public and private conversations + await app_conversation_service.save_app_conversation_info( + sample_conversation_info + ) + await app_conversation_service.save_app_conversation_info( + sample_private_conversation_info + ) + + # Batch get both conversations + result = ( + await shared_conversation_info_service.batch_get_shared_conversation_info( + [sample_conversation_info.id, sample_private_conversation_info.id] + ) + ) + + # Should return the public one and None for the private one + assert len(result) == 2 + assert result[0] is not None + assert result[0].id == sample_conversation_info.id + assert result[1] is None + + @pytest.mark.asyncio + async def test_search_with_pagination( + self, + shared_conversation_info_service, + app_conversation_service, + ): + """Test search with pagination.""" + # Create multiple public conversations + conversations = [] + for i in range(5): + conv = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user', + sandbox_id=f'test_sandbox_{i}', + title=f'Conversation {i}', + created_at=datetime(2023, 1, i + 1, tzinfo=UTC), + updated_at=datetime(2023, 1, i + 1, tzinfo=UTC), + public=True, + metrics=MetricsSnapshot( + accumulated_cost=0.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage(), + ), + ) + conversations.append(conv) + await app_conversation_service.save_app_conversation_info(conv) + + # Get first page with limit 2 + result = await shared_conversation_info_service.search_shared_conversation_info( + limit=2, sort_order=SharedConversationSortOrder.CREATED_AT + ) + assert len(result.items) == 2 + assert result.next_page_id is not None + + # Get next page + result2 = ( + await shared_conversation_info_service.search_shared_conversation_info( + limit=2, + page_id=result.next_page_id, + sort_order=SharedConversationSortOrder.CREATED_AT, + ) + ) + assert len(result2.items) == 2 + assert result2.next_page_id is not None + + # Verify no overlap between pages + page1_ids = {item.id for item in result.items} + page2_ids = {item.id for item in result2.items} + assert page1_ids.isdisjoint(page2_ids) diff --git a/enterprise/tests/unit/test_sharing/test_sharing_shared_event_service.py b/enterprise/tests/unit/test_sharing/test_sharing_shared_event_service.py new file mode 100644 index 0000000000..e12e8f0fad --- /dev/null +++ b/enterprise/tests/unit/test_sharing/test_sharing_shared_event_service.py @@ -0,0 +1,365 @@ +"""Tests for SharedEventService.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock +from uuid import uuid4 + +import pytest +from server.sharing.filesystem_shared_event_service import ( + SharedEventServiceImpl, +) +from server.sharing.shared_conversation_info_service import ( + SharedConversationInfoService, +) +from server.sharing.shared_conversation_models import SharedConversation + +from openhands.agent_server.models import EventPage, EventSortOrder +from openhands.app_server.event.event_service import EventService +from openhands.sdk.llm import MetricsSnapshot +from openhands.sdk.llm.utils.metrics import TokenUsage + + +@pytest.fixture +def mock_shared_conversation_info_service(): + """Create a mock SharedConversationInfoService.""" + return AsyncMock(spec=SharedConversationInfoService) + + +@pytest.fixture +def mock_event_service(): + """Create a mock EventService.""" + return AsyncMock(spec=EventService) + + +@pytest.fixture +def shared_event_service(mock_shared_conversation_info_service, mock_event_service): + """Create a SharedEventService for testing.""" + return SharedEventServiceImpl( + shared_conversation_info_service=mock_shared_conversation_info_service, + event_service=mock_event_service, + ) + + +@pytest.fixture +def sample_public_conversation(): + """Create a sample public conversation.""" + return SharedConversation( + id=uuid4(), + created_by_user_id='test_user', + sandbox_id='test_sandbox', + title='Test Public Conversation', + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + metrics=MetricsSnapshot( + accumulated_cost=0.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage(), + ), + ) + + +@pytest.fixture +def sample_event(): + """Create a sample event.""" + # For testing purposes, we'll just use a mock that the EventPage can accept + # The actual event creation is complex and not the focus of these tests + return None + + +class TestSharedEventService: + """Test cases for SharedEventService.""" + + async def test_get_shared_event_returns_event_for_public_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + mock_event_service, + sample_public_conversation, + sample_event, + ): + """Test that get_shared_event returns an event for a public conversation.""" + conversation_id = sample_public_conversation.id + event_id = 'test_event_id' + + # Mock the public conversation service to return a public conversation + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation + + # Mock the event service to return an event + mock_event_service.get_event.return_value = sample_event + + # Call the method + result = await shared_event_service.get_shared_event(conversation_id, event_id) + + # Verify the result + assert result == sample_event + mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with( + conversation_id + ) + mock_event_service.get_event.assert_called_once_with(event_id) + + async def test_get_shared_event_returns_none_for_private_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + mock_event_service, + ): + """Test that get_shared_event returns None for a private conversation.""" + conversation_id = uuid4() + event_id = 'test_event_id' + + # Mock the public conversation service to return None (private conversation) + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None + + # Call the method + result = await shared_event_service.get_shared_event(conversation_id, event_id) + + # Verify the result + assert result is None + mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with( + conversation_id + ) + # Event service should not be called + mock_event_service.get_event.assert_not_called() + + async def test_search_shared_events_returns_events_for_public_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + mock_event_service, + sample_public_conversation, + sample_event, + ): + """Test that search_shared_events returns events for a public conversation.""" + conversation_id = sample_public_conversation.id + + # Mock the public conversation service to return a public conversation + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation + + # Mock the event service to return events + mock_event_page = EventPage(items=[], next_page_id=None) + mock_event_service.search_events.return_value = mock_event_page + + # Call the method + result = await shared_event_service.search_shared_events( + conversation_id=conversation_id, + kind__eq='ActionEvent', + limit=10, + ) + + # Verify the result + assert result == mock_event_page + assert len(result.items) == 0 # Empty list as we mocked + + mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with( + conversation_id + ) + mock_event_service.search_events.assert_called_once_with( + conversation_id__eq=conversation_id, + kind__eq='ActionEvent', + timestamp__gte=None, + timestamp__lt=None, + sort_order=EventSortOrder.TIMESTAMP, + page_id=None, + limit=10, + ) + + async def test_search_shared_events_returns_empty_for_private_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + mock_event_service, + ): + """Test that search_shared_events returns empty page for a private conversation.""" + conversation_id = uuid4() + + # Mock the public conversation service to return None (private conversation) + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None + + # Call the method + result = await shared_event_service.search_shared_events( + conversation_id=conversation_id, + limit=10, + ) + + # Verify the result + assert isinstance(result, EventPage) + assert len(result.items) == 0 + assert result.next_page_id is None + + mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with( + conversation_id + ) + # Event service should not be called + mock_event_service.search_events.assert_not_called() + + async def test_count_shared_events_returns_count_for_public_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + mock_event_service, + sample_public_conversation, + ): + """Test that count_shared_events returns count for a public conversation.""" + conversation_id = sample_public_conversation.id + + # Mock the public conversation service to return a public conversation + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation + + # Mock the event service to return a count + mock_event_service.count_events.return_value = 5 + + # Call the method + result = await shared_event_service.count_shared_events( + conversation_id=conversation_id, + kind__eq='ActionEvent', + ) + + # Verify the result + assert result == 5 + + mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with( + conversation_id + ) + mock_event_service.count_events.assert_called_once_with( + conversation_id__eq=conversation_id, + kind__eq='ActionEvent', + timestamp__gte=None, + timestamp__lt=None, + sort_order=EventSortOrder.TIMESTAMP, + ) + + async def test_count_shared_events_returns_zero_for_private_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + mock_event_service, + ): + """Test that count_shared_events returns 0 for a private conversation.""" + conversation_id = uuid4() + + # Mock the public conversation service to return None (private conversation) + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None + + # Call the method + result = await shared_event_service.count_shared_events( + conversation_id=conversation_id, + ) + + # Verify the result + assert result == 0 + + mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with( + conversation_id + ) + # Event service should not be called + mock_event_service.count_events.assert_not_called() + + async def test_batch_get_shared_events_returns_events_for_public_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + mock_event_service, + sample_public_conversation, + sample_event, + ): + """Test that batch_get_shared_events returns events for a public conversation.""" + conversation_id = sample_public_conversation.id + event_ids = ['event1', 'event2'] + + # Mock the public conversation service to return a public conversation + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation + + # Mock the event service to return events + mock_event_service.get_event.side_effect = [sample_event, None] + + # Call the method + result = await shared_event_service.batch_get_shared_events( + conversation_id, event_ids + ) + + # Verify the result + assert len(result) == 2 + assert result[0] == sample_event + assert result[1] is None + + # Verify that get_shared_conversation_info was called for each event + assert ( + mock_shared_conversation_info_service.get_shared_conversation_info.call_count + == 2 + ) + # Verify that get_event was called for each event + assert mock_event_service.get_event.call_count == 2 + + async def test_batch_get_shared_events_returns_none_for_private_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + mock_event_service, + ): + """Test that batch_get_shared_events returns None for a private conversation.""" + conversation_id = uuid4() + event_ids = ['event1', 'event2'] + + # Mock the public conversation service to return None (private conversation) + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None + + # Call the method + result = await shared_event_service.batch_get_shared_events( + conversation_id, event_ids + ) + + # Verify the result + assert len(result) == 2 + assert result[0] is None + assert result[1] is None + + # Verify that get_shared_conversation_info was called for each event + assert ( + mock_shared_conversation_info_service.get_shared_conversation_info.call_count + == 2 + ) + # Event service should not be called + mock_event_service.get_event.assert_not_called() + + async def test_search_shared_events_with_all_parameters( + self, + shared_event_service, + mock_shared_conversation_info_service, + mock_event_service, + sample_public_conversation, + ): + """Test search_shared_events with all parameters.""" + conversation_id = sample_public_conversation.id + timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC) + timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC) + + # Mock the public conversation service to return a public conversation + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation + + # Mock the event service to return events + mock_event_page = EventPage(items=[], next_page_id='next_page') + mock_event_service.search_events.return_value = mock_event_page + + # Call the method with all parameters + result = await shared_event_service.search_shared_events( + conversation_id=conversation_id, + kind__eq='ObservationEvent', + timestamp__gte=timestamp_gte, + timestamp__lt=timestamp_lt, + sort_order=EventSortOrder.TIMESTAMP_DESC, + page_id='current_page', + limit=50, + ) + + # Verify the result + assert result == mock_event_page + + mock_event_service.search_events.assert_called_once_with( + conversation_id__eq=conversation_id, + kind__eq='ObservationEvent', + timestamp__gte=timestamp_gte, + timestamp__lt=timestamp_lt, + sort_order=EventSortOrder.TIMESTAMP_DESC, + page_id='current_page', + limit=50, + ) diff --git a/openhands/app_server/app_conversation/app_conversation_models.py b/openhands/app_server/app_conversation/app_conversation_models.py index 1c0ba914cb..58a63a95d6 100644 --- a/openhands/app_server/app_conversation/app_conversation_models.py +++ b/openhands/app_server/app_conversation/app_conversation_models.py @@ -45,6 +45,8 @@ class AppConversationInfo(BaseModel): parent_conversation_id: OpenHandsUUID | None = None sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list) + public: bool | None = None + created_at: datetime = Field(default_factory=utc_now) updated_at: datetime = Field(default_factory=utc_now) @@ -114,6 +116,12 @@ class AppConversationStartRequest(BaseModel): parent_conversation_id: OpenHandsUUID | None = None agent_type: AgentType = Field(default=AgentType.DEFAULT) + public: bool | None = None + + +class AppConversationUpdateRequest(BaseModel): + public: bool + class AppConversationStartTaskStatus(Enum): WORKING = 'WORKING' diff --git a/openhands/app_server/app_conversation/app_conversation_router.py b/openhands/app_server/app_conversation/app_conversation_router.py index f68b80ba4e..29ae3f69d7 100644 --- a/openhands/app_server/app_conversation/app_conversation_router.py +++ b/openhands/app_server/app_conversation/app_conversation_router.py @@ -40,6 +40,7 @@ from openhands.app_server.app_conversation.app_conversation_models import ( AppConversationStartTask, AppConversationStartTaskPage, AppConversationStartTaskSortOrder, + AppConversationUpdateRequest, SkillResponse, ) from openhands.app_server.app_conversation.app_conversation_service import ( @@ -222,6 +223,22 @@ async def start_app_conversation( raise +@router.patch('/{conversation_id}') +async def update_app_conversation( + conversation_id: str, + update_request: AppConversationUpdateRequest, + app_conversation_service: AppConversationService = ( + app_conversation_service_dependency + ), +) -> AppConversation: + info = await app_conversation_service.update_app_conversation( + UUID(conversation_id), update_request + ) + if info is None: + raise HTTPException(404, 'unknown_app_conversation') + return info + + @router.post('/stream-start') async def stream_app_conversation_start( request: AppConversationStartRequest, diff --git a/openhands/app_server/app_conversation/app_conversation_service.py b/openhands/app_server/app_conversation/app_conversation_service.py index b1b10c39ba..dd98dd44c9 100644 --- a/openhands/app_server/app_conversation/app_conversation_service.py +++ b/openhands/app_server/app_conversation/app_conversation_service.py @@ -10,6 +10,7 @@ from openhands.app_server.app_conversation.app_conversation_models import ( AppConversationSortOrder, AppConversationStartRequest, AppConversationStartTask, + AppConversationUpdateRequest, ) from openhands.app_server.sandbox.sandbox_models import SandboxInfo from openhands.app_server.services.injector import Injector @@ -98,6 +99,13 @@ class AppConversationService(ABC): """Run the setup scripts for the project and yield status updates""" yield task + @abstractmethod + async def update_app_conversation( + self, conversation_id: UUID, request: AppConversationUpdateRequest + ) -> AppConversation | None: + """Update an app conversation and return it. Return None if the conversation + did not exist.""" + @abstractmethod async def delete_app_conversation(self, conversation_id: UUID) -> bool: """Delete a V1 conversation and all its associated data. diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index 84f20de07a..11d9e4fef8 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -32,6 +32,7 @@ from openhands.app_server.app_conversation.app_conversation_models import ( AppConversationStartRequest, AppConversationStartTask, AppConversationStartTaskStatus, + AppConversationUpdateRequest, ) from openhands.app_server.app_conversation.app_conversation_service import ( AppConversationService, @@ -1049,6 +1050,23 @@ class LiveStatusAppConversationService(AppConversationServiceBase): f'Successfully updated agent-server conversation {conversation_id} title to "{new_title}"' ) + async def update_app_conversation( + self, conversation_id: UUID, request: AppConversationUpdateRequest + ) -> AppConversation | None: + """Update an app conversation and return it. Return None if the conversation + did not exist.""" + info = await self.app_conversation_info_service.get_app_conversation_info( + conversation_id + ) + if info is None: + return None + for field_name in request.model_fields: + value = getattr(request, field_name) + setattr(info, field_name, value) + info = await self.app_conversation_info_service.save_app_conversation_info(info) + conversations = await self._build_app_conversations([info]) + return conversations[0] + async def delete_app_conversation(self, conversation_id: UUID) -> bool: """Delete a V1 conversation and all its associated data. diff --git a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py index 83e2d1915b..7764b99e77 100644 --- a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py @@ -25,7 +25,17 @@ from typing import AsyncGenerator from uuid import UUID from fastapi import Request -from sqlalchemy import Column, DateTime, Float, Integer, Select, String, func, select +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Float, + Integer, + Select, + String, + func, + select, +) from sqlalchemy.ext.asyncio import AsyncSession from openhands.agent_server.utils import utc_now @@ -91,6 +101,7 @@ class StoredConversationMetadata(Base): # type: ignore conversation_version = Column(String, nullable=False, default='V0', index=True) sandbox_id = Column(String, nullable=True, index=True) parent_conversation_id = Column(String, nullable=True, index=True) + public = Column(Boolean, nullable=True, index=True) @dataclass @@ -350,6 +361,7 @@ class SQLAppConversationInfoService(AppConversationInfoService): if info.parent_conversation_id else None ), + public=info.public, ) await self.db_session.merge(stored) @@ -541,6 +553,7 @@ class SQLAppConversationInfoService(AppConversationInfoService): else None ), sub_conversation_ids=sub_conversation_ids or [], + public=stored.public, created_at=created_at, updated_at=updated_at, ) diff --git a/openhands/app_server/app_lifespan/alembic/versions/004.py b/openhands/app_server/app_lifespan/alembic/versions/004.py new file mode 100644 index 0000000000..2d5ef07f41 --- /dev/null +++ b/openhands/app_server/app_lifespan/alembic/versions/004.py @@ -0,0 +1,41 @@ +"""add public column to conversation_metadata + +Revision ID: 004 +Revises: 003 +Create Date: 2025-01-27 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '004' +down_revision: Union[str, None] = '003' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column( + 'conversation_metadata', + sa.Column('public', sa.Boolean(), nullable=True), + ) + op.create_index( + op.f('ix_conversation_metadata_public'), + 'conversation_metadata', + ['public'], + unique=False, + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_index( + op.f('ix_conversation_metadata_public'), + table_name='conversation_metadata', + ) + op.drop_column('conversation_metadata', 'public') diff --git a/openhands/app_server/event/event_router.py b/openhands/app_server/event/event_router.py index 3476c155e9..3431bf2815 100644 --- a/openhands/app_server/event/event_router.py +++ b/openhands/app_server/event/event_router.py @@ -22,7 +22,7 @@ event_service_dependency = depends_event_service() @router.get('/search') async def search_events( conversation_id__eq: Annotated[ - UUID | None, + str | None, Query(title='Optional filter by conversation ID'), ] = None, kind__eq: Annotated[ @@ -55,7 +55,7 @@ async def search_events( assert limit > 0 assert limit <= 100 return await event_service.search_events( - conversation_id__eq=conversation_id__eq, + conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None, kind__eq=kind__eq, timestamp__gte=timestamp__gte, timestamp__lt=timestamp__lt, @@ -68,7 +68,7 @@ async def search_events( @router.get('/count') async def count_events( conversation_id__eq: Annotated[ - UUID | None, + str | None, Query(title='Optional filter by conversation ID'), ] = None, kind__eq: Annotated[ @@ -91,7 +91,7 @@ async def count_events( ) -> int: """Count events matching the given filters.""" return await event_service.count_events( - conversation_id__eq=conversation_id__eq, + conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None, kind__eq=kind__eq, timestamp__gte=timestamp__gte, timestamp__lt=timestamp__lt, diff --git a/openhands/app_server/event/filesystem_event_service.py b/openhands/app_server/event/filesystem_event_service.py index 05e2ed9350..1f98fcec05 100644 --- a/openhands/app_server/event/filesystem_event_service.py +++ b/openhands/app_server/event/filesystem_event_service.py @@ -1,32 +1,27 @@ """Filesystem-based EventService implementation.""" -import asyncio -import glob import json -import logging from dataclasses import dataclass -from datetime import datetime from pathlib import Path from typing import AsyncGenerator from uuid import UUID from fastapi import Request -from openhands.agent_server.models import EventPage, EventSortOrder from openhands.app_server.app_conversation.app_conversation_info_service import ( AppConversationInfoService, ) from openhands.app_server.errors import OpenHandsError from openhands.app_server.event.event_service import EventService, EventServiceInjector -from openhands.app_server.event_callback.event_callback_models import EventKind +from openhands.app_server.event.filesystem_event_service_base import ( + FilesystemEventServiceBase, +) from openhands.app_server.services.injector import InjectorState from openhands.sdk import Event -_logger = logging.getLogger(__name__) - @dataclass -class FilesystemEventService(EventService): +class FilesystemEventService(FilesystemEventServiceBase, EventService): """Filesystem-based implementation of EventService. Events are stored in files with the naming format: @@ -47,25 +42,6 @@ class FilesystemEventService(EventService): events_path.mkdir(parents=True, exist_ok=True) return events_path - def _timestamp_to_str(self, timestamp: datetime | str) -> str: - """Convert timestamp to YYYYMMDDHHMMSS format.""" - if isinstance(timestamp, str): - # Parse ISO format timestamp string - dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) - return dt.strftime('%Y%m%d%H%M%S') - return timestamp.strftime('%Y%m%d%H%M%S') - - def _get_event_filename(self, conversation_id: UUID, event: Event) -> str: - """Generate filename using YYYYMMDDHHMMSS_kind_id.hex format.""" - timestamp_str = self._timestamp_to_str(event.timestamp) - kind = event.__class__.__name__ - # Handle both UUID objects and string UUIDs - if isinstance(event.id, str): - id_hex = event.id.replace('-', '') - else: - id_hex = event.id.hex - return f'{timestamp_str}_{kind}_{id_hex}' - def _save_event_to_file(self, conversation_id: UUID, event: Event) -> None: """Save an event to a file.""" events_path = self._ensure_events_dir(conversation_id) @@ -77,60 +53,17 @@ class FilesystemEventService(EventService): data = event.model_dump(mode='json') f.write(json.dumps(data, indent=2)) - def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]: - events = [] - for file_path in file_paths: - event = self._load_event_from_file(file_path) - if event is not None: - events.append(event) - return events - - def _load_event_from_file(self, filepath: Path) -> Event | None: - """Load an event from a file.""" - try: - json_data = filepath.read_text() - return Event.model_validate_json(json_data) - except Exception: - return None - - def _get_event_files_by_pattern( - self, pattern: str, conversation_id: UUID | None = None - ) -> list[Path]: - """Get event files matching a glob pattern, sorted by timestamp.""" - if conversation_id: - search_path = self.events_dir / str(conversation_id) / pattern - else: - search_path = self.events_dir / '*' / pattern - - files = glob.glob(str(search_path)) - return sorted([Path(f) for f in files]) - - def _parse_filename(self, filename: str) -> dict[str, str] | None: - """Parse filename to extract timestamp, kind, and event_id.""" - try: - parts = filename.split('_') - if len(parts) >= 3: - timestamp_str = parts[0] - kind = '_'.join(parts[1:-1]) # Handle kinds with underscores - event_id = parts[-1] - return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id} - except Exception: - pass - return None - - def _get_conversation_id(self, file: Path) -> UUID | None: - try: - return UUID(file.parent.name) - except Exception: - return None - - def _get_conversation_ids(self, files: list[Path]) -> set[UUID]: - result = set() - for file in files: - conversation_id = self._get_conversation_id(file) - if conversation_id: - result.add(conversation_id) - return result + async def save_event(self, conversation_id: UUID, event: Event): + """Save an event. Internal method intended not be part of the REST api.""" + conversation = ( + await self.app_conversation_info_service.get_app_conversation_info( + conversation_id + ) + ) + if not conversation: + # This is either an illegal state or somebody is trying to hack + raise OpenHandsError('No such conversation: {conversaiont_id}') + self._save_event_to_file(conversation_id, event) async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]: conversation_ids = list(self._get_conversation_ids(files)) @@ -150,161 +83,6 @@ class FilesystemEventService(EventService): ] return result - def _filter_files_by_criteria( - self, - files: list[Path], - conversation_id__eq: UUID | None = None, - kind__eq: EventKind | None = None, - timestamp__gte: datetime | None = None, - timestamp__lt: datetime | None = None, - ) -> list[Path]: - """Filter files based on search criteria.""" - filtered_files = [] - - for file_path in files: - # Check conversation_id filter - if conversation_id__eq: - if str(conversation_id__eq) not in str(file_path): - continue - - # Parse filename for additional filtering - filename_info = self._parse_filename(file_path.name) - if not filename_info: - continue - - # Check kind filter - if kind__eq and filename_info['kind'] != kind__eq: - continue - - # Check timestamp filters - if timestamp__gte or timestamp__lt: - try: - file_timestamp = datetime.strptime( - filename_info['timestamp'], '%Y%m%d%H%M%S' - ) - if timestamp__gte and file_timestamp < timestamp__gte: - continue - if timestamp__lt and file_timestamp >= timestamp__lt: - continue - except ValueError: - continue - - filtered_files.append(file_path) - - return filtered_files - - async def get_event(self, event_id: str) -> Event | None: - """Get the event with the given id, or None if not found.""" - # Convert event_id to hex format (remove dashes) for filename matching - if isinstance(event_id, str) and '-' in event_id: - id_hex = event_id.replace('-', '') - else: - id_hex = event_id - - # Use glob pattern to find files ending with the event_id - pattern = f'*_{id_hex}' - files = self._get_event_files_by_pattern(pattern) - - if not files: - return None - - # If there is no access to the conversation do not return the event - file = files[0] - conversation_id = self._get_conversation_id(file) - if not conversation_id: - return None - conversation = ( - await self.app_conversation_info_service.get_app_conversation_info( - conversation_id - ) - ) - if not conversation: - return None - - # Load and return the first matching event - return self._load_event_from_file(file) - - async def search_events( - self, - conversation_id__eq: UUID | None = None, - kind__eq: EventKind | None = None, - timestamp__gte: datetime | None = None, - timestamp__lt: datetime | None = None, - sort_order: EventSortOrder = EventSortOrder.TIMESTAMP, - page_id: str | None = None, - limit: int = 100, - ) -> EventPage: - """Search for events matching the given filters.""" - # Build the search pattern - pattern = '*' - files = self._get_event_files_by_pattern(pattern, conversation_id__eq) - - files = await self._filter_files_by_conversation(files) - - files = self._filter_files_by_criteria( - files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt - ) - - files.sort( - key=lambda f: f.name, - reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC), - ) - - # Handle pagination - start_index = 0 - if page_id: - for i, file_path in enumerate(files): - if file_path.name == page_id: - start_index = i + 1 - break - - # Collect items for this page - page_files = files[start_index : start_index + limit] - next_page_id = None - if start_index + limit < len(files): - next_page_id = files[start_index + limit].name - - # Load all events from files in a background thread. - loop = asyncio.get_running_loop() - page_events = await loop.run_in_executor( - None, self._load_events_from_files, page_files - ) - - return EventPage(items=page_events, next_page_id=next_page_id) - - async def count_events( - self, - conversation_id__eq: UUID | None = None, - kind__eq: EventKind | None = None, - timestamp__gte: datetime | None = None, - timestamp__lt: datetime | None = None, - sort_order: EventSortOrder = EventSortOrder.TIMESTAMP, - ) -> int: - """Count events matching the given filters.""" - # Build the search pattern - pattern = '*' - files = self._get_event_files_by_pattern(pattern, conversation_id__eq) - - files = await self._filter_files_by_conversation(files) - - files = self._filter_files_by_criteria( - files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt - ) - - return len(files) - - async def save_event(self, conversation_id: UUID, event: Event): - """Save an event. Internal method intended not be part of the REST api.""" - conversation = ( - await self.app_conversation_info_service.get_app_conversation_info( - conversation_id - ) - ) - if not conversation: - # This is either an illegal state or somebody is trying to hack - raise OpenHandsError('No such conversation: {conversaiont_id}') - self._save_event_to_file(conversation_id, event) - class FilesystemEventServiceInjector(EventServiceInjector): async def inject( diff --git a/openhands/app_server/event/filesystem_event_service_base.py b/openhands/app_server/event/filesystem_event_service_base.py new file mode 100644 index 0000000000..b957f5f24a --- /dev/null +++ b/openhands/app_server/event/filesystem_event_service_base.py @@ -0,0 +1,224 @@ +import asyncio +import glob +from abc import abstractmethod +from datetime import datetime +from pathlib import Path +from uuid import UUID + +from openhands.agent_server.models import EventPage, EventSortOrder +from openhands.app_server.event_callback.event_callback_models import EventKind +from openhands.sdk import Event + + +class FilesystemEventServiceBase: + events_dir: Path + + async def get_event(self, event_id: str) -> Event | None: + """Get the event with the given id, or None if not found.""" + # Convert event_id to hex format (remove dashes) for filename matching + if isinstance(event_id, str) and '-' in event_id: + id_hex = event_id.replace('-', '') + else: + id_hex = event_id + + # Use glob pattern to find files ending with the event_id + pattern = f'*_{id_hex}' + files = self._get_event_files_by_pattern(pattern) + + files = await self._filter_files_by_conversation(files) + + if not files: + return None + + # Load and return the first matching event + return self._load_event_from_file(files[0]) + + async def search_events( + self, + conversation_id__eq: UUID | None = None, + kind__eq: EventKind | None = None, + timestamp__gte: datetime | None = None, + timestamp__lt: datetime | None = None, + sort_order: EventSortOrder = EventSortOrder.TIMESTAMP, + page_id: str | None = None, + limit: int = 100, + ) -> EventPage: + """Search for events matching the given filters.""" + # Build the search pattern + pattern = '*' + files = self._get_event_files_by_pattern(pattern, conversation_id__eq) + + files = await self._filter_files_by_conversation(files) + + files = self._filter_files_by_criteria( + files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt + ) + + files.sort( + key=lambda f: f.name, + reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC), + ) + + # Handle pagination + start_index = 0 + if page_id: + for i, file_path in enumerate(files): + if file_path.name == page_id: + start_index = i + 1 + break + + # Collect items for this page + page_files = files[start_index : start_index + limit] + next_page_id = None + if start_index + limit < len(files): + next_page_id = files[start_index + limit].name + + # Load all events from files in a background thread. + loop = asyncio.get_running_loop() + page_events = await loop.run_in_executor( + None, self._load_events_from_files, page_files + ) + + return EventPage(items=page_events, next_page_id=next_page_id) + + async def count_events( + self, + conversation_id__eq: UUID | None = None, + kind__eq: EventKind | None = None, + timestamp__gte: datetime | None = None, + timestamp__lt: datetime | None = None, + sort_order: EventSortOrder = EventSortOrder.TIMESTAMP, + ) -> int: + """Count events matching the given filters.""" + # Build the search pattern + pattern = '*' + files = self._get_event_files_by_pattern(pattern, conversation_id__eq) + + files = await self._filter_files_by_conversation(files) + + files = self._filter_files_by_criteria( + files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt + ) + + return len(files) + + def _get_event_filename(self, conversation_id: UUID, event: Event) -> str: + """Generate filename using YYYYMMDDHHMMSS_kind_id.hex format.""" + timestamp_str = self._timestamp_to_str(event.timestamp) + kind = event.__class__.__name__ + # Handle both UUID objects and string UUIDs + if isinstance(event.id, str): + id_hex = event.id.replace('-', '') + else: + id_hex = event.id.hex + return f'{timestamp_str}_{kind}_{id_hex}' + + def _timestamp_to_str(self, timestamp: datetime | str) -> str: + """Convert timestamp to YYYYMMDDHHMMSS format.""" + if isinstance(timestamp, str): + # Parse ISO format timestamp string + dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) + return dt.strftime('%Y%m%d%H%M%S') + return timestamp.strftime('%Y%m%d%H%M%S') + + def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]: + events = [] + for file_path in file_paths: + event = self._load_event_from_file(file_path) + if event is not None: + events.append(event) + return events + + def _load_event_from_file(self, filepath: Path) -> Event | None: + """Load an event from a file.""" + try: + json_data = filepath.read_text() + return Event.model_validate_json(json_data) + except Exception: + return None + + def _get_event_files_by_pattern( + self, pattern: str, conversation_id: UUID | None = None + ) -> list[Path]: + """Get event files matching a glob pattern, sorted by timestamp.""" + if conversation_id: + search_path = self.events_dir / str(conversation_id) / pattern + else: + search_path = self.events_dir / '*' / pattern + + files = glob.glob(str(search_path)) + return sorted([Path(f) for f in files]) + + def _parse_filename(self, filename: str) -> dict[str, str] | None: + """Parse filename to extract timestamp, kind, and event_id.""" + try: + parts = filename.split('_') + if len(parts) >= 3: + timestamp_str = parts[0] + kind = '_'.join(parts[1:-1]) # Handle kinds with underscores + event_id = parts[-1] + return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id} + except Exception: + pass + return None + + def _get_conversation_id(self, file: Path) -> UUID | None: + try: + return UUID(file.parent.name) + except Exception: + return None + + def _get_conversation_ids(self, files: list[Path]) -> set[UUID]: + result = set() + for file in files: + conversation_id = self._get_conversation_id(file) + if conversation_id: + result.add(conversation_id) + return result + + @abstractmethod + async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]: + """Filter files by conversation.""" + + def _filter_files_by_criteria( + self, + files: list[Path], + conversation_id__eq: UUID | None = None, + kind__eq: EventKind | None = None, + timestamp__gte: datetime | None = None, + timestamp__lt: datetime | None = None, + ) -> list[Path]: + """Filter files based on search criteria.""" + filtered_files = [] + + for file_path in files: + # Check conversation_id filter + if conversation_id__eq: + if str(conversation_id__eq) not in str(file_path): + continue + + # Parse filename for additional filtering + filename_info = self._parse_filename(file_path.name) + if not filename_info: + continue + + # Check kind filter + if kind__eq and filename_info['kind'] != kind__eq: + continue + + # Check timestamp filters + if timestamp__gte or timestamp__lt: + try: + file_timestamp = datetime.strptime( + filename_info['timestamp'], '%Y%m%d%H%M%S' + ) + if timestamp__gte and file_timestamp < timestamp__gte: + continue + if timestamp__lt and file_timestamp >= timestamp__lt: + continue + except ValueError: + continue + + filtered_files.append(file_path) + + return filtered_files diff --git a/openhands/server/data_models/conversation_info.py b/openhands/server/data_models/conversation_info.py index 78af0e3dc1..5ca7b80b08 100644 --- a/openhands/server/data_models/conversation_info.py +++ b/openhands/server/data_models/conversation_info.py @@ -29,3 +29,4 @@ class ConversationInfo: pr_number: list[int] = field(default_factory=list) conversation_version: str = 'V0' sub_conversation_ids: list[str] = field(default_factory=list) + public: bool | None = None diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 1793b07e7d..b88c2851e2 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -1501,4 +1501,5 @@ def _to_conversation_info(app_conversation: AppConversation) -> ConversationInfo sub_conversation_ids=[ sub_id.hex for sub_id in app_conversation.sub_conversation_ids ], + public=app_conversation.public, ) diff --git a/openhands/storage/data_models/conversation_metadata.py b/openhands/storage/data_models/conversation_metadata.py index 8febc9afbd..5e08907303 100644 --- a/openhands/storage/data_models/conversation_metadata.py +++ b/openhands/storage/data_models/conversation_metadata.py @@ -39,3 +39,4 @@ class ConversationMetadata: # V1 compatibility sandbox_id: str | None = None conversation_version: str | None = None + public: bool | None = None diff --git a/poetry.lock b/poetry.lock index 23789d3285..aba2364232 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12707,18 +12707,19 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.1.0" +version = "1.3.0" description = "Pytest support for asyncio" optional = false -python-versions = ">=3.9" -groups = ["test"] +python-versions = ">=3.10" +groups = ["dev", "test"] files = [ - {file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"}, - {file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"}, + {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, + {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, ] [package.dependencies] -pytest = ">=8.2,<9" +pytest = ">=8.2,<10" +typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] @@ -16823,4 +16824,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api [metadata] lock-version = "2.1" python-versions = "^3.12,<3.14" -content-hash = "9764f3b69ec8ed35feebd78a826bbc6bfa4ac6d5b56bc999be8bc738b644e538" +content-hash = "e24ceb52bccd0c80f52c408215ccf007475eb69e10b895053ea49c7e3e4be3b8" diff --git a/pyproject.toml b/pyproject.toml index c70c110dcc..fda3cc9b96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,7 @@ pre-commit = "4.2.0" build = "*" types-setuptools = "*" pytest = "^8.4.0" +pytest-asyncio = "^1.3.0" [tool.poetry.group.test] optional = true