Separate SaaS-specific fields from StoredConversationMetadata

- Create new ConversationMetadataSaas model with conversation_id, user_id, org_id
- Remove github_user_id, user_id, org_id from StoredConversationMetadata
- Update all enterprise clients to use ConversationMetadataSaas for user/org lookups
- Add database migration to create new table and migrate existing data
- Maintain backward compatibility in OpenHands core components

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
openhands 2025-10-27 23:46:27 +00:00
parent f89e41ac30
commit 4646439108
11 changed files with 186 additions and 57 deletions

View File

@ -0,0 +1,81 @@
"""separate saas fields from conversation metadata
Revision ID: 081
Revises: 080
Create Date: 2025-01-27 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '081'
down_revision: Union[str, None] = '080'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create conversation_metadata_saas table
op.create_table(
'conversation_metadata_saas',
sa.Column('conversation_id', sa.String(), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'),
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'),
sa.PrimaryKeyConstraint('conversation_id'),
)
# Migrate existing data from conversation_metadata to conversation_metadata_saas
# First, we need to handle the case where user_id might be a string that needs to be converted to UUID
op.execute("""
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
SELECT
conversation_id,
CASE
WHEN user_id ~ '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
THEN user_id::uuid
ELSE gen_random_uuid() -- Generate a new UUID for invalid user_id values
END as user_id,
COALESCE(org_id, gen_random_uuid()) as org_id -- Use existing org_id or generate new one
FROM conversation_metadata
WHERE user_id IS NOT NULL
""")
# Remove columns from conversation_metadata table
op.drop_constraint('conversation_metadata_org_fkey', 'conversation_metadata', type_='foreignkey')
op.drop_column('conversation_metadata', 'github_user_id')
op.drop_column('conversation_metadata', 'user_id')
op.drop_column('conversation_metadata', 'org_id')
def downgrade() -> None:
# Add columns back to conversation_metadata table
op.add_column('conversation_metadata', sa.Column('github_user_id', sa.String(), nullable=True))
op.add_column('conversation_metadata', sa.Column('user_id', sa.String(), nullable=False))
op.add_column('conversation_metadata', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True))
# Recreate foreign key constraint
op.create_foreign_key(
'conversation_metadata_org_fkey',
'conversation_metadata',
'org',
['org_id'],
['id'],
)
# Migrate data back from conversation_metadata_saas to conversation_metadata
op.execute("""
UPDATE conversation_metadata
SET user_id = cms.user_id::text, org_id = cms.org_id
FROM conversation_metadata_saas cms
WHERE conversation_metadata.conversation_id = cms.conversation_id
""")
# Drop conversation_metadata_saas table
op.drop_table('conversation_metadata_saas')

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import socketio
from server.logger import logger
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
from storage.conversation_metadata_saas import ConversationMetadataSaas
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from storage.stored_conversation_metadata import StoredConversationMetadata
@ -525,16 +526,16 @@ class ClusteredConversationManager(StandaloneConversationManager):
)
# Look up the user_id from the database
with session_maker() as session:
conversation_metadata = (
session.query(StoredConversationMetadata)
conversation_metadata_saas = (
session.query(ConversationMetadataSaas)
.filter(
StoredConversationMetadata.conversation_id
ConversationMetadataSaas.conversation_id
== conversation_id
)
.first()
)
user_id = (
conversation_metadata.user_id if conversation_metadata else None
str(conversation_metadata_saas.user_id) if conversation_metadata_saas else None
)
# Handle the stopped conversation asynchronously
asyncio.create_task(

View File

@ -20,6 +20,7 @@ from server.utils.conversation_callback_utils import (
update_conversation_metadata,
update_conversation_stats,
)
from storage.conversation_metadata_saas import ConversationMetadataSaas
from storage.database import session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
@ -226,12 +227,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
def _get_user_id(conversation_id: str) -> str:
with session_maker() as session:
conversation_metadata = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.conversation_id == conversation_id)
conversation_metadata_saas = (
session.query(ConversationMetadataSaas)
.filter(ConversationMetadataSaas.conversation_id == conversation_id)
.first()
)
return conversation_metadata.user_id
return str(conversation_metadata_saas.user_id)
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.future import select
from storage.conversation_metadata_saas import ConversationMetadataSaas
from storage.database import session_maker
from storage.feedback import ConversationFeedback
from storage.stored_conversation_metadata import StoredConversationMetadata
@ -33,10 +34,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
def _verify_conversation():
with session_maker() as session:
metadata = (
session.query(StoredConversationMetadata)
session.query(ConversationMetadataSaas)
.filter(
StoredConversationMetadata.conversation_id == conversation_id,
StoredConversationMetadata.user_id == user_id,
ConversationMetadataSaas.conversation_id == conversation_id,
ConversationMetadataSaas.user_id == user_id,
)
.first()
)

View File

@ -20,6 +20,7 @@ from server.utils.conversation_callback_utils import (
from sqlalchemy import orm
from storage.api_key_store import ApiKeyStore
from storage.database import session_maker
from storage.conversation_metadata_saas import ConversationMetadataSaas
from storage.stored_conversation_metadata import StoredConversationMetadata
from openhands.controller.agent import Agent
@ -522,16 +523,16 @@ class SaasNestedConversationManager(ConversationManager):
"""
with session_maker() as session:
conversation_metadata = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.conversation_id == conversation_id)
conversation_metadata_saas = (
session.query(ConversationMetadataSaas)
.filter(ConversationMetadataSaas.conversation_id == conversation_id)
.first()
)
if not conversation_metadata:
if not conversation_metadata_saas:
raise ValueError(f'No conversation found {conversation_id}')
return conversation_metadata.user_id
return str(conversation_metadata_saas.user_id)
async def _get_runtime_status_from_nested_runtime(
self, session_api_key: Any | None, nested_url: str, conversation_id: str
@ -853,8 +854,11 @@ class SaasNestedConversationManager(ConversationManager):
with session_maker() as session:
# Only include conversations updated in the past week
one_week_ago = datetime.now(UTC) - timedelta(days=7)
query = session.query(StoredConversationMetadata.conversation_id).filter(
StoredConversationMetadata.user_id == user_id,
query = session.query(StoredConversationMetadata.conversation_id).join(
ConversationMetadataSaas,
StoredConversationMetadata.conversation_id == ConversationMetadataSaas.conversation_id
).filter(
ConversationMetadataSaas.user_id == user_id,
StoredConversationMetadata.last_updated_at >= one_week_ago,
)
user_conversation_ids = set(query)

View File

@ -3,6 +3,7 @@ from storage.auth_tokens import AuthTokens
from storage.billing_session import BillingSession
from storage.billing_session_type import BillingSessionType
from storage.conversation_callback import CallbackStatus, ConversationCallback
from storage.conversation_metadata_saas import ConversationMetadataSaas
from storage.conversation_work import ConversationWork
from storage.experiment_assignment import ExperimentAssignment
from storage.feedback import ConversationFeedback, Feedback
@ -45,6 +46,7 @@ __all__ = [
'CallbackStatus',
'ConversationCallback',
'ConversationFeedback',
'ConversationMetadataSaas',
'ConversationWork',
'ExperimentAssignment',
'Feedback',

View File

@ -0,0 +1,29 @@
"""
SQLAlchemy model for ConversationMetadataSaas.
This model stores the SaaS-specific metadata for conversations,
containing only the conversation_id, user_id, and org_id.
"""
from uuid import UUID
from sqlalchemy import UUID as SQL_UUID, Column, ForeignKey, String
from sqlalchemy.orm import relationship
from storage.base import Base
class ConversationMetadataSaas(Base): # type: ignore
"""SaaS conversation metadata model containing user and org associations."""
__tablename__ = 'conversation_metadata_saas'
conversation_id = Column(String, primary_key=True)
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
# Relationships
user = relationship('User', back_populates='conversation_metadata_saas')
org = relationship('Org', back_populates='conversation_metadata_saas')
__all__ = ['ConversationMetadataSaas']

View File

@ -51,6 +51,7 @@ class Org(Base): # type: ignore
conversation_metadata = relationship(
'StoredConversationMetadata', back_populates='org'
)
conversation_metadata_saas = relationship('ConversationMetadataSaas', back_populates='org')
user_secrets = relationship('StoredUserSecrets', back_populates='org')
api_keys = relationship('ApiKey', back_populates='org')
slack_conversations = relationship('SlackConversation', back_populates='org')

View File

@ -7,6 +7,7 @@ from datetime import UTC
from uuid import UUID
from sqlalchemy.orm import sessionmaker
from storage.conversation_metadata_saas import ConversationMetadataSaas
from storage.database import session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.user_store import UserStore
@ -43,10 +44,15 @@ class SaasConversationStore(ConversationStore):
self.org_id = user.current_org_id
def _select_by_id(self, session, conversation_id: str):
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
return (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.user_id == self.user_id)
.filter(StoredConversationMetadata.org_id == self.org_id)
.join(
ConversationMetadataSaas,
StoredConversationMetadata.conversation_id == ConversationMetadataSaas.conversation_id
)
.filter(ConversationMetadataSaas.user_id == self.user_id)
.filter(ConversationMetadataSaas.org_id == self.org_id)
.filter(StoredConversationMetadata.conversation_id == conversation_id)
)
@ -54,8 +60,6 @@ class SaasConversationStore(ConversationStore):
kwargs = {
c.name: getattr(conversation_metadata, c.name)
for c in StoredConversationMetadata.__table__.columns
if c.name
not in ['github_user_id', 'org_id'] # Skip github_user_id and org_id fields
}
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
@ -78,8 +82,10 @@ class SaasConversationStore(ConversationStore):
async def save_metadata(self, metadata: ConversationMetadata):
kwargs = dataclasses.asdict(metadata)
kwargs['user_id'] = self.user_id
kwargs['org_id'] = self.org_id
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
kwargs.pop('user_id', None)
kwargs.pop('org_id', None)
# Convert ProviderType enum to string for storage
if kwargs.get('git_provider') is not None:
@ -93,7 +99,26 @@ class SaasConversationStore(ConversationStore):
def _save_metadata():
with self.session_maker() as session:
# Save the main conversation metadata
session.merge(stored_metadata)
# Create or update the SaaS metadata record
saas_metadata = session.query(ConversationMetadataSaas).filter(
ConversationMetadataSaas.conversation_id == stored_metadata.conversation_id
).first()
if not saas_metadata:
saas_metadata = ConversationMetadataSaas(
conversation_id=stored_metadata.conversation_id,
user_id=self.user_id,
org_id=self.org_id
)
session.add(saas_metadata)
else:
# Update existing record
saas_metadata.user_id = self.user_id
saas_metadata.org_id = self.org_id
session.commit()
await call_sync_from_async(_save_metadata)
@ -113,7 +138,16 @@ class SaasConversationStore(ConversationStore):
async def delete_metadata(self, conversation_id: str) -> None:
def _delete_metadata():
with self.session_maker() as session:
# Delete the main conversation metadata
self._select_by_id(session, conversation_id).delete()
# Delete the SaaS metadata record
session.query(ConversationMetadataSaas).filter(
ConversationMetadataSaas.conversation_id == conversation_id,
ConversationMetadataSaas.user_id == self.user_id,
ConversationMetadataSaas.org_id == self.org_id
).delete()
session.commit()
await call_sync_from_async(_delete_metadata)
@ -137,8 +171,12 @@ class SaasConversationStore(ConversationStore):
with self.session_maker() as session:
conversations = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.user_id == self.user_id)
.filter(StoredConversationMetadata.org_id == self.org_id)
.join(
ConversationMetadataSaas,
StoredConversationMetadata.conversation_id == ConversationMetadataSaas.conversation_id
)
.filter(ConversationMetadataSaas.user_id == self.user_id)
.filter(ConversationMetadataSaas.org_id == self.org_id)
.order_by(StoredConversationMetadata.created_at.desc())
.offset(offset)
.limit(limit + 1)

View File

@ -36,3 +36,4 @@ class User(Base): # type: ignore
role = relationship('Role', back_populates='users')
org_members = relationship('OrgMember', back_populates='user')
current_org = relationship('Org', back_populates='current_users')
conversation_metadata_saas = relationship('ConversationMetadataSaas', back_populates='user')

View File

@ -25,14 +25,10 @@ from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from sqlalchemy import (
UUID as SQL_UUID,
)
from sqlalchemy import (
Column,
DateTime,
Float,
ForeignKey,
Integer,
Select,
String,
@ -40,7 +36,6 @@ from sqlalchemy import (
select,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import relationship
from openhands.agent_server.utils import utc_now
from openhands.app_server.app_conversation.app_conversation_info_service import (
@ -71,9 +66,6 @@ class StoredConversationMetadata(Base): # type: ignore
conversation_id = Column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
github_user_id = Column(String, nullable=True) # The GitHub user ID
user_id = Column(String, nullable=False) # The Keycloak User ID
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
selected_repository = Column(String, nullable=True)
selected_branch = Column(String, nullable=True)
git_provider = Column(
@ -104,9 +96,6 @@ class StoredConversationMetadata(Base): # type: ignore
conversation_version = Column(String, nullable=False, default='V0', index=True)
sandbox_id = Column(String, nullable=True, index=True)
# Relationship back to org
org = relationship('Org', back_populates='conversation_metadata')
@dataclass
class SQLAppConversationInfoService(AppConversationInfoService):
@ -196,9 +185,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
) -> int:
"""Count sandboxed conversations matching the given filters."""
query = select(func.count(StoredConversationMetadata.conversation_id))
user_id = await self.user_context.get_user_id()
if user_id:
query = query.where(StoredConversationMetadata.user_id == user_id)
query = self._apply_filters(
query=query,
@ -288,22 +274,11 @@ class SQLAppConversationInfoService(AppConversationInfoService):
async def save_app_conversation_info(
self, info: AppConversationInfo
) -> AppConversationInfo:
user_id = await self.user_context.get_user_id()
if user_id:
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_id == info.id
)
result = await self.db_session.execute(query)
existing = result.scalar_one_or_none()
assert existing is None or existing.created_by_user_id == user_id
metrics = info.metrics or MetricsSnapshot()
usage = metrics.accumulated_token_usage or TokenUsage()
stored = StoredConversationMetadata(
conversation_id=str(info.id),
github_user_id=None, # TODO: Should we add this to the conversation info?
user_id=info.created_by_user_id or '',
selected_repository=info.selected_repository,
selected_branch=info.selected_branch,
git_provider=info.git_provider.value if info.git_provider else None,
@ -335,11 +310,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'
)
user_id = await self.user_context.get_user_id()
if user_id:
query = query.where(
StoredConversationMetadata.user_id == user_id,
)
return query
def _to_info(self, stored: StoredConversationMetadata) -> AppConversationInfo:
@ -370,7 +340,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
return AppConversationInfo(
id=UUID(stored.conversation_id),
created_by_user_id=stored.user_id if stored.user_id else None,
created_by_user_id=None, # User ID is now stored in ConversationMetadataSaas
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,