From 46464391083a65feb7162e3845a1051052eefe3c Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 27 Oct 2025 23:46:27 +0000 Subject: [PATCH] Separate SaaS-specific fields from StoredConversationMetadata - Create new ConversationMetadataSaas model with conversation_id, user_id, org_id - Remove github_user_id, user_id, org_id from StoredConversationMetadata - Update all enterprise clients to use ConversationMetadataSaas for user/org lookups - Add database migration to create new table and migrate existing data - Maintain backward compatibility in OpenHands core components Co-authored-by: openhands --- ..._saas_fields_from_conversation_metadata.py | 81 +++++++++++++++++++ .../server/clustered_conversation_manager.py | 9 ++- enterprise/server/routes/event_webhook.py | 9 ++- enterprise/server/routes/feedback.py | 7 +- .../saas_nested_conversation_manager.py | 18 +++-- enterprise/storage/__init__.py | 2 + .../storage/conversation_metadata_saas.py | 29 +++++++ enterprise/storage/org.py | 1 + enterprise/storage/saas_conversation_store.py | 54 +++++++++++-- enterprise/storage/user.py | 1 + .../sql_app_conversation_info_service.py | 32 +------- 11 files changed, 186 insertions(+), 57 deletions(-) create mode 100644 enterprise/migrations/versions/081_separate_saas_fields_from_conversation_metadata.py create mode 100644 enterprise/storage/conversation_metadata_saas.py diff --git a/enterprise/migrations/versions/081_separate_saas_fields_from_conversation_metadata.py b/enterprise/migrations/versions/081_separate_saas_fields_from_conversation_metadata.py new file mode 100644 index 0000000000..8a2ee87f5d --- /dev/null +++ b/enterprise/migrations/versions/081_separate_saas_fields_from_conversation_metadata.py @@ -0,0 +1,81 @@ +"""separate saas fields from conversation metadata + +Revision ID: 081 +Revises: 080 +Create Date: 2025-01-27 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '081' +down_revision: Union[str, None] = '080' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create conversation_metadata_saas table + op.create_table( + 'conversation_metadata_saas', + sa.Column('conversation_id', sa.String(), nullable=False), + sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'), + sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'), + sa.PrimaryKeyConstraint('conversation_id'), + ) + + # Migrate existing data from conversation_metadata to conversation_metadata_saas + # First, we need to handle the case where user_id might be a string that needs to be converted to UUID + op.execute(""" + INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id) + SELECT + conversation_id, + CASE + WHEN user_id ~ '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' + THEN user_id::uuid + ELSE gen_random_uuid() -- Generate a new UUID for invalid user_id values + END as user_id, + COALESCE(org_id, gen_random_uuid()) as org_id -- Use existing org_id or generate new one + FROM conversation_metadata + WHERE user_id IS NOT NULL + """) + + # Remove columns from conversation_metadata table + op.drop_constraint('conversation_metadata_org_fkey', 'conversation_metadata', type_='foreignkey') + op.drop_column('conversation_metadata', 'github_user_id') + op.drop_column('conversation_metadata', 'user_id') + op.drop_column('conversation_metadata', 'org_id') + + +def downgrade() -> None: + # Add columns back to conversation_metadata table + op.add_column('conversation_metadata', sa.Column('github_user_id', sa.String(), nullable=True)) + op.add_column('conversation_metadata', sa.Column('user_id', sa.String(), nullable=False)) + op.add_column('conversation_metadata', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)) + + # Recreate foreign key constraint + op.create_foreign_key( + 'conversation_metadata_org_fkey', + 'conversation_metadata', + 'org', + ['org_id'], + ['id'], + ) + + # Migrate data back from conversation_metadata_saas to conversation_metadata + op.execute(""" + UPDATE conversation_metadata + SET user_id = cms.user_id::text, org_id = cms.org_id + FROM conversation_metadata_saas cms + WHERE conversation_metadata.conversation_id = cms.conversation_id + """) + + # Drop conversation_metadata_saas table + op.drop_table('conversation_metadata_saas') \ No newline at end of file diff --git a/enterprise/server/clustered_conversation_manager.py b/enterprise/server/clustered_conversation_manager.py index 1eae6d19da..685a2169c9 100644 --- a/enterprise/server/clustered_conversation_manager.py +++ b/enterprise/server/clustered_conversation_manager.py @@ -7,6 +7,7 @@ from uuid import uuid4 import socketio from server.logger import logger from server.utils.conversation_callback_utils import invoke_conversation_callbacks +from storage.conversation_metadata_saas import ConversationMetadataSaas from storage.database import session_maker from storage.saas_settings_store import SaasSettingsStore from storage.stored_conversation_metadata import StoredConversationMetadata @@ -525,16 +526,16 @@ class ClusteredConversationManager(StandaloneConversationManager): ) # Look up the user_id from the database with session_maker() as session: - conversation_metadata = ( - session.query(StoredConversationMetadata) + conversation_metadata_saas = ( + session.query(ConversationMetadataSaas) .filter( - StoredConversationMetadata.conversation_id + ConversationMetadataSaas.conversation_id == conversation_id ) .first() ) user_id = ( - conversation_metadata.user_id if conversation_metadata else None + str(conversation_metadata_saas.user_id) if conversation_metadata_saas else None ) # Handle the stopped conversation asynchronously asyncio.create_task( diff --git a/enterprise/server/routes/event_webhook.py b/enterprise/server/routes/event_webhook.py index b4f8b71f68..08bdb97be3 100644 --- a/enterprise/server/routes/event_webhook.py +++ b/enterprise/server/routes/event_webhook.py @@ -20,6 +20,7 @@ from server.utils.conversation_callback_utils import ( update_conversation_metadata, update_conversation_stats, ) +from storage.conversation_metadata_saas import ConversationMetadataSaas from storage.database import session_maker from storage.stored_conversation_metadata import StoredConversationMetadata @@ -226,12 +227,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]: def _get_user_id(conversation_id: str) -> str: with session_maker() as session: - conversation_metadata = ( - session.query(StoredConversationMetadata) - .filter(StoredConversationMetadata.conversation_id == conversation_id) + conversation_metadata_saas = ( + session.query(ConversationMetadataSaas) + .filter(ConversationMetadataSaas.conversation_id == conversation_id) .first() ) - return conversation_metadata.user_id + return str(conversation_metadata_saas.user_id) async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None: diff --git a/enterprise/server/routes/feedback.py b/enterprise/server/routes/feedback.py index dc37af242f..8a60d93564 100644 --- a/enterprise/server/routes/feedback.py +++ b/enterprise/server/routes/feedback.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, Field from sqlalchemy.future import select +from storage.conversation_metadata_saas import ConversationMetadataSaas from storage.database import session_maker from storage.feedback import ConversationFeedback from storage.stored_conversation_metadata import StoredConversationMetadata @@ -33,10 +34,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]: def _verify_conversation(): with session_maker() as session: metadata = ( - session.query(StoredConversationMetadata) + session.query(ConversationMetadataSaas) .filter( - StoredConversationMetadata.conversation_id == conversation_id, - StoredConversationMetadata.user_id == user_id, + ConversationMetadataSaas.conversation_id == conversation_id, + ConversationMetadataSaas.user_id == user_id, ) .first() ) diff --git a/enterprise/server/saas_nested_conversation_manager.py b/enterprise/server/saas_nested_conversation_manager.py index 6eb03a66e3..2f27dfa76f 100644 --- a/enterprise/server/saas_nested_conversation_manager.py +++ b/enterprise/server/saas_nested_conversation_manager.py @@ -20,6 +20,7 @@ from server.utils.conversation_callback_utils import ( from sqlalchemy import orm from storage.api_key_store import ApiKeyStore from storage.database import session_maker +from storage.conversation_metadata_saas import ConversationMetadataSaas from storage.stored_conversation_metadata import StoredConversationMetadata from openhands.controller.agent import Agent @@ -522,16 +523,16 @@ class SaasNestedConversationManager(ConversationManager): """ with session_maker() as session: - conversation_metadata = ( - session.query(StoredConversationMetadata) - .filter(StoredConversationMetadata.conversation_id == conversation_id) + conversation_metadata_saas = ( + session.query(ConversationMetadataSaas) + .filter(ConversationMetadataSaas.conversation_id == conversation_id) .first() ) - if not conversation_metadata: + if not conversation_metadata_saas: raise ValueError(f'No conversation found {conversation_id}') - return conversation_metadata.user_id + return str(conversation_metadata_saas.user_id) async def _get_runtime_status_from_nested_runtime( self, session_api_key: Any | None, nested_url: str, conversation_id: str @@ -853,8 +854,11 @@ class SaasNestedConversationManager(ConversationManager): with session_maker() as session: # Only include conversations updated in the past week one_week_ago = datetime.now(UTC) - timedelta(days=7) - query = session.query(StoredConversationMetadata.conversation_id).filter( - StoredConversationMetadata.user_id == user_id, + query = session.query(StoredConversationMetadata.conversation_id).join( + ConversationMetadataSaas, + StoredConversationMetadata.conversation_id == ConversationMetadataSaas.conversation_id + ).filter( + ConversationMetadataSaas.user_id == user_id, StoredConversationMetadata.last_updated_at >= one_week_ago, ) user_conversation_ids = set(query) diff --git a/enterprise/storage/__init__.py b/enterprise/storage/__init__.py index 90162047cd..1c72b0c063 100644 --- a/enterprise/storage/__init__.py +++ b/enterprise/storage/__init__.py @@ -3,6 +3,7 @@ from storage.auth_tokens import AuthTokens from storage.billing_session import BillingSession from storage.billing_session_type import BillingSessionType from storage.conversation_callback import CallbackStatus, ConversationCallback +from storage.conversation_metadata_saas import ConversationMetadataSaas from storage.conversation_work import ConversationWork from storage.experiment_assignment import ExperimentAssignment from storage.feedback import ConversationFeedback, Feedback @@ -45,6 +46,7 @@ __all__ = [ 'CallbackStatus', 'ConversationCallback', 'ConversationFeedback', + 'ConversationMetadataSaas', 'ConversationWork', 'ExperimentAssignment', 'Feedback', diff --git a/enterprise/storage/conversation_metadata_saas.py b/enterprise/storage/conversation_metadata_saas.py new file mode 100644 index 0000000000..8efa37a12a --- /dev/null +++ b/enterprise/storage/conversation_metadata_saas.py @@ -0,0 +1,29 @@ +""" +SQLAlchemy model for ConversationMetadataSaas. + +This model stores the SaaS-specific metadata for conversations, +containing only the conversation_id, user_id, and org_id. +""" + +from uuid import UUID + +from sqlalchemy import UUID as SQL_UUID, Column, ForeignKey, String +from sqlalchemy.orm import relationship +from storage.base import Base + + +class ConversationMetadataSaas(Base): # type: ignore + """SaaS conversation metadata model containing user and org associations.""" + + __tablename__ = 'conversation_metadata_saas' + + conversation_id = Column(String, primary_key=True) + user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False) + org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False) + + # Relationships + user = relationship('User', back_populates='conversation_metadata_saas') + org = relationship('Org', back_populates='conversation_metadata_saas') + + +__all__ = ['ConversationMetadataSaas'] \ No newline at end of file diff --git a/enterprise/storage/org.py b/enterprise/storage/org.py index 81580a5bca..af1e169b61 100644 --- a/enterprise/storage/org.py +++ b/enterprise/storage/org.py @@ -51,6 +51,7 @@ class Org(Base): # type: ignore conversation_metadata = relationship( 'StoredConversationMetadata', back_populates='org' ) + conversation_metadata_saas = relationship('ConversationMetadataSaas', back_populates='org') user_secrets = relationship('StoredUserSecrets', back_populates='org') api_keys = relationship('ApiKey', back_populates='org') slack_conversations = relationship('SlackConversation', back_populates='org') diff --git a/enterprise/storage/saas_conversation_store.py b/enterprise/storage/saas_conversation_store.py index 10371810aa..8538f73cfd 100644 --- a/enterprise/storage/saas_conversation_store.py +++ b/enterprise/storage/saas_conversation_store.py @@ -7,6 +7,7 @@ from datetime import UTC from uuid import UUID from sqlalchemy.orm import sessionmaker +from storage.conversation_metadata_saas import ConversationMetadataSaas from storage.database import session_maker from storage.stored_conversation_metadata import StoredConversationMetadata from storage.user_store import UserStore @@ -43,10 +44,15 @@ class SaasConversationStore(ConversationStore): self.org_id = user.current_org_id def _select_by_id(self, session, conversation_id: str): + # Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org return ( session.query(StoredConversationMetadata) - .filter(StoredConversationMetadata.user_id == self.user_id) - .filter(StoredConversationMetadata.org_id == self.org_id) + .join( + ConversationMetadataSaas, + StoredConversationMetadata.conversation_id == ConversationMetadataSaas.conversation_id + ) + .filter(ConversationMetadataSaas.user_id == self.user_id) + .filter(ConversationMetadataSaas.org_id == self.org_id) .filter(StoredConversationMetadata.conversation_id == conversation_id) ) @@ -54,8 +60,6 @@ class SaasConversationStore(ConversationStore): kwargs = { c.name: getattr(conversation_metadata, c.name) for c in StoredConversationMetadata.__table__.columns - if c.name - not in ['github_user_id', 'org_id'] # Skip github_user_id and org_id fields } # TODO: I'm not sure why the timezone is not set on the dates coming back out of the db kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC) @@ -78,8 +82,10 @@ class SaasConversationStore(ConversationStore): async def save_metadata(self, metadata: ConversationMetadata): kwargs = dataclasses.asdict(metadata) - kwargs['user_id'] = self.user_id - kwargs['org_id'] = self.org_id + + # Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata + kwargs.pop('user_id', None) + kwargs.pop('org_id', None) # Convert ProviderType enum to string for storage if kwargs.get('git_provider') is not None: @@ -93,7 +99,26 @@ class SaasConversationStore(ConversationStore): def _save_metadata(): with self.session_maker() as session: + # Save the main conversation metadata session.merge(stored_metadata) + + # Create or update the SaaS metadata record + saas_metadata = session.query(ConversationMetadataSaas).filter( + ConversationMetadataSaas.conversation_id == stored_metadata.conversation_id + ).first() + + if not saas_metadata: + saas_metadata = ConversationMetadataSaas( + conversation_id=stored_metadata.conversation_id, + user_id=self.user_id, + org_id=self.org_id + ) + session.add(saas_metadata) + else: + # Update existing record + saas_metadata.user_id = self.user_id + saas_metadata.org_id = self.org_id + session.commit() await call_sync_from_async(_save_metadata) @@ -113,7 +138,16 @@ class SaasConversationStore(ConversationStore): async def delete_metadata(self, conversation_id: str) -> None: def _delete_metadata(): with self.session_maker() as session: + # Delete the main conversation metadata self._select_by_id(session, conversation_id).delete() + + # Delete the SaaS metadata record + session.query(ConversationMetadataSaas).filter( + ConversationMetadataSaas.conversation_id == conversation_id, + ConversationMetadataSaas.user_id == self.user_id, + ConversationMetadataSaas.org_id == self.org_id + ).delete() + session.commit() await call_sync_from_async(_delete_metadata) @@ -137,8 +171,12 @@ class SaasConversationStore(ConversationStore): with self.session_maker() as session: conversations = ( session.query(StoredConversationMetadata) - .filter(StoredConversationMetadata.user_id == self.user_id) - .filter(StoredConversationMetadata.org_id == self.org_id) + .join( + ConversationMetadataSaas, + StoredConversationMetadata.conversation_id == ConversationMetadataSaas.conversation_id + ) + .filter(ConversationMetadataSaas.user_id == self.user_id) + .filter(ConversationMetadataSaas.org_id == self.org_id) .order_by(StoredConversationMetadata.created_at.desc()) .offset(offset) .limit(limit + 1) diff --git a/enterprise/storage/user.py b/enterprise/storage/user.py index f274b8703a..27596608a4 100644 --- a/enterprise/storage/user.py +++ b/enterprise/storage/user.py @@ -36,3 +36,4 @@ class User(Base): # type: ignore role = relationship('Role', back_populates='users') org_members = relationship('OrgMember', back_populates='user') current_org = relationship('Org', back_populates='current_users') + conversation_metadata_saas = relationship('ConversationMetadataSaas', back_populates='user') diff --git a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py index 03f3d03c7c..1bc1729a6d 100644 --- a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py @@ -25,14 +25,10 @@ from typing import AsyncGenerator from uuid import UUID from fastapi import Request -from sqlalchemy import ( - UUID as SQL_UUID, -) from sqlalchemy import ( Column, DateTime, Float, - ForeignKey, Integer, Select, String, @@ -40,7 +36,6 @@ from sqlalchemy import ( select, ) from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import relationship from openhands.agent_server.utils import utc_now from openhands.app_server.app_conversation.app_conversation_info_service import ( @@ -71,9 +66,6 @@ class StoredConversationMetadata(Base): # type: ignore conversation_id = Column( String, primary_key=True, default=lambda: str(uuid.uuid4()) ) - github_user_id = Column(String, nullable=True) # The GitHub user ID - user_id = Column(String, nullable=False) # The Keycloak User ID - org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=True) selected_repository = Column(String, nullable=True) selected_branch = Column(String, nullable=True) git_provider = Column( @@ -104,9 +96,6 @@ class StoredConversationMetadata(Base): # type: ignore conversation_version = Column(String, nullable=False, default='V0', index=True) sandbox_id = Column(String, nullable=True, index=True) - # Relationship back to org - org = relationship('Org', back_populates='conversation_metadata') - @dataclass class SQLAppConversationInfoService(AppConversationInfoService): @@ -196,9 +185,6 @@ class SQLAppConversationInfoService(AppConversationInfoService): ) -> int: """Count sandboxed conversations matching the given filters.""" query = select(func.count(StoredConversationMetadata.conversation_id)) - user_id = await self.user_context.get_user_id() - if user_id: - query = query.where(StoredConversationMetadata.user_id == user_id) query = self._apply_filters( query=query, @@ -288,22 +274,11 @@ class SQLAppConversationInfoService(AppConversationInfoService): async def save_app_conversation_info( self, info: AppConversationInfo ) -> AppConversationInfo: - user_id = await self.user_context.get_user_id() - if user_id: - query = select(StoredConversationMetadata).where( - StoredConversationMetadata.conversation_id == info.id - ) - result = await self.db_session.execute(query) - existing = result.scalar_one_or_none() - assert existing is None or existing.created_by_user_id == user_id - metrics = info.metrics or MetricsSnapshot() usage = metrics.accumulated_token_usage or TokenUsage() stored = StoredConversationMetadata( conversation_id=str(info.id), - github_user_id=None, # TODO: Should we add this to the conversation info? - user_id=info.created_by_user_id or '', selected_repository=info.selected_repository, selected_branch=info.selected_branch, git_provider=info.git_provider.value if info.git_provider else None, @@ -335,11 +310,6 @@ class SQLAppConversationInfoService(AppConversationInfoService): query = select(StoredConversationMetadata).where( StoredConversationMetadata.conversation_version == 'V1' ) - user_id = await self.user_context.get_user_id() - if user_id: - query = query.where( - StoredConversationMetadata.user_id == user_id, - ) return query def _to_info(self, stored: StoredConversationMetadata) -> AppConversationInfo: @@ -370,7 +340,7 @@ class SQLAppConversationInfoService(AppConversationInfoService): return AppConversationInfo( id=UUID(stored.conversation_id), - created_by_user_id=stored.user_id if stored.user_id else None, + created_by_user_id=None, # User ID is now stored in ConversationMetadataSaas sandbox_id=stored.sandbox_id, selected_repository=stored.selected_repository, selected_branch=stored.selected_branch,