diff --git a/enterprise/server/utils/saas_app_conversation_info_injector.py b/enterprise/server/utils/saas_app_conversation_info_injector.py index 16a7952a33..987f42ca10 100644 --- a/enterprise/server/utils/saas_app_conversation_info_injector.py +++ b/enterprise/server/utils/saas_app_conversation_info_injector.py @@ -119,6 +119,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService): created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, + sandbox_id__eq: str | None = None, sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 100, @@ -141,6 +142,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService): created_at__lt=created_at__lt, updated_at__gte=updated_at__gte, updated_at__lt=updated_at__lt, + sandbox_id__eq=sandbox_id__eq, ) # Add sort order @@ -198,6 +200,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService): created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, + sandbox_id__eq: str | None = None, ) -> int: """Count conversations matching the given filters with SAAS metadata.""" query = ( @@ -220,6 +223,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService): created_at__lt=created_at__lt, updated_at__gte=updated_at__gte, updated_at__lt=updated_at__lt, + sandbox_id__eq=sandbox_id__eq, ) result = await self.db_session.execute(query) @@ -234,6 +238,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService): created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, + sandbox_id__eq: str | None = None, ): """Apply filters to query that includes SAAS metadata.""" # Apply the same filters as the base class @@ -259,6 +264,9 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService): StoredConversationMetadata.last_updated_at < updated_at__lt ) + if sandbox_id__eq is not None: + conditions.append(StoredConversationMetadata.sandbox_id == sandbox_id__eq) + if conditions: query = query.where(*conditions) return query diff --git a/enterprise/tests/unit/storage/test_saas_sql_app_conversation_info_service.py b/enterprise/tests/unit/storage/test_saas_sql_app_conversation_info_service.py index 2ce5e3599c..0e0d1e9d35 100644 --- a/enterprise/tests/unit/storage/test_saas_sql_app_conversation_info_service.py +++ b/enterprise/tests/unit/storage/test_saas_sql_app_conversation_info_service.py @@ -791,3 +791,202 @@ class TestSaasSQLAppConversationInfoServiceWebhookFallback: assert len(user1_page.items) == 1 assert user1_page.items[0].id == conv_id assert user1_page.items[0].title == 'E2E Webhook Conversation' + + +class TestSandboxIdFilterSaas: + """Test suite for sandbox_id__eq filter parameter in SAAS service.""" + + @pytest.mark.asyncio + async def test_search_by_sandbox_id( + self, + async_session_with_users: AsyncSession, + ): + """Test searching conversations by exact sandbox_id match with SAAS user filtering.""" + # Create service for user1 + user1_service = SaasSQLAppConversationInfoService( + db_session=async_session_with_users, + user_context=SpecifyUserContext(user_id=str(USER1_ID)), + ) + + # Create conversations with different sandbox IDs for user1 + conv1 = AppConversationInfo( + id=uuid4(), + created_by_user_id=str(USER1_ID), + sandbox_id='sandbox_alpha', + title='Conversation Alpha', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + conv2 = AppConversationInfo( + id=uuid4(), + created_by_user_id=str(USER1_ID), + sandbox_id='sandbox_beta', + title='Conversation Beta', + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + conv3 = AppConversationInfo( + id=uuid4(), + created_by_user_id=str(USER1_ID), + sandbox_id='sandbox_alpha', + title='Conversation Gamma', + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await user1_service.save_app_conversation_info(conv1) + await user1_service.save_app_conversation_info(conv2) + await user1_service.save_app_conversation_info(conv3) + + # Search for sandbox_alpha - should return 2 conversations + page = await user1_service.search_app_conversation_info( + sandbox_id__eq='sandbox_alpha' + ) + assert len(page.items) == 2 + sandbox_ids = {item.sandbox_id for item in page.items} + assert sandbox_ids == {'sandbox_alpha'} + conversation_ids = {item.id for item in page.items} + assert conv1.id in conversation_ids + assert conv3.id in conversation_ids + + # Search for sandbox_beta - should return 1 conversation + page = await user1_service.search_app_conversation_info( + sandbox_id__eq='sandbox_beta' + ) + assert len(page.items) == 1 + assert page.items[0].id == conv2.id + + # Search for non-existent sandbox - should return 0 conversations + page = await user1_service.search_app_conversation_info( + sandbox_id__eq='sandbox_nonexistent' + ) + assert len(page.items) == 0 + + @pytest.mark.asyncio + async def test_count_by_sandbox_id( + self, + async_session_with_users: AsyncSession, + ): + """Test counting conversations by exact sandbox_id match with SAAS user filtering.""" + # Create service for user1 + user1_service = SaasSQLAppConversationInfoService( + db_session=async_session_with_users, + user_context=SpecifyUserContext(user_id=str(USER1_ID)), + ) + + # Create conversations with different sandbox IDs + conv1 = AppConversationInfo( + id=uuid4(), + created_by_user_id=str(USER1_ID), + sandbox_id='sandbox_x', + title='Conversation X1', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + conv2 = AppConversationInfo( + id=uuid4(), + created_by_user_id=str(USER1_ID), + sandbox_id='sandbox_y', + title='Conversation Y1', + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + conv3 = AppConversationInfo( + id=uuid4(), + created_by_user_id=str(USER1_ID), + sandbox_id='sandbox_x', + title='Conversation X2', + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await user1_service.save_app_conversation_info(conv1) + await user1_service.save_app_conversation_info(conv2) + await user1_service.save_app_conversation_info(conv3) + + # Count for sandbox_x - should be 2 + count = await user1_service.count_app_conversation_info( + sandbox_id__eq='sandbox_x' + ) + assert count == 2 + + # Count for sandbox_y - should be 1 + count = await user1_service.count_app_conversation_info( + sandbox_id__eq='sandbox_y' + ) + assert count == 1 + + # Count for non-existent sandbox - should be 0 + count = await user1_service.count_app_conversation_info( + sandbox_id__eq='sandbox_nonexistent' + ) + assert count == 0 + + @pytest.mark.asyncio + async def test_sandbox_id_filter_respects_user_isolation( + self, + async_session_with_users: AsyncSession, + ): + """Test that sandbox_id filter respects user isolation in SAAS environment.""" + # Create services for both users + user1_service = SaasSQLAppConversationInfoService( + db_session=async_session_with_users, + user_context=SpecifyUserContext(user_id=str(USER1_ID)), + ) + user2_service = SaasSQLAppConversationInfoService( + db_session=async_session_with_users, + user_context=SpecifyUserContext(user_id=str(USER2_ID)), + ) + + # Create conversation with same sandbox_id for both users + shared_sandbox_id = 'sandbox_shared' + + conv_user1 = AppConversationInfo( + id=uuid4(), + created_by_user_id=str(USER1_ID), + sandbox_id=shared_sandbox_id, + title='User1 Conversation', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + conv_user2 = AppConversationInfo( + id=uuid4(), + created_by_user_id=str(USER2_ID), + sandbox_id=shared_sandbox_id, + title='User2 Conversation', + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + + # Save conversations + await user1_service.save_app_conversation_info(conv_user1) + await user2_service.save_app_conversation_info(conv_user2) + + # User1 should only see their own conversation with this sandbox_id + page = await user1_service.search_app_conversation_info( + sandbox_id__eq=shared_sandbox_id + ) + assert len(page.items) == 1 + assert page.items[0].id == conv_user1.id + assert page.items[0].title == 'User1 Conversation' + + # User2 should only see their own conversation with this sandbox_id + page = await user2_service.search_app_conversation_info( + sandbox_id__eq=shared_sandbox_id + ) + assert len(page.items) == 1 + assert page.items[0].id == conv_user2.id + assert page.items[0].title == 'User2 Conversation' + + # Count should also respect user isolation + count = await user1_service.count_app_conversation_info( + sandbox_id__eq=shared_sandbox_id + ) + assert count == 1 + + count = await user2_service.count_app_conversation_info( + sandbox_id__eq=shared_sandbox_id + ) + assert count == 1 diff --git a/openhands/app_server/app_conversation/app_conversation_info_service.py b/openhands/app_server/app_conversation/app_conversation_info_service.py index 8e9f1ffe68..bb83ab5801 100644 --- a/openhands/app_server/app_conversation/app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/app_conversation_info_service.py @@ -24,6 +24,7 @@ class AppConversationInfoService(ABC): created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, + sandbox_id__eq: str | None = None, sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 100, @@ -39,6 +40,7 @@ class AppConversationInfoService(ABC): created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, + sandbox_id__eq: str | None = None, ) -> int: """Count sandboxed conversations.""" diff --git a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py index af4528c9a4..c7c9e1935e 100644 --- a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py @@ -119,6 +119,7 @@ class SQLAppConversationInfoService(AppConversationInfoService): created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, + sandbox_id__eq: str | None = None, sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 100, @@ -141,6 +142,7 @@ class SQLAppConversationInfoService(AppConversationInfoService): created_at__lt=created_at__lt, updated_at__gte=updated_at__gte, updated_at__lt=updated_at__lt, + sandbox_id__eq=sandbox_id__eq, ) # Add sort order @@ -195,6 +197,7 @@ class SQLAppConversationInfoService(AppConversationInfoService): created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, + sandbox_id__eq: str | None = None, ) -> int: """Count sandboxed conversations matching the given filters.""" query = select(func.count(StoredConversationMetadata.conversation_id)).where( @@ -208,6 +211,7 @@ class SQLAppConversationInfoService(AppConversationInfoService): created_at__lt=created_at__lt, updated_at__gte=updated_at__gte, updated_at__lt=updated_at__lt, + sandbox_id__eq=sandbox_id__eq, ) result = await self.db_session.execute(query) @@ -222,6 +226,7 @@ class SQLAppConversationInfoService(AppConversationInfoService): created_at__lt: datetime | None = None, updated_at__gte: datetime | None = None, updated_at__lt: datetime | None = None, + sandbox_id__eq: str | None = None, ) -> Select: # Apply the same filters as search_app_conversations conditions = [] @@ -246,6 +251,9 @@ class SQLAppConversationInfoService(AppConversationInfoService): StoredConversationMetadata.last_updated_at < updated_at__lt ) + if sandbox_id__eq is not None: + conditions.append(StoredConversationMetadata.sandbox_id == sandbox_id__eq) + if conditions: query = query.where(*conditions) return query diff --git a/tests/unit/app_server/test_sql_app_conversation_info_service.py b/tests/unit/app_server/test_sql_app_conversation_info_service.py index 2b741d984f..a491fa93af 100644 --- a/tests/unit/app_server/test_sql_app_conversation_info_service.py +++ b/tests/unit/app_server/test_sql_app_conversation_info_service.py @@ -943,3 +943,254 @@ class TestSQLAppConversationInfoService: assert parent_id in all_ids for sub_info in sub_conversations: assert sub_info.id in all_ids + + +class TestSandboxIdFilter: + """Test suite for sandbox_id__eq filter parameter.""" + + @pytest.mark.asyncio + async def test_search_by_sandbox_id( + self, + service: SQLAppConversationInfoService, + ): + """Test searching conversations by exact sandbox_id match.""" + # Create conversations with different sandbox IDs + conv1 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_alpha', + title='Conversation Alpha', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + conv2 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_beta', + title='Conversation Beta', + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + conv3 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_alpha', + title='Conversation Gamma', + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await service.save_app_conversation_info(conv1) + await service.save_app_conversation_info(conv2) + await service.save_app_conversation_info(conv3) + + # Search for sandbox_alpha - should return 2 conversations + page = await service.search_app_conversation_info( + sandbox_id__eq='sandbox_alpha' + ) + assert len(page.items) == 2 + sandbox_ids = {item.sandbox_id for item in page.items} + assert sandbox_ids == {'sandbox_alpha'} + conversation_ids = {item.id for item in page.items} + assert conv1.id in conversation_ids + assert conv3.id in conversation_ids + + # Search for sandbox_beta - should return 1 conversation + page = await service.search_app_conversation_info(sandbox_id__eq='sandbox_beta') + assert len(page.items) == 1 + assert page.items[0].id == conv2.id + assert page.items[0].sandbox_id == 'sandbox_beta' + + # Search for non-existent sandbox - should return 0 conversations + page = await service.search_app_conversation_info( + sandbox_id__eq='sandbox_nonexistent' + ) + assert len(page.items) == 0 + + @pytest.mark.asyncio + async def test_count_by_sandbox_id( + self, + service: SQLAppConversationInfoService, + ): + """Test counting conversations by exact sandbox_id match.""" + # Create conversations with different sandbox IDs + conv1 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_x', + title='Conversation X1', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + conv2 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_y', + title='Conversation Y1', + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + conv3 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_x', + title='Conversation X2', + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + conv4 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_x', + title='Conversation X3', + created_at=datetime(2024, 1, 1, 15, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 15, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await service.save_app_conversation_info(conv1) + await service.save_app_conversation_info(conv2) + await service.save_app_conversation_info(conv3) + await service.save_app_conversation_info(conv4) + + # Count for sandbox_x - should be 3 + count = await service.count_app_conversation_info(sandbox_id__eq='sandbox_x') + assert count == 3 + + # Count for sandbox_y - should be 1 + count = await service.count_app_conversation_info(sandbox_id__eq='sandbox_y') + assert count == 1 + + # Count for non-existent sandbox - should be 0 + count = await service.count_app_conversation_info( + sandbox_id__eq='sandbox_nonexistent' + ) + assert count == 0 + + @pytest.mark.asyncio + async def test_sandbox_id_filter_combined_with_title_filter( + self, + service: SQLAppConversationInfoService, + ): + """Test sandbox_id filter combined with title filter.""" + # Create conversations with different sandbox IDs and titles + conv1 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_project', + title='Feature: User Authentication', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + conv2 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_project', + title='Bug Fix: Login Issue', + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + conv3 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_other', + title='Feature: Payment System', + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await service.save_app_conversation_info(conv1) + await service.save_app_conversation_info(conv2) + await service.save_app_conversation_info(conv3) + + # Search for Feature in sandbox_project - should return 1 + page = await service.search_app_conversation_info( + sandbox_id__eq='sandbox_project', title__contains='Feature' + ) + assert len(page.items) == 1 + assert page.items[0].id == conv1.id + + # Search for Feature in sandbox_other - should return 1 + page = await service.search_app_conversation_info( + sandbox_id__eq='sandbox_other', title__contains='Feature' + ) + assert len(page.items) == 1 + assert page.items[0].id == conv3.id + + # Count for Bug in sandbox_project - should be 1 + count = await service.count_app_conversation_info( + sandbox_id__eq='sandbox_project', title__contains='Bug' + ) + assert count == 1 + + # Count for Bug in sandbox_other - should be 0 + count = await service.count_app_conversation_info( + sandbox_id__eq='sandbox_other', title__contains='Bug' + ) + assert count == 0 + + @pytest.mark.asyncio + async def test_sandbox_id_filter_with_date_filters( + self, + service: SQLAppConversationInfoService, + ): + """Test sandbox_id filter combined with date range filters.""" + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + # Create conversations in the same sandbox but at different times + conv1 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_time_test', + title='Conversation Early', + created_at=base_time, + updated_at=base_time, + ) + conv2 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_time_test', + title='Conversation Middle', + created_at=base_time.replace(hour=15), + updated_at=base_time.replace(hour=15), + ) + conv3 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_time_test', + title='Conversation Late', + created_at=base_time.replace(hour=18), + updated_at=base_time.replace(hour=18), + ) + conv4 = AppConversationInfo( + id=uuid4(), + created_by_user_id=None, + sandbox_id='sandbox_other_time', + title='Conversation Other', + created_at=base_time.replace(hour=15), + updated_at=base_time.replace(hour=15), + ) + + # Save all conversations + await service.save_app_conversation_info(conv1) + await service.save_app_conversation_info(conv2) + await service.save_app_conversation_info(conv3) + await service.save_app_conversation_info(conv4) + + # Search for sandbox_time_test with date filter - should return 2 + cutoff = base_time.replace(hour=14) + page = await service.search_app_conversation_info( + sandbox_id__eq='sandbox_time_test', created_at__gte=cutoff + ) + assert len(page.items) == 2 + conversation_ids = {item.id for item in page.items} + assert conv2.id in conversation_ids + assert conv3.id in conversation_ids + + # Count for sandbox_time_test with date filter + count = await service.count_app_conversation_info( + sandbox_id__eq='sandbox_time_test', created_at__gte=cutoff + ) + assert count == 2