mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Fix LiteLLM key management and user migration SQL queries (#13279)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -137,11 +137,23 @@ class LiteLlmManager:
|
|||||||
client, keycloak_user_id, org_id, team_budget
|
client, keycloak_user_id, org_id, team_budget
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We delete the key if it already exists. In environments where multiple
|
||||||
|
# installations are using the same keycloak and litellm instance, this
|
||||||
|
# will mean other installations will have their key invalidated.
|
||||||
|
key_alias = get_openhands_cloud_key_alias(keycloak_user_id, org_id)
|
||||||
|
try:
|
||||||
|
await LiteLlmManager._delete_key_by_alias(client, key_alias)
|
||||||
|
except httpx.HTTPStatusError as ex:
|
||||||
|
if ex.status_code == 404:
|
||||||
|
logger.debug(f'Key "{key_alias}" did not exist - continuing')
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
key = await LiteLlmManager._generate_key(
|
key = await LiteLlmManager._generate_key(
|
||||||
client,
|
client,
|
||||||
keycloak_user_id,
|
keycloak_user_id,
|
||||||
org_id,
|
org_id,
|
||||||
get_openhands_cloud_key_alias(keycloak_user_id, org_id),
|
key_alias,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -295,29 +295,28 @@ class UserStore:
|
|||||||
extra={'user_id': user_id},
|
extra={'user_id': user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_uuid = uuid.UUID(user_id)
|
||||||
|
|
||||||
# need to migrate conversation metadata
|
# need to migrate conversation metadata
|
||||||
await session.execute(
|
await session.execute(
|
||||||
text("""
|
text("""
|
||||||
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||||
SELECT
|
SELECT
|
||||||
conversation_id,
|
conversation_id,
|
||||||
:user_id,
|
:user_uuid,
|
||||||
:user_id
|
:user_uuid
|
||||||
FROM conversation_metadata
|
FROM conversation_metadata
|
||||||
WHERE user_id = :user_id
|
WHERE user_id = :user_id_text
|
||||||
"""),
|
"""),
|
||||||
{'user_id': user_id},
|
{'user_uuid': user_uuid, 'user_id_text': user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update org_id for tables that had org_id added
|
|
||||||
user_uuid = uuid.UUID(user_id)
|
|
||||||
|
|
||||||
# Update stripe_customers
|
# Update stripe_customers
|
||||||
await session.execute(
|
await session.execute(
|
||||||
text(
|
text(
|
||||||
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||||
),
|
),
|
||||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
{'org_id': user_uuid, 'user_id': user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update slack_users
|
# Update slack_users
|
||||||
@@ -325,7 +324,7 @@ class UserStore:
|
|||||||
text(
|
text(
|
||||||
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||||
),
|
),
|
||||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
{'org_id': user_uuid, 'user_id': user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update slack_conversation
|
# Update slack_conversation
|
||||||
@@ -333,13 +332,13 @@ class UserStore:
|
|||||||
text(
|
text(
|
||||||
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||||
),
|
),
|
||||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
{'org_id': user_uuid, 'user_id': user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update api_keys
|
# Update api_keys
|
||||||
await session.execute(
|
await session.execute(
|
||||||
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
|
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
|
||||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
{'org_id': user_uuid, 'user_id': user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update custom_secrets
|
# Update custom_secrets
|
||||||
@@ -347,7 +346,7 @@ class UserStore:
|
|||||||
text(
|
text(
|
||||||
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||||
),
|
),
|
||||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
{'org_id': user_uuid, 'user_id': user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update billing_sessions
|
# Update billing_sessions
|
||||||
@@ -355,7 +354,7 @@ class UserStore:
|
|||||||
text(
|
text(
|
||||||
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
|
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
|
||||||
),
|
),
|
||||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
{'org_id': user_uuid, 'user_id': user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -180,11 +180,11 @@ class TestLiteLlmManager:
|
|||||||
assert result.llm_api_key.get_secret_value() == 'test-api-key'
|
assert result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||||
assert result.llm_base_url == 'http://test.com'
|
assert result.llm_base_url == 'http://test.com'
|
||||||
|
|
||||||
# Verify API calls were made (get_team + 3 posts)
|
# Verify API calls were made (get_team + 4 posts)
|
||||||
assert mock_client.get.call_count == 1 # get_team
|
assert mock_client.get.call_count == 1 # get_team
|
||||||
assert (
|
assert (
|
||||||
mock_client.post.call_count == 3
|
mock_client.post.call_count == 4
|
||||||
) # create_team, add_user_to_team, generate_key
|
) # create_team, add_user_to_team, delete_key_by_alias, generate_key
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_entries_inherits_existing_team_budget(
|
async def test_create_entries_inherits_existing_team_budget(
|
||||||
|
|||||||
@@ -833,3 +833,413 @@ async def test_release_user_creation_lock_released():
|
|||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_redis.delete.assert_called_once()
|
mock_redis.delete.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for migrate_user SQL parameter type handling ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_migrate_user_sql_type_handling(async_session_maker):
|
||||||
|
"""Test that migrate_user correctly handles UUID vs string types in SQL queries.
|
||||||
|
|
||||||
|
This test verifies the fixes for SQL parameter binding issues in _migrate_personal_data
|
||||||
|
where UUID and string parameters need to be correctly matched to their column types.
|
||||||
|
|
||||||
|
Note: SQLite doesn't natively support UUID types, so we use string representations.
|
||||||
|
The key verification is that:
|
||||||
|
1. String user_ids in WHERE clauses match source tables correctly
|
||||||
|
2. UUID values are inserted into target UUID columns correctly
|
||||||
|
3. The migration queries don't fail due to type mismatches
|
||||||
|
"""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
user_uuid = uuid.UUID(user_id)
|
||||||
|
# For SQLite raw SQL, use string representation of UUID
|
||||||
|
user_uuid_str = str(user_uuid)
|
||||||
|
|
||||||
|
# Set up legacy data with string user_ids (as in the old schema)
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
# First, add conversation_metadata with user_id as string column
|
||||||
|
# The current model doesn't have user_id, but the real DB did before migration
|
||||||
|
# We use raw SQL to add the column and insert test data
|
||||||
|
await session.execute(
|
||||||
|
text('ALTER TABLE conversation_metadata ADD COLUMN user_id VARCHAR')
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO conversation_metadata (conversation_id, user_id, conversation_version, created_at, last_updated_at)
|
||||||
|
VALUES (:conv_id, :user_id, 'V0', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{'conv_id': 'test-conv-1', 'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create org first (needed for foreign keys)
|
||||||
|
org = Org(id=user_uuid, name=f'user_{user_id}_org')
|
||||||
|
session.add(org)
|
||||||
|
|
||||||
|
# Create user (needed for foreign keys)
|
||||||
|
user = User(id=user_uuid, current_org_id=user_uuid)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Add stripe_customers with keycloak_user_id as string
|
||||||
|
from storage.stripe_customer import StripeCustomer
|
||||||
|
|
||||||
|
stripe_customer = StripeCustomer(
|
||||||
|
keycloak_user_id=user_id, stripe_customer_id='stripe_123'
|
||||||
|
)
|
||||||
|
session.add(stripe_customer)
|
||||||
|
|
||||||
|
# Add slack_users with keycloak_user_id as string
|
||||||
|
from storage.slack_user import SlackUser
|
||||||
|
|
||||||
|
slack_user = SlackUser(
|
||||||
|
keycloak_user_id=user_id,
|
||||||
|
slack_user_id='slack_user_123',
|
||||||
|
slack_display_name='Test User',
|
||||||
|
)
|
||||||
|
session.add(slack_user)
|
||||||
|
|
||||||
|
# Add slack_conversation with keycloak_user_id as string
|
||||||
|
from storage.slack_conversation import SlackConversation
|
||||||
|
|
||||||
|
slack_conv = SlackConversation(
|
||||||
|
conversation_id='slack-conv-1',
|
||||||
|
channel_id='channel_123',
|
||||||
|
keycloak_user_id=user_id,
|
||||||
|
)
|
||||||
|
session.add(slack_conv)
|
||||||
|
|
||||||
|
# Add api_keys with user_id as string
|
||||||
|
from storage.api_key import ApiKey
|
||||||
|
|
||||||
|
api_key = ApiKey(key='api_key_123', user_id=user_id, name='Test API Key')
|
||||||
|
session.add(api_key)
|
||||||
|
|
||||||
|
# Add custom_secrets with keycloak_user_id as string
|
||||||
|
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||||
|
|
||||||
|
custom_secret = StoredCustomSecrets(
|
||||||
|
keycloak_user_id=user_id,
|
||||||
|
secret_name='test_secret',
|
||||||
|
secret_value='secret_value',
|
||||||
|
)
|
||||||
|
session.add(custom_secret)
|
||||||
|
|
||||||
|
# Add billing_sessions with user_id as string
|
||||||
|
from storage.billing_session import BillingSession
|
||||||
|
|
||||||
|
billing_session = BillingSession(
|
||||||
|
id='billing-session-1',
|
||||||
|
user_id=user_id,
|
||||||
|
status='completed',
|
||||||
|
price=10,
|
||||||
|
price_code='USD',
|
||||||
|
)
|
||||||
|
session.add(billing_session)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Now execute the migration SQL statements with the correct parameter types
|
||||||
|
# This tests the fix: using user_uuid for UUID columns and user_id for string columns
|
||||||
|
# Note: For SQLite, we use string representation of UUID
|
||||||
|
|
||||||
|
# Test 1: conversation_metadata to conversation_metadata_saas migration
|
||||||
|
# The fix uses user_uuid (UUID) for inserting into user_id/org_id (UUID columns)
|
||||||
|
# and user_id_text (string) for comparing with user_id in conversation_metadata (string column)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||||
|
SELECT
|
||||||
|
conversation_id,
|
||||||
|
:user_uuid,
|
||||||
|
:user_uuid
|
||||||
|
FROM conversation_metadata
|
||||||
|
WHERE user_id = :user_id_text
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{'user_uuid': user_uuid_str, 'user_id_text': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 2: Update stripe_customers - org_id is UUID, keycloak_user_id is string
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||||
|
),
|
||||||
|
{'org_id': user_uuid_str, 'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 3: Update slack_users - org_id is UUID, keycloak_user_id is string
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||||
|
),
|
||||||
|
{'org_id': user_uuid_str, 'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 4: Update slack_conversation - org_id is UUID, keycloak_user_id is string
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||||
|
),
|
||||||
|
{'org_id': user_uuid_str, 'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 5: Update api_keys - org_id is UUID, user_id is string
|
||||||
|
await session.execute(
|
||||||
|
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
|
||||||
|
{'org_id': user_uuid_str, 'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 6: Update custom_secrets - org_id is UUID, keycloak_user_id is string
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||||
|
),
|
||||||
|
{'org_id': user_uuid_str, 'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 7: Update billing_sessions - org_id is UUID, user_id is string
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
|
||||||
|
),
|
||||||
|
{'org_id': user_uuid_str, 'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify the data was migrated correctly
|
||||||
|
from storage.stored_conversation_metadata_saas import (
|
||||||
|
StoredConversationMetadataSaas,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify conversation_metadata_saas
|
||||||
|
result = await session.execute(
|
||||||
|
select(StoredConversationMetadataSaas).filter(
|
||||||
|
StoredConversationMetadataSaas.conversation_id == 'test-conv-1'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
saas_metadata = result.scalars().first()
|
||||||
|
assert (
|
||||||
|
saas_metadata is not None
|
||||||
|
), 'conversation_metadata_saas record should exist'
|
||||||
|
assert saas_metadata.user_id == user_uuid, 'user_id should be UUID type'
|
||||||
|
assert saas_metadata.org_id == user_uuid, 'org_id should be UUID type'
|
||||||
|
|
||||||
|
# Verify stripe_customers org_id was set
|
||||||
|
result = await session.execute(
|
||||||
|
select(StripeCustomer).filter(StripeCustomer.keycloak_user_id == user_id)
|
||||||
|
)
|
||||||
|
stripe_record = result.scalars().first()
|
||||||
|
assert stripe_record is not None
|
||||||
|
assert (
|
||||||
|
stripe_record.org_id == user_uuid
|
||||||
|
), 'stripe_customers.org_id should be UUID'
|
||||||
|
|
||||||
|
# Verify slack_users org_id was set
|
||||||
|
result = await session.execute(
|
||||||
|
select(SlackUser).filter(SlackUser.keycloak_user_id == user_id)
|
||||||
|
)
|
||||||
|
slack_user_record = result.scalars().first()
|
||||||
|
assert slack_user_record is not None
|
||||||
|
assert (
|
||||||
|
slack_user_record.org_id == user_uuid
|
||||||
|
), 'slack_users.org_id should be UUID'
|
||||||
|
|
||||||
|
# Verify slack_conversation org_id was set
|
||||||
|
result = await session.execute(
|
||||||
|
select(SlackConversation).filter(
|
||||||
|
SlackConversation.keycloak_user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
slack_conv_record = result.scalars().first()
|
||||||
|
assert slack_conv_record is not None
|
||||||
|
assert (
|
||||||
|
slack_conv_record.org_id == user_uuid
|
||||||
|
), 'slack_conversation.org_id should be UUID'
|
||||||
|
|
||||||
|
# Verify api_keys org_id was set
|
||||||
|
result = await session.execute(select(ApiKey).filter(ApiKey.user_id == user_id))
|
||||||
|
api_key_record = result.scalars().first()
|
||||||
|
assert api_key_record is not None
|
||||||
|
assert api_key_record.org_id == user_uuid, 'api_keys.org_id should be UUID'
|
||||||
|
|
||||||
|
# Verify custom_secrets org_id was set
|
||||||
|
result = await session.execute(
|
||||||
|
select(StoredCustomSecrets).filter(
|
||||||
|
StoredCustomSecrets.keycloak_user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
custom_secret_record = result.scalars().first()
|
||||||
|
assert custom_secret_record is not None
|
||||||
|
assert (
|
||||||
|
custom_secret_record.org_id == user_uuid
|
||||||
|
), 'custom_secrets.org_id should be UUID'
|
||||||
|
|
||||||
|
# Verify billing_sessions org_id was set
|
||||||
|
result = await session.execute(
|
||||||
|
select(BillingSession).filter(BillingSession.user_id == user_id)
|
||||||
|
)
|
||||||
|
billing_record = result.scalars().first()
|
||||||
|
assert billing_record is not None
|
||||||
|
assert (
|
||||||
|
billing_record.org_id == user_uuid
|
||||||
|
), 'billing_sessions.org_id should be UUID'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_migrate_user_sql_no_matching_records(async_session_maker):
|
||||||
|
"""Test that migration SQL handles the case where no records match the user_id.
|
||||||
|
|
||||||
|
This verifies that the SQL queries don't fail when there are no matching records.
|
||||||
|
"""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
user_uuid = uuid.UUID(user_id)
|
||||||
|
user_uuid_str = str(user_uuid)
|
||||||
|
other_user_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Set up data for a different user
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
# Add conversation_metadata with user_id column for a different user
|
||||||
|
await session.execute(
|
||||||
|
text('ALTER TABLE conversation_metadata ADD COLUMN user_id VARCHAR')
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO conversation_metadata (conversation_id, user_id, conversation_version, created_at, last_updated_at)
|
||||||
|
VALUES (:conv_id, :user_id, 'V0', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{'conv_id': 'other-conv-1', 'user_id': other_user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create org and user for our test user
|
||||||
|
org = Org(id=user_uuid, name=f'user_{user_id}_org')
|
||||||
|
session.add(org)
|
||||||
|
user = User(id=user_uuid, current_org_id=user_uuid)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Execute migration SQL for our user (no data should match)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||||
|
SELECT
|
||||||
|
conversation_id,
|
||||||
|
:user_uuid,
|
||||||
|
:user_uuid
|
||||||
|
FROM conversation_metadata
|
||||||
|
WHERE user_id = :user_id_text
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{'user_uuid': user_uuid_str, 'user_id_text': user_id},
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify no records were created for our user
|
||||||
|
from storage.stored_conversation_metadata_saas import (
|
||||||
|
StoredConversationMetadataSaas,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(StoredConversationMetadataSaas).filter(
|
||||||
|
StoredConversationMetadataSaas.user_id == user_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
records = result.scalars().all()
|
||||||
|
assert (
|
||||||
|
len(records) == 0
|
||||||
|
), 'No records should be created for non-matching user_id'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_migrate_user_sql_multiple_conversations(async_session_maker):
|
||||||
|
"""Test that migration SQL correctly handles multiple conversations for a user."""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
user_uuid = uuid.UUID(user_id)
|
||||||
|
user_uuid_str = str(user_uuid)
|
||||||
|
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
# Create org and user FIRST (needed for foreign keys)
|
||||||
|
org = Org(id=user_uuid, name=f'user_{user_id}_org')
|
||||||
|
session.add(org)
|
||||||
|
user = User(id=user_uuid, current_org_id=user_uuid)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Add conversation_metadata with user_id column
|
||||||
|
await session.execute(
|
||||||
|
text('ALTER TABLE conversation_metadata ADD COLUMN user_id VARCHAR')
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Insert multiple conversations for the same user
|
||||||
|
for i in range(3):
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO conversation_metadata (conversation_id, user_id, conversation_version, created_at, last_updated_at)
|
||||||
|
VALUES (:conv_id, :user_id, 'V0', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{'conv_id': f'test-conv-{i}', 'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify that conversation_metadata was inserted
|
||||||
|
result = await session.execute(
|
||||||
|
text('SELECT conversation_id, user_id FROM conversation_metadata')
|
||||||
|
)
|
||||||
|
conv_rows = result.fetchall()
|
||||||
|
assert (
|
||||||
|
len(conv_rows) == 3
|
||||||
|
), f'Expected 3 conversation_metadata rows, got {len(conv_rows)}'
|
||||||
|
|
||||||
|
# Execute migration SQL
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||||
|
SELECT
|
||||||
|
conversation_id,
|
||||||
|
:user_uuid,
|
||||||
|
:user_uuid
|
||||||
|
FROM conversation_metadata
|
||||||
|
WHERE user_id = :user_id_text
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{'user_uuid': user_uuid_str, 'user_id_text': user_id},
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify all conversations were migrated using raw SQL
|
||||||
|
# (SQLite stores UUIDs as strings, ORM comparison may differ)
|
||||||
|
result = await session.execute(
|
||||||
|
text(
|
||||||
|
'SELECT conversation_id, user_id, org_id FROM conversation_metadata_saas WHERE user_id = :user_uuid'
|
||||||
|
),
|
||||||
|
{'user_uuid': user_uuid_str},
|
||||||
|
)
|
||||||
|
saas_rows = result.fetchall()
|
||||||
|
assert len(saas_rows) == 3, 'All 3 conversations should be migrated'
|
||||||
|
|
||||||
|
# Verify the user_id and org_id values
|
||||||
|
for row in saas_rows:
|
||||||
|
assert (
|
||||||
|
row.user_id == user_uuid_str
|
||||||
|
), f'user_id should match: {row.user_id} vs {user_uuid_str}'
|
||||||
|
assert (
|
||||||
|
row.org_id == user_uuid_str
|
||||||
|
), f'org_id should match: {row.org_id} vs {user_uuid_str}'
|
||||||
|
|||||||
Reference in New Issue
Block a user