mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Separate SaaS-specific fields from StoredConversationMetadata
- Create new ConversationMetadataSaas model with conversation_id, user_id, org_id - Remove github_user_id, user_id, org_id from StoredConversationMetadata - Update all enterprise clients to use ConversationMetadataSaas for user/org lookups - Add database migration to create new table and migrate existing data - Maintain backward compatibility in OpenHands core components Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
f89e41ac30
commit
4646439108
@ -0,0 +1,81 @@
|
||||
"""separate saas fields from conversation metadata
|
||||
|
||||
Revision ID: 081
|
||||
Revises: 080
|
||||
Create Date: 2025-01-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '081'
|
||||
down_revision: Union[str, None] = '080'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create conversation_metadata_saas table
|
||||
op.create_table(
|
||||
'conversation_metadata_saas',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'),
|
||||
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
# Migrate existing data from conversation_metadata to conversation_metadata_saas
|
||||
# First, we need to handle the case where user_id might be a string that needs to be converted to UUID
|
||||
op.execute("""
|
||||
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||
SELECT
|
||||
conversation_id,
|
||||
CASE
|
||||
WHEN user_id ~ '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
|
||||
THEN user_id::uuid
|
||||
ELSE gen_random_uuid() -- Generate a new UUID for invalid user_id values
|
||||
END as user_id,
|
||||
COALESCE(org_id, gen_random_uuid()) as org_id -- Use existing org_id or generate new one
|
||||
FROM conversation_metadata
|
||||
WHERE user_id IS NOT NULL
|
||||
""")
|
||||
|
||||
# Remove columns from conversation_metadata table
|
||||
op.drop_constraint('conversation_metadata_org_fkey', 'conversation_metadata', type_='foreignkey')
|
||||
op.drop_column('conversation_metadata', 'github_user_id')
|
||||
op.drop_column('conversation_metadata', 'user_id')
|
||||
op.drop_column('conversation_metadata', 'org_id')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add columns back to conversation_metadata table
|
||||
op.add_column('conversation_metadata', sa.Column('github_user_id', sa.String(), nullable=True))
|
||||
op.add_column('conversation_metadata', sa.Column('user_id', sa.String(), nullable=False))
|
||||
op.add_column('conversation_metadata', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True))
|
||||
|
||||
# Recreate foreign key constraint
|
||||
op.create_foreign_key(
|
||||
'conversation_metadata_org_fkey',
|
||||
'conversation_metadata',
|
||||
'org',
|
||||
['org_id'],
|
||||
['id'],
|
||||
)
|
||||
|
||||
# Migrate data back from conversation_metadata_saas to conversation_metadata
|
||||
op.execute("""
|
||||
UPDATE conversation_metadata
|
||||
SET user_id = cms.user_id::text, org_id = cms.org_id
|
||||
FROM conversation_metadata_saas cms
|
||||
WHERE conversation_metadata.conversation_id = cms.conversation_id
|
||||
""")
|
||||
|
||||
# Drop conversation_metadata_saas table
|
||||
op.drop_table('conversation_metadata_saas')
|
||||
@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
import socketio
|
||||
from server.logger import logger
|
||||
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
||||
from storage.conversation_metadata_saas import ConversationMetadataSaas
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
@ -525,16 +526,16 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
)
|
||||
# Look up the user_id from the database
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
conversation_metadata_saas = (
|
||||
session.query(ConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
ConversationMetadataSaas.conversation_id
|
||||
== conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_id = (
|
||||
conversation_metadata.user_id if conversation_metadata else None
|
||||
str(conversation_metadata_saas.user_id) if conversation_metadata_saas else None
|
||||
)
|
||||
# Handle the stopped conversation asynchronously
|
||||
asyncio.create_task(
|
||||
|
||||
@ -20,6 +20,7 @@ from server.utils.conversation_callback_utils import (
|
||||
update_conversation_metadata,
|
||||
update_conversation_stats,
|
||||
)
|
||||
from storage.conversation_metadata_saas import ConversationMetadataSaas
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
@ -226,12 +227,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
|
||||
|
||||
def _get_user_id(conversation_id: str) -> str:
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(ConversationMetadataSaas)
|
||||
.filter(ConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
|
||||
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.conversation_metadata_saas import ConversationMetadataSaas
|
||||
from storage.database import session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
@ -33,10 +34,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
session.query(ConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
ConversationMetadataSaas.conversation_id == conversation_id,
|
||||
ConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -20,6 +20,7 @@ from server.utils.conversation_callback_utils import (
|
||||
from sqlalchemy import orm
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.conversation_metadata_saas import ConversationMetadataSaas
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
@ -522,16 +523,16 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
"""
|
||||
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(ConversationMetadataSaas)
|
||||
.filter(ConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation_metadata:
|
||||
if not conversation_metadata_saas:
|
||||
raise ValueError(f'No conversation found {conversation_id}')
|
||||
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
async def _get_runtime_status_from_nested_runtime(
|
||||
self, session_api_key: Any | None, nested_url: str, conversation_id: str
|
||||
@ -853,8 +854,11 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
with session_maker() as session:
|
||||
# Only include conversations updated in the past week
|
||||
one_week_ago = datetime.now(UTC) - timedelta(days=7)
|
||||
query = session.query(StoredConversationMetadata.conversation_id).filter(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
query = session.query(StoredConversationMetadata.conversation_id).join(
|
||||
ConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id == ConversationMetadataSaas.conversation_id
|
||||
).filter(
|
||||
ConversationMetadataSaas.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
)
|
||||
user_conversation_ids = set(query)
|
||||
|
||||
@ -3,6 +3,7 @@ from storage.auth_tokens import AuthTokens
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.conversation_callback import CallbackStatus, ConversationCallback
|
||||
from storage.conversation_metadata_saas import ConversationMetadataSaas
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.experiment_assignment import ExperimentAssignment
|
||||
from storage.feedback import ConversationFeedback, Feedback
|
||||
@ -45,6 +46,7 @@ __all__ = [
|
||||
'CallbackStatus',
|
||||
'ConversationCallback',
|
||||
'ConversationFeedback',
|
||||
'ConversationMetadataSaas',
|
||||
'ConversationWork',
|
||||
'ExperimentAssignment',
|
||||
'Feedback',
|
||||
|
||||
29
enterprise/storage/conversation_metadata_saas.py
Normal file
29
enterprise/storage/conversation_metadata_saas.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""
|
||||
SQLAlchemy model for ConversationMetadataSaas.
|
||||
|
||||
This model stores the SaaS-specific metadata for conversations,
|
||||
containing only the conversation_id, user_id, and org_id.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import UUID as SQL_UUID, Column, ForeignKey, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class ConversationMetadataSaas(Base): # type: ignore
|
||||
"""SaaS conversation metadata model containing user and org associations."""
|
||||
|
||||
__tablename__ = 'conversation_metadata_saas'
|
||||
|
||||
conversation_id = Column(String, primary_key=True)
|
||||
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship('User', back_populates='conversation_metadata_saas')
|
||||
org = relationship('Org', back_populates='conversation_metadata_saas')
|
||||
|
||||
|
||||
__all__ = ['ConversationMetadataSaas']
|
||||
@ -51,6 +51,7 @@ class Org(Base): # type: ignore
|
||||
conversation_metadata = relationship(
|
||||
'StoredConversationMetadata', back_populates='org'
|
||||
)
|
||||
conversation_metadata_saas = relationship('ConversationMetadataSaas', back_populates='org')
|
||||
user_secrets = relationship('StoredUserSecrets', back_populates='org')
|
||||
api_keys = relationship('ApiKey', back_populates='org')
|
||||
slack_conversations = relationship('SlackConversation', back_populates='org')
|
||||
|
||||
@ -7,6 +7,7 @@ from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.conversation_metadata_saas import ConversationMetadataSaas
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.user_store import UserStore
|
||||
@ -43,10 +44,15 @@ class SaasConversationStore(ConversationStore):
|
||||
self.org_id = user.current_org_id
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
return (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.filter(StoredConversationMetadata.org_id == self.org_id)
|
||||
.join(
|
||||
ConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id == ConversationMetadataSaas.conversation_id
|
||||
)
|
||||
.filter(ConversationMetadataSaas.user_id == self.user_id)
|
||||
.filter(ConversationMetadataSaas.org_id == self.org_id)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
)
|
||||
|
||||
@ -54,8 +60,6 @@ class SaasConversationStore(ConversationStore):
|
||||
kwargs = {
|
||||
c.name: getattr(conversation_metadata, c.name)
|
||||
for c in StoredConversationMetadata.__table__.columns
|
||||
if c.name
|
||||
not in ['github_user_id', 'org_id'] # Skip github_user_id and org_id fields
|
||||
}
|
||||
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
|
||||
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
|
||||
@ -78,8 +82,10 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
async def save_metadata(self, metadata: ConversationMetadata):
|
||||
kwargs = dataclasses.asdict(metadata)
|
||||
kwargs['user_id'] = self.user_id
|
||||
kwargs['org_id'] = self.org_id
|
||||
|
||||
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
|
||||
kwargs.pop('user_id', None)
|
||||
kwargs.pop('org_id', None)
|
||||
|
||||
# Convert ProviderType enum to string for storage
|
||||
if kwargs.get('git_provider') is not None:
|
||||
@ -93,7 +99,26 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
def _save_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Save the main conversation metadata
|
||||
session.merge(stored_metadata)
|
||||
|
||||
# Create or update the SaaS metadata record
|
||||
saas_metadata = session.query(ConversationMetadataSaas).filter(
|
||||
ConversationMetadataSaas.conversation_id == stored_metadata.conversation_id
|
||||
).first()
|
||||
|
||||
if not saas_metadata:
|
||||
saas_metadata = ConversationMetadataSaas(
|
||||
conversation_id=stored_metadata.conversation_id,
|
||||
user_id=self.user_id,
|
||||
org_id=self.org_id
|
||||
)
|
||||
session.add(saas_metadata)
|
||||
else:
|
||||
# Update existing record
|
||||
saas_metadata.user_id = self.user_id
|
||||
saas_metadata.org_id = self.org_id
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_metadata)
|
||||
@ -113,7 +138,16 @@ class SaasConversationStore(ConversationStore):
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
def _delete_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Delete the main conversation metadata
|
||||
self._select_by_id(session, conversation_id).delete()
|
||||
|
||||
# Delete the SaaS metadata record
|
||||
session.query(ConversationMetadataSaas).filter(
|
||||
ConversationMetadataSaas.conversation_id == conversation_id,
|
||||
ConversationMetadataSaas.user_id == self.user_id,
|
||||
ConversationMetadataSaas.org_id == self.org_id
|
||||
).delete()
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_delete_metadata)
|
||||
@ -137,8 +171,12 @@ class SaasConversationStore(ConversationStore):
|
||||
with self.session_maker() as session:
|
||||
conversations = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.filter(StoredConversationMetadata.org_id == self.org_id)
|
||||
.join(
|
||||
ConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id == ConversationMetadataSaas.conversation_id
|
||||
)
|
||||
.filter(ConversationMetadataSaas.user_id == self.user_id)
|
||||
.filter(ConversationMetadataSaas.org_id == self.org_id)
|
||||
.order_by(StoredConversationMetadata.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit + 1)
|
||||
|
||||
@ -36,3 +36,4 @@ class User(Base): # type: ignore
|
||||
role = relationship('Role', back_populates='users')
|
||||
org_members = relationship('OrgMember', back_populates='user')
|
||||
current_org = relationship('Org', back_populates='current_users')
|
||||
conversation_metadata_saas = relationship('ConversationMetadataSaas', back_populates='user')
|
||||
|
||||
@ -25,14 +25,10 @@ from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import (
|
||||
UUID as SQL_UUID,
|
||||
)
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Select,
|
||||
String,
|
||||
@ -40,7 +36,6 @@ from sqlalchemy import (
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from openhands.agent_server.utils import utc_now
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
@ -71,9 +66,6 @@ class StoredConversationMetadata(Base): # type: ignore
|
||||
conversation_id = Column(
|
||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
github_user_id = Column(String, nullable=True) # The GitHub user ID
|
||||
user_id = Column(String, nullable=False) # The Keycloak User ID
|
||||
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
selected_repository = Column(String, nullable=True)
|
||||
selected_branch = Column(String, nullable=True)
|
||||
git_provider = Column(
|
||||
@ -104,9 +96,6 @@ class StoredConversationMetadata(Base): # type: ignore
|
||||
conversation_version = Column(String, nullable=False, default='V0', index=True)
|
||||
sandbox_id = Column(String, nullable=True, index=True)
|
||||
|
||||
# Relationship back to org
|
||||
org = relationship('Org', back_populates='conversation_metadata')
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
@ -196,9 +185,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
) -> int:
|
||||
"""Count sandboxed conversations matching the given filters."""
|
||||
query = select(func.count(StoredConversationMetadata.conversation_id))
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
query = query.where(StoredConversationMetadata.user_id == user_id)
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
@ -288,22 +274,11 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
async def save_app_conversation_info(
|
||||
self, info: AppConversationInfo
|
||||
) -> AppConversationInfo:
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_id == info.id
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
existing = result.scalar_one_or_none()
|
||||
assert existing is None or existing.created_by_user_id == user_id
|
||||
|
||||
metrics = info.metrics or MetricsSnapshot()
|
||||
usage = metrics.accumulated_token_usage or TokenUsage()
|
||||
|
||||
stored = StoredConversationMetadata(
|
||||
conversation_id=str(info.id),
|
||||
github_user_id=None, # TODO: Should we add this to the conversation info?
|
||||
user_id=info.created_by_user_id or '',
|
||||
selected_repository=info.selected_repository,
|
||||
selected_branch=info.selected_branch,
|
||||
git_provider=info.git_provider.value if info.git_provider else None,
|
||||
@ -335,11 +310,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
)
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
)
|
||||
return query
|
||||
|
||||
def _to_info(self, stored: StoredConversationMetadata) -> AppConversationInfo:
|
||||
@ -370,7 +340,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
|
||||
return AppConversationInfo(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
created_by_user_id=None, # User ID is now stored in ConversationMetadataSaas
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user