From 9d0a19cf8f9b45af4d42eb0534cfb9fab18342f2 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Fri, 23 Jan 2026 22:13:08 +0700 Subject: [PATCH] fix(backend): ensure conversation events are written back to google cloud (#12571) --- .../sql_shared_conversation_info_service.py | 56 +++- ...haring_shared_conversation_info_service.py | 263 +++++++++++++++++- 2 files changed, 309 insertions(+), 10 deletions(-) diff --git a/enterprise/server/sharing/sql_shared_conversation_info_service.py b/enterprise/server/sharing/sql_shared_conversation_info_service.py index 42972a0163..bd584168dd 100644 --- a/enterprise/server/sharing/sql_shared_conversation_info_service.py +++ b/enterprise/server/sharing/sql_shared_conversation_info_service.py @@ -26,6 +26,7 @@ from server.sharing.shared_conversation_models import ( ) from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas from openhands.app_server.app_conversation.sql_app_conversation_info_service import ( StoredConversationMetadata, @@ -57,7 +58,7 @@ class SQLSharedConversationInfoService(SharedConversationInfoService): include_sub_conversations: bool = False, ) -> SharedConversationPage: """Search for shared conversations.""" - query = self._public_select() + query = self._public_select_with_saas_metadata() # Conditionally exclude sub-conversations based on the parameter if not include_sub_conversations: @@ -104,14 +105,17 @@ class SQLSharedConversationInfoService(SharedConversationInfoService): query = query.limit(limit + 1) result = await self.db_session.execute(query) - rows = result.scalars().all() + rows = result.all() # Check if there are more results has_more = len(rows) > limit if has_more: rows = rows[:limit] - items = [self._to_shared_conversation(row) for row in rows] + items = [ + self._to_shared_conversation(stored, saas_metadata=saas_metadata) + for stored, saas_metadata in rows + ] # Calculate next page ID next_page_id = None @@ -152,17 +156,18 @@ class SQLSharedConversationInfoService(SharedConversationInfoService): self, conversation_id: UUID ) -> SharedConversation | None: """Get a single public conversation info, returning None if missing or not shared.""" - query = self._public_select().where( + query = self._public_select_with_saas_metadata().where( StoredConversationMetadata.conversation_id == str(conversation_id) ) result = await self.db_session.execute(query) - stored = result.scalar_one_or_none() + row = result.first() - if stored is None: + if row is None: return None - return self._to_shared_conversation(stored) + stored, saas_metadata = row + return self._to_shared_conversation(stored, saas_metadata=saas_metadata) def _public_select(self): """Create a select query that only returns public conversations.""" @@ -173,6 +178,25 @@ class SQLSharedConversationInfoService(SharedConversationInfoService): query = query.where(StoredConversationMetadata.public == True) # noqa: E712 return query + def _public_select_with_saas_metadata(self): + """Create a select query that returns public conversations with SAAS metadata. + + This joins with conversation_metadata_saas to retrieve the user_id needed + for constructing the correct event storage path. Uses LEFT OUTER JOIN to + support conversations that may not have SAAS metadata (e.g., in tests). + """ + query = ( + select(StoredConversationMetadata, StoredConversationMetadataSaas) + .outerjoin( + StoredConversationMetadataSaas, + StoredConversationMetadata.conversation_id + == StoredConversationMetadataSaas.conversation_id, + ) + .where(StoredConversationMetadata.conversation_version == 'V1') + .where(StoredConversationMetadata.public == True) # noqa: E712 + ) + return query + def _apply_filters( self, query, @@ -211,9 +235,16 @@ class SQLSharedConversationInfoService(SharedConversationInfoService): def _to_shared_conversation( self, stored: StoredConversationMetadata, + saas_metadata: StoredConversationMetadataSaas | None = None, sub_conversation_ids: list[UUID] | None = None, ) -> SharedConversation: - """Convert StoredConversationMetadata to SharedConversation.""" + """Convert StoredConversationMetadata to SharedConversation. + + Args: + stored: The base conversation metadata from conversation_metadata table. + saas_metadata: Optional SAAS metadata containing user_id and org_id. + sub_conversation_ids: Optional list of sub-conversation IDs. + """ # V1 conversations should always have a sandbox_id sandbox_id = stored.sandbox_id assert sandbox_id is not None @@ -239,9 +270,16 @@ class SQLSharedConversationInfoService(SharedConversationInfoService): created_at = self._fix_timezone(stored.created_at) updated_at = self._fix_timezone(stored.last_updated_at) + # Get user_id from SAAS metadata if available + created_by_user_id = ( + str(saas_metadata.user_id) + if saas_metadata and saas_metadata.user_id + else None + ) + return SharedConversation( id=UUID(stored.conversation_id), - created_by_user_id=None, # user_id is no longer stored in conversation metadata + created_by_user_id=created_by_user_id, sandbox_id=stored.sandbox_id, selected_repository=stored.selected_repository, selected_branch=stored.selected_branch, diff --git a/enterprise/tests/unit/test_sharing/test_sharing_shared_conversation_info_service.py b/enterprise/tests/unit/test_sharing/test_sharing_shared_conversation_info_service.py index e15d417000..60cd0ee29c 100644 --- a/enterprise/tests/unit/test_sharing/test_sharing_shared_conversation_info_service.py +++ b/enterprise/tests/unit/test_sharing/test_sharing_shared_conversation_info_service.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime from typing import AsyncGenerator -from uuid import uuid4 +from uuid import UUID, uuid4 import pytest from server.sharing.shared_conversation_models import ( @@ -13,6 +13,9 @@ from server.sharing.sql_shared_conversation_info_service import ( ) from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool +from storage.org import Org +from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas +from storage.user import User from openhands.app_server.app_conversation.app_conversation_models import ( AppConversationInfo, @@ -428,3 +431,261 @@ class TestSharedConversationInfoService: page1_ids = {item.id for item in result.items} page2_ids = {item.id for item in result2.items} assert page1_ids.isdisjoint(page2_ids) + + +class TestSharedConversationInfoServiceWithSaasMetadata: + """Test cases for SharedConversationInfoService with SAAS metadata. + + These tests verify that created_by_user_id is correctly retrieved from + the conversation_metadata_saas table when it exists. + """ + + @pytest.fixture + async def async_engine_with_saas(self): + """Create an async SQLite engine with all SAAS tables.""" + engine = create_async_engine( + 'sqlite+aiosqlite:///:memory:', + poolclass=StaticPool, + connect_args={'check_same_thread': False}, + echo=False, + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + await engine.dispose() + + @pytest.fixture + async def async_session_with_saas( + self, async_engine_with_saas + ) -> AsyncGenerator[AsyncSession, None]: + """Create an async session for testing with SAAS tables.""" + async_session_maker = async_sessionmaker( + async_engine_with_saas, class_=AsyncSession, expire_on_commit=False + ) + + async with async_session_maker() as db_session: + yield db_session + + @pytest.fixture + async def test_org(self, async_session_with_saas) -> Org: + """Create a test organization.""" + org = Org(id=uuid4(), name=f'test_org_{uuid4().hex[:8]}') + async_session_with_saas.add(org) + await async_session_with_saas.commit() + return org + + @pytest.fixture + async def test_user(self, async_session_with_saas, test_org) -> User: + """Create a test user belonging to the test organization.""" + user = User(id=uuid4(), current_org_id=test_org.id) + async_session_with_saas.add(user) + await async_session_with_saas.commit() + return user + + @pytest.fixture + async def shared_service_with_saas(self, async_session_with_saas): + """Create a SharedConversationInfoService for testing.""" + return SQLSharedConversationInfoService(db_session=async_session_with_saas) + + @pytest.fixture + async def app_service_with_saas(self, async_session_with_saas): + """Create an AppConversationInfoService for creating test data.""" + return SQLAppConversationInfoService( + db_session=async_session_with_saas, + user_context=SpecifyUserContext(user_id=None), + ) + + async def _create_saas_metadata( + self, + db_session: AsyncSession, + conversation_id: UUID, + user_id: UUID, + org_id: UUID, + ) -> StoredConversationMetadataSaas: + """Helper to create SAAS metadata for a conversation.""" + saas_metadata = StoredConversationMetadataSaas( + conversation_id=str(conversation_id), + user_id=user_id, + org_id=org_id, + ) + db_session.add(saas_metadata) + await db_session.commit() + return saas_metadata + + @pytest.mark.asyncio + async def test_get_shared_conversation_returns_user_id_from_saas_metadata( + self, + shared_service_with_saas, + app_service_with_saas, + async_session_with_saas, + test_user, + test_org, + ): + """Test that get_shared_conversation_info returns created_by_user_id from SAAS metadata.""" + # Arrange + conversation_id = uuid4() + conversation = AppConversationInfo( + id=conversation_id, + created_by_user_id=None, + sandbox_id='test_sandbox', + title='Public Conversation With User', + public=True, + metrics=MetricsSnapshot( + accumulated_cost=0.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage(), + ), + ) + await app_service_with_saas.save_app_conversation_info(conversation) + await self._create_saas_metadata( + async_session_with_saas, conversation_id, test_user.id, test_org.id + ) + + # Act + result = await shared_service_with_saas.get_shared_conversation_info( + conversation_id + ) + + # Assert + assert result is not None + assert result.created_by_user_id == str(test_user.id) + + @pytest.mark.asyncio + async def test_search_shared_conversations_returns_user_id_from_saas_metadata( + self, + shared_service_with_saas, + app_service_with_saas, + async_session_with_saas, + test_user, + test_org, + ): + """Test that search_shared_conversation_info returns created_by_user_id from SAAS metadata.""" + # Arrange + conversation_id = uuid4() + conversation = AppConversationInfo( + id=conversation_id, + created_by_user_id=None, + sandbox_id='test_sandbox_search', + title='Searchable Public Conversation', + public=True, + metrics=MetricsSnapshot( + accumulated_cost=0.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage(), + ), + ) + await app_service_with_saas.save_app_conversation_info(conversation) + await self._create_saas_metadata( + async_session_with_saas, conversation_id, test_user.id, test_org.id + ) + + # Act + result = await shared_service_with_saas.search_shared_conversation_info() + + # Assert + assert len(result.items) == 1 + assert result.items[0].created_by_user_id == str(test_user.id) + + @pytest.mark.asyncio + async def test_batch_get_shared_conversations_returns_user_id_from_saas_metadata( + self, + shared_service_with_saas, + app_service_with_saas, + async_session_with_saas, + test_user, + test_org, + ): + """Test that batch_get_shared_conversation_info returns created_by_user_id from SAAS metadata.""" + # Arrange + conversation_id = uuid4() + conversation = AppConversationInfo( + id=conversation_id, + created_by_user_id=None, + sandbox_id='test_sandbox_batch', + title='Batch Get Conversation', + public=True, + metrics=MetricsSnapshot( + accumulated_cost=0.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage(), + ), + ) + await app_service_with_saas.save_app_conversation_info(conversation) + await self._create_saas_metadata( + async_session_with_saas, conversation_id, test_user.id, test_org.id + ) + + # Act + result = await shared_service_with_saas.batch_get_shared_conversation_info( + [conversation_id] + ) + + # Assert + assert len(result) == 1 + assert result[0] is not None + assert result[0].created_by_user_id == str(test_user.id) + + @pytest.mark.asyncio + async def test_mixed_conversations_with_and_without_saas_metadata( + self, + shared_service_with_saas, + app_service_with_saas, + async_session_with_saas, + test_user, + test_org, + ): + """Test handling of conversations where some have SAAS metadata and some don't.""" + # Arrange + conv_with_saas_id = uuid4() + conv_without_saas_id = uuid4() + + conv_with_saas = AppConversationInfo( + id=conv_with_saas_id, + created_by_user_id=None, + sandbox_id='sandbox_with_saas', + title='With SAAS Metadata', + created_at=datetime(2023, 1, 2, tzinfo=UTC), + updated_at=datetime(2023, 1, 2, tzinfo=UTC), + public=True, + metrics=MetricsSnapshot( + accumulated_cost=0.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage(), + ), + ) + conv_without_saas = AppConversationInfo( + id=conv_without_saas_id, + created_by_user_id=None, + sandbox_id='sandbox_without_saas', + title='Without SAAS Metadata', + created_at=datetime(2023, 1, 1, tzinfo=UTC), + updated_at=datetime(2023, 1, 1, tzinfo=UTC), + public=True, + metrics=MetricsSnapshot( + accumulated_cost=0.0, + max_budget_per_task=10.0, + accumulated_token_usage=TokenUsage(), + ), + ) + + await app_service_with_saas.save_app_conversation_info(conv_with_saas) + await app_service_with_saas.save_app_conversation_info(conv_without_saas) + await self._create_saas_metadata( + async_session_with_saas, conv_with_saas_id, test_user.id, test_org.id + ) + + # Act + result = await shared_service_with_saas.search_shared_conversation_info( + sort_order=SharedConversationSortOrder.CREATED_AT + ) + + # Assert + assert len(result.items) == 2 + conv_without = next( + item for item in result.items if item.id == conv_without_saas_id + ) + conv_with = next(item for item in result.items if item.id == conv_with_saas_id) + assert conv_without.created_by_user_id is None + assert conv_with.created_by_user_id == str(test_user.id)