mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Refactor user_store.py to use async database sessions (#13187)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -17,7 +17,7 @@ from openhands.server.user_auth import get_user_id
|
||||
# Helper functions for BYOR API key management
|
||||
async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
@@ -36,7 +36,7 @@ async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
|
||||
async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
"""Store the BYOR key in the database for a user."""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
@@ -55,7 +55,7 @@ async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
try:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
current_org_id = str(user.current_org_id)
|
||||
@@ -98,7 +98,7 @@ async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||
"""
|
||||
try:
|
||||
# Get user to construct the key alias
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
key_alias = None
|
||||
if user and user.current_org_id:
|
||||
key_alias = f'BYOR Key - user {user_id}, org {user.current_org_id}'
|
||||
|
||||
@@ -204,7 +204,7 @@ async def keycloak_callback(
|
||||
email = user_info.email
|
||||
user_id = user_info.sub
|
||||
user_info_dict = user_info.model_dump(exclude_none=True)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info_dict)
|
||||
else:
|
||||
|
||||
@@ -90,7 +90,7 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, detail='User not found')
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
@@ -248,7 +248,7 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
user = await UserStore.get_user_by_id_async(billing_session.user_id)
|
||||
user = await UserStore.get_user_by_id(billing_session.user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, detail='User not found')
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
|
||||
@@ -197,7 +197,7 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
keycloak_user_id = user_info.sub
|
||||
user = await UserStore.get_user_by_id_async(keycloak_user_id)
|
||||
user = await UserStore.get_user_by_id(keycloak_user_id)
|
||||
if not user:
|
||||
return _html_response(
|
||||
title='Failed to authenticate.',
|
||||
|
||||
@@ -99,7 +99,7 @@ async def list_user_orgs(
|
||||
|
||||
try:
|
||||
# Fetch user to get current_org_id
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
current_org_id = (
|
||||
str(user.current_org_id) if user and user.current_org_id else None
|
||||
)
|
||||
|
||||
@@ -115,7 +115,7 @@ async def saas_get_user(
|
||||
email = user_info.email
|
||||
sub = user_info.sub
|
||||
if sub:
|
||||
db_user = await UserStore.get_user_by_id_async(sub)
|
||||
db_user = await UserStore.get_user_by_id(sub)
|
||||
if db_user and db_user.email is not None:
|
||||
email = db_user.email
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ class OrgInvitationService:
|
||||
raise ValueError(f'Invalid role: {role_name}')
|
||||
|
||||
# Step 5: Check if user is already a member (by email)
|
||||
existing_user = await UserStore.get_user_by_email_async(email)
|
||||
existing_user = await UserStore.get_user_by_email(email)
|
||||
if existing_user:
|
||||
existing_member = await OrgMemberStore.get_org_member(
|
||||
org_id, existing_user.id
|
||||
@@ -127,7 +127,7 @@ class OrgInvitationService:
|
||||
# Step 7: Send invitation email
|
||||
try:
|
||||
# Get inviter info for the email
|
||||
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
|
||||
inviter_user = await UserStore.get_user_by_id(str(inviter_member.user_id))
|
||||
inviter_name = 'A team member'
|
||||
if inviter_user and inviter_user.email:
|
||||
inviter_name = inviter_user.email.split('@')[0]
|
||||
@@ -308,7 +308,7 @@ class OrgInvitationService:
|
||||
raise InvitationExpiredError('Invitation has expired')
|
||||
|
||||
# Step 2.5: Verify user email matches invitation email
|
||||
user = await UserStore.get_user_by_id_async(str(user_id))
|
||||
user = await UserStore.get_user_by_id(str(user_id))
|
||||
if not user:
|
||||
raise InvitationInvalidError('User not found')
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class OrgMemberService:
|
||||
raise RoleNotFoundError(org_member.role_id)
|
||||
|
||||
# Get user email
|
||||
user = await UserStore.get_user_by_id_async(str(user_id))
|
||||
user = await UserStore.get_user_by_id(str(user_id))
|
||||
email = user.email if user and user.email else ''
|
||||
|
||||
return MeResponse.from_org_member(org_member, role, email)
|
||||
@@ -218,10 +218,10 @@ class OrgMemberService:
|
||||
return False, 'removal_failed'
|
||||
|
||||
# Update user's current_org_id if it points to the org they were removed from
|
||||
user = await UserStore.get_user_by_id_async(str(target_user_id))
|
||||
user = await UserStore.get_user_by_id(str(target_user_id))
|
||||
if user and user.current_org_id == org_id:
|
||||
# Set current_org_id to personal workspace (org.id == user.id)
|
||||
UserStore.update_current_org(str(target_user_id), target_user_id)
|
||||
await UserStore.update_current_org(str(target_user_id), target_user_id)
|
||||
|
||||
# If database removal succeeded, also remove from LiteLLM team
|
||||
try:
|
||||
@@ -308,7 +308,7 @@ class OrgMemberService:
|
||||
|
||||
# If no role change requested, return current state
|
||||
if new_role_name is None:
|
||||
user = await UserStore.get_user_by_id_async(str(target_user_id))
|
||||
user = await UserStore.get_user_by_id(str(target_user_id))
|
||||
return OrgMemberResponse(
|
||||
user_id=str(target_membership.user_id),
|
||||
email=user.email if user else None,
|
||||
@@ -347,7 +347,7 @@ class OrgMemberService:
|
||||
raise MemberUpdateError('Failed to update member')
|
||||
|
||||
# Get user email for response
|
||||
user = await UserStore.get_user_by_id_async(str(target_user_id))
|
||||
user = await UserStore.get_user_by_id(str(target_user_id))
|
||||
|
||||
return OrgMemberResponse(
|
||||
user_id=str(updated_member.user_id),
|
||||
|
||||
@@ -37,7 +37,7 @@ class ApiKeyStore:
|
||||
The generated API key
|
||||
"""
|
||||
api_key = self.generate_api_key()
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f'User not found: {user_id}')
|
||||
org_id = user.current_org_id
|
||||
@@ -117,7 +117,7 @@ class ApiKeyStore:
|
||||
|
||||
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
|
||||
"""List all API keys for a user."""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f'User not found: {user_id}')
|
||||
org_id = user.current_org_id
|
||||
@@ -132,7 +132,7 @@ class ApiKeyStore:
|
||||
return [key for key in keys if key.name != 'MCP_API_KEY']
|
||||
|
||||
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f'User not found: {user_id}')
|
||||
org_id = user.current_org_id
|
||||
|
||||
@@ -1171,7 +1171,7 @@ class LiteLlmManager:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
user = await UserStore.get_user_by_id_async(keycloak_user_id)
|
||||
user = await UserStore.get_user_by_id(keycloak_user_id)
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
|
||||
@@ -875,7 +875,7 @@ class OrgService:
|
||||
Returns:
|
||||
bool: True if BYOR export is enabled, False otherwise
|
||||
"""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user or not user.current_org_id:
|
||||
return False
|
||||
|
||||
@@ -929,7 +929,7 @@ class OrgService:
|
||||
|
||||
# Step 3: Update user's current_org_id
|
||||
try:
|
||||
updated_user = UserStore.update_current_org(user_id, org_id)
|
||||
updated_user = await UserStore.update_current_org(user_id, org_id)
|
||||
if not updated_user:
|
||||
raise OrgDatabaseError('User not found')
|
||||
|
||||
|
||||
@@ -236,6 +236,6 @@ class SaasConversationStore(ConversationStore):
|
||||
# user_id should not be None in SaaS, should we raise?
|
||||
# Use async version since callers now use asyncio.run_coroutine_threadsafe()
|
||||
# to dispatch to the main event loop where asyncpg connections work properly.
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
return SaasConversationStore(str(user_id), org_id, session_maker)
|
||||
|
||||
@@ -24,7 +24,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def load(self) -> Secrets | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
user = await UserStore.get_user_by_id(self.user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
|
||||
async with a_session_maker() as session:
|
||||
@@ -52,7 +52,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def store(self, item: Secrets):
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
user = await UserStore.get_user_by_id(self.user_id)
|
||||
if user is None:
|
||||
raise ValueError(f'User not found: {self.user_id}')
|
||||
org_id = user.current_org_id
|
||||
|
||||
@@ -68,7 +68,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
return result.scalars().first()
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
user = await UserStore.get_user_by_id(self.user_id)
|
||||
if not user:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
|
||||
@@ -16,8 +16,8 @@ from server.constants import (
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from sqlalchemy.orm import selectinload
|
||||
from storage.database import a_session_maker
|
||||
from storage.encrypt_utils import (
|
||||
decrypt_legacy_model,
|
||||
decrypt_legacy_value,
|
||||
@@ -30,8 +30,6 @@ from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from utils.identity import resolve_display_name
|
||||
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
# The max possible time to wait for another process to finish creating a user before retrying
|
||||
_REDIS_CREATE_TIMEOUT_SECONDS = 30
|
||||
# The delay to wait for another process to finish creating a user before trying to load again
|
||||
@@ -50,7 +48,7 @@ class UserStore:
|
||||
role_id: Optional[int] = None,
|
||||
) -> User | None:
|
||||
"""Create a new user."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
@@ -105,9 +103,9 @@ class UserStore:
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(user, ['org_members']) # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
@@ -176,19 +174,17 @@ class UserStore:
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Check if user has completed billing sessions to enable BYOR export
|
||||
from storage.billing_session import BillingSession
|
||||
|
||||
has_completed_billing = (
|
||||
session.query(BillingSession)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(BillingSession).filter(
|
||||
BillingSession.user_id == user_id,
|
||||
BillingSession.status == 'completed',
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
has_completed_billing = result.scalars().first() is not None
|
||||
|
||||
# create personal org
|
||||
org = Org(
|
||||
@@ -297,15 +293,15 @@ class UserStore:
|
||||
|
||||
# Mark the old user_settings as migrated instead of deleting
|
||||
user_settings.already_migrated = True
|
||||
session.merge(user_settings)
|
||||
session.flush()
|
||||
await session.merge(user_settings)
|
||||
await session.flush()
|
||||
logger.debug(
|
||||
'user_store:migrate_user:session_flush_complete',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# need to migrate conversation metadata
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||
SELECT
|
||||
@@ -322,7 +318,7 @@ class UserStore:
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Update stripe_customers
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
@@ -330,7 +326,7 @@ class UserStore:
|
||||
)
|
||||
|
||||
# Update slack_users
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text(
|
||||
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
@@ -338,7 +334,7 @@ class UserStore:
|
||||
)
|
||||
|
||||
# Update slack_conversation
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
@@ -346,13 +342,13 @@ class UserStore:
|
||||
)
|
||||
|
||||
# Update api_keys
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update custom_secrets
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text(
|
||||
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
@@ -360,16 +356,16 @@ class UserStore:
|
||||
)
|
||||
|
||||
# Update billing_sessions
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(user, ['org_members']) # load org_members
|
||||
logger.debug(
|
||||
'user_store:migrate_user:session_committed',
|
||||
extra={'user_id': user_id},
|
||||
@@ -410,14 +406,14 @@ class UserStore:
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Get the user and their org_member
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(selectinload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:user_not_found',
|
||||
@@ -426,7 +422,10 @@ class UserStore:
|
||||
return None
|
||||
|
||||
# Get the user's personal org (org_id == user_id)
|
||||
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
|
||||
result = await session.execute(
|
||||
select(Org).filter(Org.id == uuid.UUID(user_id))
|
||||
)
|
||||
org = result.scalars().first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:org_not_found',
|
||||
@@ -435,9 +434,10 @@ class UserStore:
|
||||
return None
|
||||
|
||||
# Get org_members for this org - should only be one for personal orgs
|
||||
org_members = (
|
||||
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
|
||||
result = await session.execute(
|
||||
select(OrgMember).filter(OrgMember.org_id == org.id)
|
||||
)
|
||||
org_members = result.scalars().all()
|
||||
|
||||
if len(org_members) != 1:
|
||||
logger.error(
|
||||
@@ -453,14 +453,13 @@ class UserStore:
|
||||
org_member = org_members[0]
|
||||
|
||||
# Get the user_settings (for migrated users)
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_settings = result.scalars().first()
|
||||
|
||||
# For new sign-ups after migration, user_settings won't exist
|
||||
# Fall back to getting data from org_members
|
||||
@@ -491,7 +490,7 @@ class UserStore:
|
||||
'user_store:downgrade_user:created_user_settings_from_org_member',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
# Call LiteLLM downgrade
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
@@ -531,7 +530,7 @@ class UserStore:
|
||||
# Step 3: Copy user_id from conversation_metadata_saas to conversation_metadata
|
||||
# This ensures any conversations created after migration have their user_id
|
||||
# preserved in the original table before we delete the saas entries
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE conversation_metadata
|
||||
SET user_id = :user_id
|
||||
@@ -545,14 +544,14 @@ class UserStore:
|
||||
)
|
||||
|
||||
# Step 4: Delete conversation_metadata_saas entries
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM conversation_metadata_saas WHERE user_id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 5: Reset org_id columns in related tables
|
||||
# Reset stripe_customers
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
@@ -560,13 +559,13 @@ class UserStore:
|
||||
)
|
||||
|
||||
# Reset slack_users
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('UPDATE slack_users SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset slack_conversation
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
@@ -574,19 +573,19 @@ class UserStore:
|
||||
)
|
||||
|
||||
# Reset api_keys
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('UPDATE api_keys SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset custom_secrets
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('UPDATE custom_secrets SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset billing_sessions
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
@@ -594,19 +593,19 @@ class UserStore:
|
||||
)
|
||||
|
||||
# Step 6: Delete org_member entries for this org
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 7: Delete the user entry
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM "user" WHERE id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Delete the org entry
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM org WHERE id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
@@ -626,9 +625,9 @@ class UserStore:
|
||||
if value is not None and not _is_legacy_value_encrypted(value):
|
||||
setattr(user_settings, key, encrypt_legacy_value(value))
|
||||
|
||||
session.merge(user_settings)
|
||||
await session.merge(user_settings)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_store:downgrade_user:complete',
|
||||
@@ -637,88 +636,12 @@ class UserStore:
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (sync version).
|
||||
|
||||
Note: This method uses call_async_from_sync internally which creates a new
|
||||
event loop. If you're already in an async context, use get_user_by_id_async
|
||||
instead to avoid event loop conflicts.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if we need to migrate from user_settings
|
||||
while not call_async_from_sync(
|
||||
UserStore._acquire_user_creation_lock, GENERAL_TIMEOUT, user_id
|
||||
):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'user_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
call_async_from_sync(
|
||||
asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS
|
||||
)
|
||||
|
||||
try:
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = call_async_from_sync(
|
||||
token_manager.get_user_info_from_user_id,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
)
|
||||
user = call_async_from_sync(
|
||||
UserStore.migrate_user,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
finally:
|
||||
call_async_from_sync(
|
||||
UserStore._release_user_creation_lock, GENERAL_TIMEOUT, user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_id_async(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (async version).
|
||||
|
||||
This is the preferred method when calling from an async context as it
|
||||
avoids event loop conflicts that can occur with the sync version.
|
||||
"""
|
||||
async def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.options(selectinload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
@@ -729,7 +652,7 @@ class UserStore:
|
||||
while not await UserStore._acquire_user_creation_lock(user_id):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:waiting_for_lock',
|
||||
'user_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
@@ -738,17 +661,13 @@ class UserStore:
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.options(selectinload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:start_migration',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
@@ -759,10 +678,6 @@ class UserStore:
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id)
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:calling_migrate_user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
user = await UserStore.migrate_user(
|
||||
user_id,
|
||||
user_settings,
|
||||
@@ -775,8 +690,8 @@ class UserStore:
|
||||
await UserStore._release_user_creation_lock(user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_email_async(email: str) -> Optional[User]:
|
||||
"""Get user by email address (async version).
|
||||
async def get_user_by_email(email: str) -> Optional[User]:
|
||||
"""Get user by email address.
|
||||
|
||||
This method looks up a user by their email address. Note that email
|
||||
addresses may not be unique across all users in rare cases.
|
||||
@@ -793,19 +708,20 @@ class UserStore:
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.options(selectinload(User.org_members))
|
||||
.filter(User.email == email.lower().strip())
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
async def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(User))
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
|
||||
async def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
|
||||
"""Update the user's current organization.
|
||||
|
||||
Args:
|
||||
@@ -815,19 +731,17 @@ class UserStore:
|
||||
Returns:
|
||||
User: The updated user object, or None if user not found
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.with_for_update()
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id)).with_for_update()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user.current_org_id = org_id
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -374,7 +374,7 @@ class TestDeleteByorKeyFromLitellm:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id')
|
||||
async def test_delete_constructs_alias_from_user(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
@@ -400,7 +400,7 @@ class TestDeleteByorKeyFromLitellm:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id')
|
||||
async def test_delete_without_user_passes_no_alias(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
@@ -421,7 +421,7 @@ class TestDeleteByorKeyFromLitellm:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id')
|
||||
async def test_delete_without_org_id_passes_no_alias(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
@@ -444,7 +444,7 @@ class TestDeleteByorKeyFromLitellm:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id')
|
||||
async def test_delete_returns_false_on_exception(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
|
||||
@@ -514,7 +514,7 @@ async def test_list_user_orgs_success(mock_app_list):
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.UserStore.get_user_by_id_async',
|
||||
'server.routes.orgs.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -568,7 +568,7 @@ async def test_list_user_orgs_returns_current_org_id(mock_app_list):
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.UserStore.get_user_by_id_async',
|
||||
'server.routes.orgs.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -613,7 +613,7 @@ async def test_list_user_orgs_with_pagination(mock_app_list):
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.UserStore.get_user_by_id_async',
|
||||
'server.routes.orgs.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -648,7 +648,7 @@ async def test_list_user_orgs_empty(mock_app_list):
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.UserStore.get_user_by_id_async',
|
||||
'server.routes.orgs.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -715,7 +715,7 @@ async def test_list_user_orgs_service_error(mock_app_list):
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.UserStore.get_user_by_id_async',
|
||||
'server.routes.orgs.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -781,7 +781,7 @@ async def test_list_user_orgs_personal_org_identified(mock_app_list):
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.UserStore.get_user_by_id_async',
|
||||
'server.routes.orgs.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -820,7 +820,7 @@ async def test_list_user_orgs_team_org_identified(mock_app_list):
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.UserStore.get_user_by_id_async',
|
||||
'server.routes.orgs.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -869,7 +869,7 @@ async def test_list_user_orgs_mixed_personal_and_team(mock_app_list):
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.UserStore.get_user_by_id_async',
|
||||
'server.routes.orgs.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -941,7 +941,7 @@ async def test_list_user_orgs_all_fields_present(mock_app_list):
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.UserStore.get_user_by_id_async',
|
||||
'server.routes.orgs.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
|
||||
@@ -718,7 +718,7 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -767,7 +767,7 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -815,7 +815,7 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -967,7 +967,7 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
@@ -1149,7 +1149,7 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -1251,11 +1251,12 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.update_current_org'
|
||||
'server.services.org_member_service.UserStore.update_current_org',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_org,
|
||||
):
|
||||
mock_get_member.side_effect = [
|
||||
@@ -1274,7 +1275,9 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
# Assert
|
||||
assert success is True
|
||||
assert error is None
|
||||
mock_update_org.assert_called_once_with(str(target_user_id), target_user_id)
|
||||
mock_update_org.assert_awaited_once_with(
|
||||
str(target_user_id), target_user_id
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_member_does_not_update_current_org_id_when_not_matching(
|
||||
@@ -1307,11 +1310,12 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.update_current_org'
|
||||
'server.services.org_member_service.UserStore.update_current_org',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_org,
|
||||
):
|
||||
mock_get_member.side_effect = [
|
||||
@@ -1330,7 +1334,7 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
# Assert
|
||||
assert success is True
|
||||
assert error is None
|
||||
mock_update_org.assert_not_called()
|
||||
mock_update_org.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_member_succeeds_when_user_not_found_after_removal(
|
||||
@@ -1359,11 +1363,12 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.update_current_org'
|
||||
'server.services.org_member_service.UserStore.update_current_org',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_org,
|
||||
):
|
||||
mock_get_member.side_effect = [
|
||||
@@ -1382,7 +1387,7 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
# Assert
|
||||
assert success is True
|
||||
assert error is None
|
||||
mock_update_org.assert_not_called()
|
||||
mock_update_org.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_removal_calls_litellm_remove_user_from_team(
|
||||
@@ -1411,7 +1416,7 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
@@ -1465,7 +1470,7 @@ class TestOrgMemberServiceRemoveOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_remove,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
@@ -1632,7 +1637,7 @@ class TestOrgMemberServiceUpdateOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -1693,7 +1698,7 @@ class TestOrgMemberServiceUpdateOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -1752,7 +1757,7 @@ class TestOrgMemberServiceUpdateOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -1815,7 +1820,7 @@ class TestOrgMemberServiceUpdateOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch.object(
|
||||
@@ -2036,7 +2041,7 @@ class TestOrgMemberServiceUpdateOrgMember:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_role,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -2253,7 +2258,7 @@ class TestOrgMemberServiceGetMe:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_role,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -2340,7 +2345,7 @@ class TestOrgMemberServiceGetMe:
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_role,
|
||||
patch(
|
||||
'server.services.org_member_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_member_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
|
||||
@@ -56,7 +56,7 @@ def test_generate_api_key(api_key_store):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
async def test_create_api_key(
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
):
|
||||
@@ -324,7 +324,7 @@ async def test_delete_api_key_by_id(api_key_store, async_session_maker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
async def test_list_api_keys(
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
):
|
||||
@@ -377,7 +377,7 @@ async def test_list_api_keys(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
async def test_retrieve_mcp_api_key(
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
):
|
||||
@@ -416,7 +416,7 @@ async def test_retrieve_mcp_api_key(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
async def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
):
|
||||
|
||||
@@ -158,7 +158,7 @@ async def test_keycloak_callback_user_not_allowed(
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = None
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
@@ -197,7 +197,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
@@ -273,7 +273,7 @@ async def test_keycloak_callback_email_not_verified(
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -324,7 +324,7 @@ async def test_keycloak_callback_email_not_verified_missing_field(
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -368,7 +368,7 @@ async def test_keycloak_callback_success_without_offline_token(
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
@@ -616,7 +616,7 @@ async def test_keycloak_callback_blocked_email_domain(
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -683,7 +683,7 @@ async def test_keycloak_callback_allowed_email_domain(
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -750,7 +750,7 @@ async def test_keycloak_callback_domain_blocking_inactive(
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -813,7 +813,7 @@ async def test_keycloak_callback_missing_email(mock_request, create_keycloak_use
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -862,7 +862,7 @@ async def test_keycloak_callback_duplicate_email_detected(
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -910,7 +910,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -971,7 +971,7 @@ async def test_keycloak_callback_duplicate_check_exception(
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1032,7 +1032,7 @@ async def test_keycloak_callback_no_duplicate_email(
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1096,7 +1096,7 @@ async def test_keycloak_callback_no_email_in_user_info(
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1254,7 +1254,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1322,7 +1322,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1408,7 +1408,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1497,7 +1497,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1585,7 +1585,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1670,7 +1670,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1752,7 +1752,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1822,7 +1822,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1896,7 +1896,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -1970,7 +1970,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
@@ -2020,7 +2020,7 @@ async def test_keycloak_callback_calls_backfill_user_email_for_existing_user(
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
@@ -103,7 +103,7 @@ async def test_get_credits_lite_llm_error():
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
@@ -135,7 +135,7 @@ async def test_get_credits_success():
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch('httpx.AsyncClient', return_value=mock_client),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
@@ -338,7 +338,7 @@ async def test_success_callback_success(async_session_maker, test_org, test_user
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id=test_org.id),
|
||||
),
|
||||
@@ -410,7 +410,7 @@ async def test_success_callback_lite_llm_error(
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id=test_org.id),
|
||||
),
|
||||
@@ -464,7 +464,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id=test_org.id),
|
||||
),
|
||||
|
||||
@@ -1100,9 +1100,7 @@ class TestLiteLlmManager:
|
||||
mock_org_member.org_id = 'test-ord-id'
|
||||
mock_org_member.llm_api_key = 'test-api-key'
|
||||
mock_user.org_members = [mock_org_member]
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(
|
||||
return_value=mock_user
|
||||
)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
@@ -1118,7 +1116,7 @@ class TestLiteLlmManager:
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('storage.user_store.UserStore') as mock_user_store:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=None)
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=None)
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
|
||||
@@ -70,7 +70,7 @@ class TestAcceptInvitationEmailValidation:
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
@@ -106,7 +106,7 @@ class TestAcceptInvitationEmailValidation:
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
@@ -174,7 +174,7 @@ class TestAcceptInvitationEmailValidation:
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
@@ -220,7 +220,7 @@ class TestAcceptInvitationEmailValidation:
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
|
||||
@@ -1870,7 +1870,7 @@ async def test_check_byor_export_enabled_returns_true_when_enabled():
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
'storage.org_service.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -1905,7 +1905,7 @@ async def test_check_byor_export_enabled_returns_false_when_disabled():
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
'storage.org_service.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -1932,7 +1932,7 @@ async def test_check_byor_export_enabled_returns_false_when_user_not_found():
|
||||
user_id = 'nonexistent-user'
|
||||
|
||||
with patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
'storage.org_service.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
# Act
|
||||
@@ -1956,7 +1956,7 @@ async def test_check_byor_export_enabled_returns_false_when_no_current_org():
|
||||
mock_user.current_org_id = None
|
||||
|
||||
with patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
'storage.org_service.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
):
|
||||
# Act
|
||||
@@ -1982,7 +1982,7 @@ async def test_check_byor_export_enabled_returns_false_when_org_not_found():
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
'storage.org_service.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
@@ -2025,6 +2025,7 @@ async def test_switch_org_success():
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch(
|
||||
'storage.org_service.UserStore.update_current_org',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_updated_user,
|
||||
),
|
||||
):
|
||||
@@ -2116,7 +2117,11 @@ async def test_switch_org_user_not_found():
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch('storage.org_service.UserStore.update_current_org', return_value=None),
|
||||
patch(
|
||||
'storage.org_service.UserStore.update_current_org',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgDatabaseError) as exc_info:
|
||||
|
||||
@@ -169,14 +169,14 @@ async def test_exists(session_maker):
|
||||
class TestGetInstance:
|
||||
"""Tests for SaasConversationStore.get_instance method.
|
||||
|
||||
The get_instance method uses async UserStore.get_user_by_id_async because
|
||||
The get_instance method uses async UserStore.get_user_by_id because
|
||||
callers now use asyncio.run_coroutine_threadsafe() to dispatch to the main
|
||||
event loop where asyncpg connections work properly.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_uses_async_get_user_by_id(self):
|
||||
"""Verify get_instance calls the async get_user_by_id_async for proper event loop handling."""
|
||||
"""Verify get_instance calls the async get_user_by_id for proper event loop handling."""
|
||||
# Arrange
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
mock_user = MagicMock(spec=User)
|
||||
@@ -184,7 +184,7 @@ class TestGetInstance:
|
||||
mock_config = MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
with patch(
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=mock_user),
|
||||
) as mock_async_get_user, patch(
|
||||
'storage.saas_conversation_store.session_maker'
|
||||
@@ -205,7 +205,7 @@ class TestGetInstance:
|
||||
mock_config = MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
with patch(
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id',
|
||||
AsyncMock(return_value=None),
|
||||
), patch('storage.saas_conversation_store.session_maker'):
|
||||
# Act
|
||||
|
||||
@@ -44,7 +44,7 @@ def secrets_store(async_session_maker, mock_config):
|
||||
class TestSaasSecretsStore:
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_store_and_load(self, mock_get_user, secrets_store, mock_user):
|
||||
@@ -84,7 +84,7 @@ class TestSaasSecretsStore:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_encryption_decryption(self, mock_get_user, secrets_store, mock_user):
|
||||
@@ -186,7 +186,7 @@ class TestSaasSecretsStore:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_update_existing_secrets(
|
||||
|
||||
@@ -35,9 +35,9 @@ def mock_check_idp():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_store():
|
||||
"""Mock UserStore.get_user_by_id_async to return None by default."""
|
||||
"""Mock UserStore.get_user_by_id to return None by default."""
|
||||
with patch(
|
||||
'server.routes.user.UserStore.get_user_by_id_async',
|
||||
'server.routes.user.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
) as mock_fn:
|
||||
|
||||
@@ -101,11 +101,11 @@ async def test_create_default_settings_with_litellm(mock_litellm_api):
|
||||
assert settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
# --- Tests for get_user_by_id_async ---
|
||||
# --- Tests for get_user_by_id ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_id_async_existing_user(async_session_maker):
|
||||
async def test_get_user_by_id_existing_user(async_session_maker):
|
||||
"""Test retrieving an existing user by ID."""
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
@@ -120,7 +120,7 @@ async def test_get_user_by_id_async_existing_user(async_session_maker):
|
||||
|
||||
# Test retrieval with patched session maker
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_user_by_id_async(str(user_id))
|
||||
result = await UserStore.get_user_by_id(str(user_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.id == user_id
|
||||
@@ -128,8 +128,8 @@ async def test_get_user_by_id_async_existing_user(async_session_maker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_id_async_user_not_found(async_session_maker):
|
||||
"""Test that get_user_by_id_async returns None for non-existent user."""
|
||||
async def test_get_user_by_id_user_not_found(async_session_maker):
|
||||
"""Test that get_user_by_id returns None for non-existent user."""
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
@@ -138,16 +138,16 @@ async def test_get_user_by_id_async_user_not_found(async_session_maker):
|
||||
patch.object(UserStore, '_acquire_user_creation_lock', return_value=True),
|
||||
patch.object(UserStore, '_release_user_creation_lock', return_value=True),
|
||||
):
|
||||
result = await UserStore.get_user_by_id_async(non_existent_id)
|
||||
result = await UserStore.get_user_by_id(non_existent_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- Tests for get_user_by_email_async ---
|
||||
# --- Tests for get_user_by_email ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_async_existing_user(async_session_maker):
|
||||
async def test_get_user_by_email_existing_user(async_session_maker):
|
||||
"""Test retrieving a user by email."""
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
@@ -163,7 +163,7 @@ async def test_get_user_by_email_async_existing_user(async_session_maker):
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_user_by_email_async(email)
|
||||
result = await UserStore.get_user_by_email(email)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == user_id
|
||||
@@ -171,28 +171,28 @@ async def test_get_user_by_email_async_existing_user(async_session_maker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_async_not_found(async_session_maker):
|
||||
"""Test that get_user_by_email_async returns None for non-existent email."""
|
||||
async def test_get_user_by_email_not_found(async_session_maker):
|
||||
"""Test that get_user_by_email returns None for non-existent email."""
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_user_by_email_async('nonexistent@example.com')
|
||||
result = await UserStore.get_user_by_email('nonexistent@example.com')
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_async_empty_email(async_session_maker):
|
||||
"""Test that get_user_by_email_async returns None for empty email."""
|
||||
async def test_get_user_by_email_empty_email(async_session_maker):
|
||||
"""Test that get_user_by_email returns None for empty email."""
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_user_by_email_async('')
|
||||
result = await UserStore.get_user_by_email('')
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_async_none_email(async_session_maker):
|
||||
"""Test that get_user_by_email_async returns None for None email."""
|
||||
async def test_get_user_by_email_none_email(async_session_maker):
|
||||
"""Test that get_user_by_email returns None for None email."""
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_user_by_email_async(None)
|
||||
result = await UserStore.get_user_by_email(None)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -543,47 +543,50 @@ async def test_backfill_contact_name_no_real_name(async_session_maker):
|
||||
assert org.contact_name == 'jdoe'
|
||||
|
||||
|
||||
# --- Tests for update_current_org (sync) ---
|
||||
# --- Tests for update_current_org ---
|
||||
|
||||
|
||||
def test_update_current_org_success(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_current_org_success(async_session_maker):
|
||||
"""Test updating a user's current organization."""
|
||||
user_id = uuid.uuid4()
|
||||
initial_org_id = uuid.uuid4()
|
||||
new_org_id = uuid.uuid4()
|
||||
|
||||
# Create test data
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
org1 = Org(id=initial_org_id, name='org1')
|
||||
org2 = Org(id=new_org_id, name='org2')
|
||||
session.add_all([org1, org2])
|
||||
user = User(id=user_id, current_org_id=initial_org_id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Update current org
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
result = UserStore.update_current_org(str(user_id), new_org_id)
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.update_current_org(str(user_id), new_org_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.current_org_id == new_org_id
|
||||
|
||||
|
||||
def test_update_current_org_user_not_found(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_current_org_user_not_found(async_session_maker):
|
||||
"""Test that update_current_org returns None for non-existent user."""
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
result = UserStore.update_current_org(user_id, org_id)
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.update_current_org(user_id, org_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- Tests for list_users (sync) ---
|
||||
# --- Tests for list_users ---
|
||||
|
||||
|
||||
def test_list_users(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users(async_session_maker):
|
||||
"""Test listing all users."""
|
||||
user_id1 = uuid.uuid4()
|
||||
user_id2 = uuid.uuid4()
|
||||
@@ -591,18 +594,18 @@ def test_list_users(session_maker):
|
||||
org_id2 = uuid.uuid4()
|
||||
|
||||
# Create test data
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
org1 = Org(id=org_id1, name='org1')
|
||||
org2 = Org(id=org_id2, name='org2')
|
||||
session.add_all([org1, org2])
|
||||
user1 = User(id=user_id1, current_org_id=org_id1)
|
||||
user2 = User(id=user_id2, current_org_id=org_id2)
|
||||
session.add_all([user1, user2])
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# List users
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
users = UserStore.list_users()
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
users = await UserStore.list_users()
|
||||
|
||||
assert len(users) >= 2
|
||||
user_ids = [user.id for user in users]
|
||||
|
||||
Reference in New Issue
Block a user