Refactor user_store.py to use async database sessions (#13187)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-03-03 17:51:53 -07:00
committed by GitHub
parent 222e8bd03d
commit 8dac1095d7
28 changed files with 235 additions and 310 deletions

View File

@@ -17,7 +17,7 @@ from openhands.server.user_auth import get_user_id
# Helper functions for BYOR API key management
async def get_byor_key_from_db(user_id: str) -> str | None:
"""Get the BYOR key from the database for a user."""
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
if not user:
return None
@@ -36,7 +36,7 @@ async def get_byor_key_from_db(user_id: str) -> str | None:
async def store_byor_key_in_db(user_id: str, key: str) -> None:
"""Store the BYOR key in the database for a user."""
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
if not user:
return None
@@ -55,7 +55,7 @@ async def store_byor_key_in_db(user_id: str, key: str) -> None:
async def generate_byor_key(user_id: str) -> str | None:
"""Generate a new BYOR key for a user."""
try:
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
if not user:
return None
current_org_id = str(user.current_org_id)
@@ -98,7 +98,7 @@ async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
"""
try:
# Get user to construct the key alias
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
key_alias = None
if user and user.current_org_id:
key_alias = f'BYOR Key - user {user_id}, org {user.current_org_id}'

View File

@@ -204,7 +204,7 @@ async def keycloak_callback(
email = user_info.email
user_id = user_info.sub
user_info_dict = user_info.model_dump(exclude_none=True)
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
if not user:
user = await UserStore.create_user(user_id, user_info_dict)
else:

View File

@@ -90,7 +90,7 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
if not stripe_service.STRIPE_API_KEY:
return GetCreditsResponse()
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
if user is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail='User not found')
user_team_info = await LiteLlmManager.get_user_team_info(
@@ -248,7 +248,7 @@ async def success_callback(session_id: str, request: Request):
)
raise HTTPException(status.HTTP_400_BAD_REQUEST)
user = await UserStore.get_user_by_id_async(billing_session.user_id)
user = await UserStore.get_user_by_id(billing_session.user_id)
if user is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail='User not found')
user_team_info = await LiteLlmManager.get_user_team_info(

View File

@@ -197,7 +197,7 @@ async def keycloak_callback(
user_info = await token_manager.get_user_info(keycloak_access_token)
keycloak_user_id = user_info.sub
user = await UserStore.get_user_by_id_async(keycloak_user_id)
user = await UserStore.get_user_by_id(keycloak_user_id)
if not user:
return _html_response(
title='Failed to authenticate.',

View File

@@ -99,7 +99,7 @@ async def list_user_orgs(
try:
# Fetch user to get current_org_id
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
current_org_id = (
str(user.current_org_id) if user and user.current_org_id else None
)

View File

@@ -115,7 +115,7 @@ async def saas_get_user(
email = user_info.email
sub = user_info.sub
if sub:
db_user = await UserStore.get_user_by_id_async(sub)
db_user = await UserStore.get_user_by_id(sub)
if db_user and db_user.email is not None:
email = db_user.email

View File

@@ -106,7 +106,7 @@ class OrgInvitationService:
raise ValueError(f'Invalid role: {role_name}')
# Step 5: Check if user is already a member (by email)
existing_user = await UserStore.get_user_by_email_async(email)
existing_user = await UserStore.get_user_by_email(email)
if existing_user:
existing_member = await OrgMemberStore.get_org_member(
org_id, existing_user.id
@@ -127,7 +127,7 @@ class OrgInvitationService:
# Step 7: Send invitation email
try:
# Get inviter info for the email
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
inviter_user = await UserStore.get_user_by_id(str(inviter_member.user_id))
inviter_name = 'A team member'
if inviter_user and inviter_user.email:
inviter_name = inviter_user.email.split('@')[0]
@@ -308,7 +308,7 @@ class OrgInvitationService:
raise InvitationExpiredError('Invitation has expired')
# Step 2.5: Verify user email matches invitation email
user = await UserStore.get_user_by_id_async(str(user_id))
user = await UserStore.get_user_by_id(str(user_id))
if not user:
raise InvitationInvalidError('User not found')

View File

@@ -56,7 +56,7 @@ class OrgMemberService:
raise RoleNotFoundError(org_member.role_id)
# Get user email
user = await UserStore.get_user_by_id_async(str(user_id))
user = await UserStore.get_user_by_id(str(user_id))
email = user.email if user and user.email else ''
return MeResponse.from_org_member(org_member, role, email)
@@ -218,10 +218,10 @@ class OrgMemberService:
return False, 'removal_failed'
# Update user's current_org_id if it points to the org they were removed from
user = await UserStore.get_user_by_id_async(str(target_user_id))
user = await UserStore.get_user_by_id(str(target_user_id))
if user and user.current_org_id == org_id:
# Set current_org_id to personal workspace (org.id == user.id)
UserStore.update_current_org(str(target_user_id), target_user_id)
await UserStore.update_current_org(str(target_user_id), target_user_id)
# If database removal succeeded, also remove from LiteLLM team
try:
@@ -308,7 +308,7 @@ class OrgMemberService:
# If no role change requested, return current state
if new_role_name is None:
user = await UserStore.get_user_by_id_async(str(target_user_id))
user = await UserStore.get_user_by_id(str(target_user_id))
return OrgMemberResponse(
user_id=str(target_membership.user_id),
email=user.email if user else None,
@@ -347,7 +347,7 @@ class OrgMemberService:
raise MemberUpdateError('Failed to update member')
# Get user email for response
user = await UserStore.get_user_by_id_async(str(target_user_id))
user = await UserStore.get_user_by_id(str(target_user_id))
return OrgMemberResponse(
user_id=str(updated_member.user_id),

View File

@@ -37,7 +37,7 @@ class ApiKeyStore:
The generated API key
"""
api_key = self.generate_api_key()
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
if user is None:
raise ValueError(f'User not found: {user_id}')
org_id = user.current_org_id
@@ -117,7 +117,7 @@ class ApiKeyStore:
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
"""List all API keys for a user."""
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
if user is None:
raise ValueError(f'User not found: {user_id}')
org_id = user.current_org_id
@@ -132,7 +132,7 @@ class ApiKeyStore:
return [key for key in keys if key.name != 'MCP_API_KEY']
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
if user is None:
raise ValueError(f'User not found: {user_id}')
org_id = user.current_org_id

View File

@@ -1171,7 +1171,7 @@ class LiteLlmManager:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
user = await UserStore.get_user_by_id_async(keycloak_user_id)
user = await UserStore.get_user_by_id(keycloak_user_id)
if not user:
return {}

View File

@@ -875,7 +875,7 @@ class OrgService:
Returns:
bool: True if BYOR export is enabled, False otherwise
"""
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
if not user or not user.current_org_id:
return False
@@ -929,7 +929,7 @@ class OrgService:
# Step 3: Update user's current_org_id
try:
updated_user = UserStore.update_current_org(user_id, org_id)
updated_user = await UserStore.update_current_org(user_id, org_id)
if not updated_user:
raise OrgDatabaseError('User not found')

View File

@@ -236,6 +236,6 @@ class SaasConversationStore(ConversationStore):
# user_id should not be None in SaaS, should we raise?
# Use async version since callers now use asyncio.run_coroutine_threadsafe()
# to dispatch to the main event loop where asyncpg connections work properly.
user = await UserStore.get_user_by_id_async(user_id)
user = await UserStore.get_user_by_id(user_id)
org_id = user.current_org_id if user else None
return SaasConversationStore(str(user_id), org_id, session_maker)

View File

@@ -24,7 +24,7 @@ class SaasSecretsStore(SecretsStore):
async def load(self) -> Secrets | None:
if not self.user_id:
return None
user = await UserStore.get_user_by_id_async(self.user_id)
user = await UserStore.get_user_by_id(self.user_id)
org_id = user.current_org_id if user else None
async with a_session_maker() as session:
@@ -52,7 +52,7 @@ class SaasSecretsStore(SecretsStore):
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
async def store(self, item: Secrets):
user = await UserStore.get_user_by_id_async(self.user_id)
user = await UserStore.get_user_by_id(self.user_id)
if user is None:
raise ValueError(f'User not found: {self.user_id}')
org_id = user.current_org_id

View File

@@ -68,7 +68,7 @@ class SaasSettingsStore(SettingsStore):
return result.scalars().first()
async def load(self) -> Settings | None:
user = await UserStore.get_user_by_id_async(self.user_id)
user = await UserStore.get_user_by_id(self.user_id)
if not user:
logger.error(f'User not found for ID {self.user_id}')
return None

View File

@@ -16,8 +16,8 @@ from server.constants import (
)
from server.logger import logger
from sqlalchemy import select, text
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker, session_maker
from sqlalchemy.orm import selectinload
from storage.database import a_session_maker
from storage.encrypt_utils import (
decrypt_legacy_model,
decrypt_legacy_value,
@@ -30,8 +30,6 @@ from storage.user import User
from storage.user_settings import UserSettings
from utils.identity import resolve_display_name
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
# The max possible time to wait for another process to finish creating a user before retrying
_REDIS_CREATE_TIMEOUT_SECONDS = 30
# The delay to wait for another process to finish creating a user before trying to load again
@@ -50,7 +48,7 @@ class UserStore:
role_id: Optional[int] = None,
) -> User | None:
"""Create a new user."""
with session_maker() as session:
async with a_session_maker() as session:
# create personal org
org = Org(
id=uuid.UUID(user_id),
@@ -105,9 +103,9 @@ class UserStore:
**org_member_kwargs,
)
session.add(org_member)
session.commit()
session.refresh(user)
user.org_members # load org_members
await session.commit()
await session.refresh(user)
await session.refresh(user, ['org_members']) # load org_members
return user
@staticmethod
@@ -176,19 +174,17 @@ class UserStore:
user_settings,
)
decrypted_user_settings = UserSettings(**kwargs)
with session_maker() as session:
async with a_session_maker() as session:
# Check if user has completed billing sessions to enable BYOR export
from storage.billing_session import BillingSession
has_completed_billing = (
session.query(BillingSession)
.filter(
result = await session.execute(
select(BillingSession).filter(
BillingSession.user_id == user_id,
BillingSession.status == 'completed',
)
.first()
is not None
)
has_completed_billing = result.scalars().first() is not None
# create personal org
org = Org(
@@ -297,15 +293,15 @@ class UserStore:
# Mark the old user_settings as migrated instead of deleting
user_settings.already_migrated = True
session.merge(user_settings)
session.flush()
await session.merge(user_settings)
await session.flush()
logger.debug(
'user_store:migrate_user:session_flush_complete',
extra={'user_id': user_id},
)
# need to migrate conversation metadata
session.execute(
await session.execute(
text("""
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
SELECT
@@ -322,7 +318,7 @@ class UserStore:
user_uuid = uuid.UUID(user_id)
# Update stripe_customers
session.execute(
await session.execute(
text(
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
@@ -330,7 +326,7 @@ class UserStore:
)
# Update slack_users
session.execute(
await session.execute(
text(
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
@@ -338,7 +334,7 @@ class UserStore:
)
# Update slack_conversation
session.execute(
await session.execute(
text(
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
@@ -346,13 +342,13 @@ class UserStore:
)
# Update api_keys
session.execute(
await session.execute(
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update custom_secrets
session.execute(
await session.execute(
text(
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
@@ -360,16 +356,16 @@ class UserStore:
)
# Update billing_sessions
session.execute(
await session.execute(
text(
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
session.commit()
session.refresh(user)
user.org_members # load org_members
await session.commit()
await session.refresh(user)
await session.refresh(user, ['org_members']) # load org_members
logger.debug(
'user_store:migrate_user:session_committed',
extra={'user_id': user_id},
@@ -410,14 +406,14 @@ class UserStore:
extra={'user_id': user_id},
)
with session_maker() as session:
async with a_session_maker() as session:
# Get the user and their org_member
user = (
session.query(User)
.options(joinedload(User.org_members))
result = await session.execute(
select(User)
.options(selectinload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
user = result.scalars().first()
if not user:
logger.warning(
'user_store:downgrade_user:user_not_found',
@@ -426,7 +422,10 @@ class UserStore:
return None
# Get the user's personal org (org_id == user_id)
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
result = await session.execute(
select(Org).filter(Org.id == uuid.UUID(user_id))
)
org = result.scalars().first()
if not org:
logger.warning(
'user_store:downgrade_user:org_not_found',
@@ -435,9 +434,10 @@ class UserStore:
return None
# Get org_members for this org - should only be one for personal orgs
org_members = (
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
result = await session.execute(
select(OrgMember).filter(OrgMember.org_id == org.id)
)
org_members = result.scalars().all()
if len(org_members) != 1:
logger.error(
@@ -453,14 +453,13 @@ class UserStore:
org_member = org_members[0]
# Get the user_settings (for migrated users)
user_settings = (
session.query(UserSettings)
.filter(
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(True),
)
.first()
)
user_settings = result.scalars().first()
# For new sign-ups after migration, user_settings won't exist
# Fall back to getting data from org_members
@@ -491,7 +490,7 @@ class UserStore:
'user_store:downgrade_user:created_user_settings_from_org_member',
extra={'user_id': user_id},
)
session.flush()
await session.flush()
# Call LiteLLM downgrade
from storage.lite_llm_manager import LiteLlmManager
@@ -531,7 +530,7 @@ class UserStore:
# Step 3: Copy user_id from conversation_metadata_saas to conversation_metadata
# This ensures any conversations created after migration have their user_id
# preserved in the original table before we delete the saas entries
session.execute(
await session.execute(
text("""
UPDATE conversation_metadata
SET user_id = :user_id
@@ -545,14 +544,14 @@ class UserStore:
)
# Step 4: Delete conversation_metadata_saas entries
session.execute(
await session.execute(
text('DELETE FROM conversation_metadata_saas WHERE user_id = :user_id'),
{'user_id': user_uuid},
)
# Step 5: Reset org_id columns in related tables
# Reset stripe_customers
session.execute(
await session.execute(
text(
'UPDATE stripe_customers SET org_id = NULL WHERE org_id = :org_id'
),
@@ -560,13 +559,13 @@ class UserStore:
)
# Reset slack_users
session.execute(
await session.execute(
text('UPDATE slack_users SET org_id = NULL WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Reset slack_conversation
session.execute(
await session.execute(
text(
'UPDATE slack_conversation SET org_id = NULL WHERE org_id = :org_id'
),
@@ -574,19 +573,19 @@ class UserStore:
)
# Reset api_keys
session.execute(
await session.execute(
text('UPDATE api_keys SET org_id = NULL WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Reset custom_secrets
session.execute(
await session.execute(
text('UPDATE custom_secrets SET org_id = NULL WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Reset billing_sessions
session.execute(
await session.execute(
text(
'UPDATE billing_sessions SET org_id = NULL WHERE org_id = :org_id'
),
@@ -594,19 +593,19 @@ class UserStore:
)
# Step 6: Delete org_member entries for this org
session.execute(
await session.execute(
text('DELETE FROM org_member WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Step 7: Delete the user entry
session.execute(
await session.execute(
text('DELETE FROM "user" WHERE id = :user_id'),
{'user_id': user_uuid},
)
# Delete the org entry
session.execute(
await session.execute(
text('DELETE FROM org WHERE id = :org_id'),
{'org_id': user_uuid},
)
@@ -626,9 +625,9 @@ class UserStore:
if value is not None and not _is_legacy_value_encrypted(value):
setattr(user_settings, key, encrypt_legacy_value(value))
session.merge(user_settings)
await session.merge(user_settings)
session.commit()
await session.commit()
logger.info(
'user_store:downgrade_user:complete',
@@ -637,88 +636,12 @@ class UserStore:
return user_settings
@staticmethod
def get_user_by_id(user_id: str) -> Optional[User]:
"""Get user by Keycloak user ID (sync version).
Note: This method uses call_async_from_sync internally which creates a new
event loop. If you're already in an async context, use get_user_by_id_async
instead to avoid event loop conflicts.
"""
with session_maker() as session:
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
if user:
return user
# Check if we need to migrate from user_settings
while not call_async_from_sync(
UserStore._acquire_user_creation_lock, GENERAL_TIMEOUT, user_id
):
# The user is already being created in another thread / process
logger.info(
'user_store:create_default_settings:waiting_for_lock',
extra={'user_id': user_id},
)
call_async_from_sync(
asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS
)
try:
# Check for user again as migration could have happened while trying to get the lock.
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
if user:
return user
user_settings = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
)
.first()
)
if user_settings:
token_manager = TokenManager()
user_info = call_async_from_sync(
token_manager.get_user_info_from_user_id,
GENERAL_TIMEOUT,
user_id,
)
user = call_async_from_sync(
UserStore.migrate_user,
GENERAL_TIMEOUT,
user_id,
user_settings,
user_info,
)
return user
else:
return None
finally:
call_async_from_sync(
UserStore._release_user_creation_lock, GENERAL_TIMEOUT, user_id
)
@staticmethod
async def get_user_by_id_async(user_id: str) -> Optional[User]:
"""Get user by Keycloak user ID (async version).
This is the preferred method when calling from an async context as it
avoids event loop conflicts that can occur with the sync version.
"""
async def get_user_by_id(user_id: str) -> Optional[User]:
"""Get user by Keycloak user ID."""
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.options(selectinload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
)
user = result.scalars().first()
@@ -729,7 +652,7 @@ class UserStore:
while not await UserStore._acquire_user_creation_lock(user_id):
# The user is already being created in another thread / process
logger.info(
'user_store:get_user_by_id_async:waiting_for_lock',
'user_store:create_default_settings:waiting_for_lock',
extra={'user_id': user_id},
)
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
@@ -738,17 +661,13 @@ class UserStore:
# Check for user again as migration could have happened while trying to get the lock.
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.options(selectinload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
)
user = result.scalars().first()
if user:
return user
logger.info(
'user_store:get_user_by_id_async:start_migration',
extra={'user_id': user_id},
)
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == user_id,
@@ -759,10 +678,6 @@ class UserStore:
if user_settings:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(user_id)
logger.info(
'user_store:get_user_by_id_async:calling_migrate_user',
extra={'user_id': user_id},
)
user = await UserStore.migrate_user(
user_id,
user_settings,
@@ -775,8 +690,8 @@ class UserStore:
await UserStore._release_user_creation_lock(user_id)
@staticmethod
async def get_user_by_email_async(email: str) -> Optional[User]:
"""Get user by email address (async version).
async def get_user_by_email(email: str) -> Optional[User]:
"""Get user by email address.
This method looks up a user by their email address. Note that email
addresses may not be unique across all users in rare cases.
@@ -793,19 +708,20 @@ class UserStore:
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.options(selectinload(User.org_members))
.filter(User.email == email.lower().strip())
)
return result.scalars().first()
@staticmethod
def list_users() -> list[User]:
async def list_users() -> list[User]:
"""List all users."""
with session_maker() as session:
return session.query(User).all()
async with a_session_maker() as session:
result = await session.execute(select(User))
return list(result.scalars().all())
@staticmethod
def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
async def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
"""Update the user's current organization.
Args:
@@ -815,19 +731,17 @@ class UserStore:
Returns:
User: The updated user object, or None if user not found
"""
with session_maker() as session:
user = (
session.query(User)
.filter(User.id == uuid.UUID(user_id))
.with_for_update()
.first()
async with a_session_maker() as session:
result = await session.execute(
select(User).filter(User.id == uuid.UUID(user_id)).with_for_update()
)
user = result.scalars().first()
if not user:
return None
user.current_org_id = org_id
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
return user
@staticmethod

View File

@@ -374,7 +374,7 @@ class TestDeleteByorKeyFromLitellm:
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
@patch('storage.user_store.UserStore.get_user_by_id_async')
@patch('storage.user_store.UserStore.get_user_by_id')
async def test_delete_constructs_alias_from_user(
self, mock_get_user, mock_delete_key
):
@@ -400,7 +400,7 @@ class TestDeleteByorKeyFromLitellm:
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
@patch('storage.user_store.UserStore.get_user_by_id_async')
@patch('storage.user_store.UserStore.get_user_by_id')
async def test_delete_without_user_passes_no_alias(
self, mock_get_user, mock_delete_key
):
@@ -421,7 +421,7 @@ class TestDeleteByorKeyFromLitellm:
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
@patch('storage.user_store.UserStore.get_user_by_id_async')
@patch('storage.user_store.UserStore.get_user_by_id')
async def test_delete_without_org_id_passes_no_alias(
self, mock_get_user, mock_delete_key
):
@@ -444,7 +444,7 @@ class TestDeleteByorKeyFromLitellm:
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
@patch('storage.user_store.UserStore.get_user_by_id_async')
@patch('storage.user_store.UserStore.get_user_by_id')
async def test_delete_returns_false_on_exception(
self, mock_get_user, mock_delete_key
):

View File

@@ -514,7 +514,7 @@ async def test_list_user_orgs_success(mock_app_list):
with (
patch(
'server.routes.orgs.UserStore.get_user_by_id_async',
'server.routes.orgs.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -568,7 +568,7 @@ async def test_list_user_orgs_returns_current_org_id(mock_app_list):
with (
patch(
'server.routes.orgs.UserStore.get_user_by_id_async',
'server.routes.orgs.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -613,7 +613,7 @@ async def test_list_user_orgs_with_pagination(mock_app_list):
with (
patch(
'server.routes.orgs.UserStore.get_user_by_id_async',
'server.routes.orgs.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -648,7 +648,7 @@ async def test_list_user_orgs_empty(mock_app_list):
with (
patch(
'server.routes.orgs.UserStore.get_user_by_id_async',
'server.routes.orgs.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -715,7 +715,7 @@ async def test_list_user_orgs_service_error(mock_app_list):
with (
patch(
'server.routes.orgs.UserStore.get_user_by_id_async',
'server.routes.orgs.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -781,7 +781,7 @@ async def test_list_user_orgs_personal_org_identified(mock_app_list):
with (
patch(
'server.routes.orgs.UserStore.get_user_by_id_async',
'server.routes.orgs.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -820,7 +820,7 @@ async def test_list_user_orgs_team_org_identified(mock_app_list):
with (
patch(
'server.routes.orgs.UserStore.get_user_by_id_async',
'server.routes.orgs.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -869,7 +869,7 @@ async def test_list_user_orgs_mixed_personal_and_team(mock_app_list):
with (
patch(
'server.routes.orgs.UserStore.get_user_by_id_async',
'server.routes.orgs.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -941,7 +941,7 @@ async def test_list_user_orgs_all_fields_present(mock_app_list):
with (
patch(
'server.routes.orgs.UserStore.get_user_by_id_async',
'server.routes.orgs.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(

View File

@@ -718,7 +718,7 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -767,7 +767,7 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -815,7 +815,7 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -967,7 +967,7 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(
@@ -1149,7 +1149,7 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -1251,11 +1251,12 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_member_service.UserStore.update_current_org'
'server.services.org_member_service.UserStore.update_current_org',
new_callable=AsyncMock,
) as mock_update_org,
):
mock_get_member.side_effect = [
@@ -1274,7 +1275,9 @@ class TestOrgMemberServiceRemoveOrgMember:
# Assert
assert success is True
assert error is None
mock_update_org.assert_called_once_with(str(target_user_id), target_user_id)
mock_update_org.assert_awaited_once_with(
str(target_user_id), target_user_id
)
@pytest.mark.asyncio
async def test_remove_member_does_not_update_current_org_id_when_not_matching(
@@ -1307,11 +1310,12 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_member_service.UserStore.update_current_org'
'server.services.org_member_service.UserStore.update_current_org',
new_callable=AsyncMock,
) as mock_update_org,
):
mock_get_member.side_effect = [
@@ -1330,7 +1334,7 @@ class TestOrgMemberServiceRemoveOrgMember:
# Assert
assert success is True
assert error is None
mock_update_org.assert_not_called()
mock_update_org.assert_not_awaited()
@pytest.mark.asyncio
async def test_remove_member_succeeds_when_user_not_found_after_removal(
@@ -1359,11 +1363,12 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_member_service.UserStore.update_current_org'
'server.services.org_member_service.UserStore.update_current_org',
new_callable=AsyncMock,
) as mock_update_org,
):
mock_get_member.side_effect = [
@@ -1382,7 +1387,7 @@ class TestOrgMemberServiceRemoveOrgMember:
# Assert
assert success is True
assert error is None
mock_update_org.assert_not_called()
mock_update_org.assert_not_awaited()
@pytest.mark.asyncio
async def test_successful_removal_calls_litellm_remove_user_from_team(
@@ -1411,7 +1416,7 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(
@@ -1465,7 +1470,7 @@ class TestOrgMemberServiceRemoveOrgMember:
new_callable=AsyncMock,
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(
@@ -1632,7 +1637,7 @@ class TestOrgMemberServiceUpdateOrgMember:
new_callable=AsyncMock,
) as mock_update,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -1693,7 +1698,7 @@ class TestOrgMemberServiceUpdateOrgMember:
new_callable=AsyncMock,
) as mock_update,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -1752,7 +1757,7 @@ class TestOrgMemberServiceUpdateOrgMember:
new_callable=AsyncMock,
) as mock_update,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -1815,7 +1820,7 @@ class TestOrgMemberServiceUpdateOrgMember:
new_callable=AsyncMock,
) as mock_update,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch.object(
@@ -2036,7 +2041,7 @@ class TestOrgMemberServiceUpdateOrgMember:
new_callable=AsyncMock,
) as mock_get_role,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -2253,7 +2258,7 @@ class TestOrgMemberServiceGetMe:
new_callable=AsyncMock,
) as mock_get_role,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -2340,7 +2345,7 @@ class TestOrgMemberServiceGetMe:
new_callable=AsyncMock,
) as mock_get_role,
patch(
'server.services.org_member_service.UserStore.get_user_by_id_async',
'server.services.org_member_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):

View File

@@ -56,7 +56,7 @@ def test_generate_api_key(api_key_store):
@pytest.mark.asyncio
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
@patch('storage.api_key_store.UserStore.get_user_by_id')
async def test_create_api_key(
mock_get_user, api_key_store, async_session_maker, mock_user
):
@@ -324,7 +324,7 @@ async def test_delete_api_key_by_id(api_key_store, async_session_maker):
@pytest.mark.asyncio
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
@patch('storage.api_key_store.UserStore.get_user_by_id')
async def test_list_api_keys(
mock_get_user, api_key_store, async_session_maker, mock_user
):
@@ -377,7 +377,7 @@ async def test_list_api_keys(
@pytest.mark.asyncio
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
@patch('storage.api_key_store.UserStore.get_user_by_id')
async def test_retrieve_mcp_api_key(
mock_get_user, api_key_store, async_session_maker, mock_user
):
@@ -416,7 +416,7 @@ async def test_retrieve_mcp_api_key(
@pytest.mark.asyncio
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
@patch('storage.api_key_store.UserStore.get_user_by_id')
async def test_retrieve_mcp_api_key_not_found(
mock_get_user, api_key_store, async_session_maker, mock_user
):

View File

@@ -158,7 +158,7 @@ async def test_keycloak_callback_user_not_allowed(
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = None
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
@@ -197,7 +197,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(
mock_user.accepted_tos = '2025-01-01'
# Setup UserStore mocks
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
@@ -273,7 +273,7 @@ async def test_keycloak_callback_email_not_verified(
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -324,7 +324,7 @@ async def test_keycloak_callback_email_not_verified_missing_field(
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -368,7 +368,7 @@ async def test_keycloak_callback_success_without_offline_token(
mock_user.accepted_tos = '2025-01-01'
# Setup UserStore mocks
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
@@ -616,7 +616,7 @@ async def test_keycloak_callback_blocked_email_domain(
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -683,7 +683,7 @@ async def test_keycloak_callback_allowed_email_domain(
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -750,7 +750,7 @@ async def test_keycloak_callback_domain_blocking_inactive(
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -813,7 +813,7 @@ async def test_keycloak_callback_missing_email(mock_request, create_keycloak_use
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -862,7 +862,7 @@ async def test_keycloak_callback_duplicate_email_detected(
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -910,7 +910,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -971,7 +971,7 @@ async def test_keycloak_callback_duplicate_check_exception(
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1032,7 +1032,7 @@ async def test_keycloak_callback_no_duplicate_email(
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1096,7 +1096,7 @@ async def test_keycloak_callback_no_email_in_user_info(
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1254,7 +1254,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1322,7 +1322,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1408,7 +1408,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1497,7 +1497,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1585,7 +1585,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1670,7 +1670,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1752,7 +1752,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1822,7 +1822,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1896,7 +1896,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -1970,7 +1970,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
@@ -2020,7 +2020,7 @@ async def test_keycloak_callback_calls_backfill_user_email_for_existing_user(
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()

View File

@@ -103,7 +103,7 @@ async def test_get_credits_lite_llm_error():
with (
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
patch(
'storage.user_store.UserStore.get_user_by_id_async',
'storage.user_store.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id='mock_org_id'),
),
@@ -135,7 +135,7 @@ async def test_get_credits_success():
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
patch('httpx.AsyncClient', return_value=mock_client),
patch(
'storage.user_store.UserStore.get_user_by_id_async',
'storage.user_store.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id='mock_org_id'),
),
@@ -338,7 +338,7 @@ async def test_success_callback_success(async_session_maker, test_org, test_user
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
'storage.user_store.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id=test_org.id),
),
@@ -410,7 +410,7 @@ async def test_success_callback_lite_llm_error(
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
'storage.user_store.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id=test_org.id),
),
@@ -464,7 +464,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
'storage.user_store.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id=test_org.id),
),

View File

@@ -1100,9 +1100,7 @@ class TestLiteLlmManager:
mock_org_member.org_id = 'test-ord-id'
mock_org_member.llm_api_key = 'test-api-key'
mock_user.org_members = [mock_org_member]
mock_user_store.get_user_by_id_async = AsyncMock(
return_value=mock_user
)
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
result = await LiteLlmManager._get_key_info(
mock_http_client, 'test-ord-id', 'test-user-id'
@@ -1118,7 +1116,7 @@ class TestLiteLlmManager:
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
with patch('storage.user_store.UserStore') as mock_user_store:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=None)
mock_user_store.get_user_by_id = AsyncMock(return_value=None)
result = await LiteLlmManager._get_key_info(
mock_http_client, 'test-ord-id', 'test-user-id'

View File

@@ -70,7 +70,7 @@ class TestAcceptInvitationEmailValidation:
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
'server.services.org_invitation_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
):
@@ -106,7 +106,7 @@ class TestAcceptInvitationEmailValidation:
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
'server.services.org_invitation_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(
@@ -174,7 +174,7 @@ class TestAcceptInvitationEmailValidation:
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
'server.services.org_invitation_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(
@@ -220,7 +220,7 @@ class TestAcceptInvitationEmailValidation:
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
'server.services.org_invitation_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(

View File

@@ -1870,7 +1870,7 @@ async def test_check_byor_export_enabled_returns_true_when_enabled():
with (
patch(
'storage.org_service.UserStore.get_user_by_id_async',
'storage.org_service.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -1905,7 +1905,7 @@ async def test_check_byor_export_enabled_returns_false_when_disabled():
with (
patch(
'storage.org_service.UserStore.get_user_by_id_async',
'storage.org_service.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -1932,7 +1932,7 @@ async def test_check_byor_export_enabled_returns_false_when_user_not_found():
user_id = 'nonexistent-user'
with patch(
'storage.org_service.UserStore.get_user_by_id_async',
'storage.org_service.UserStore.get_user_by_id',
AsyncMock(return_value=None),
):
# Act
@@ -1956,7 +1956,7 @@ async def test_check_byor_export_enabled_returns_false_when_no_current_org():
mock_user.current_org_id = None
with patch(
'storage.org_service.UserStore.get_user_by_id_async',
'storage.org_service.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
):
# Act
@@ -1982,7 +1982,7 @@ async def test_check_byor_export_enabled_returns_false_when_org_not_found():
with (
patch(
'storage.org_service.UserStore.get_user_by_id_async',
'storage.org_service.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
),
patch(
@@ -2025,6 +2025,7 @@ async def test_switch_org_success():
patch('storage.org_service.OrgService.is_org_member', return_value=True),
patch(
'storage.org_service.UserStore.update_current_org',
new_callable=AsyncMock,
return_value=mock_updated_user,
),
):
@@ -2116,7 +2117,11 @@ async def test_switch_org_user_not_found():
return_value=mock_org,
),
patch('storage.org_service.OrgService.is_org_member', return_value=True),
patch('storage.org_service.UserStore.update_current_org', return_value=None),
patch(
'storage.org_service.UserStore.update_current_org',
new_callable=AsyncMock,
return_value=None,
),
):
# Act & Assert
with pytest.raises(OrgDatabaseError) as exc_info:

View File

@@ -169,14 +169,14 @@ async def test_exists(session_maker):
class TestGetInstance:
"""Tests for SaasConversationStore.get_instance method.
The get_instance method uses async UserStore.get_user_by_id_async because
The get_instance method uses async UserStore.get_user_by_id because
callers now use asyncio.run_coroutine_threadsafe() to dispatch to the main
event loop where asyncpg connections work properly.
"""
@pytest.mark.asyncio
async def test_get_instance_uses_async_get_user_by_id(self):
"""Verify get_instance calls the async get_user_by_id_async for proper event loop handling."""
"""Verify get_instance calls the async get_user_by_id for proper event loop handling."""
# Arrange
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
mock_user = MagicMock(spec=User)
@@ -184,7 +184,7 @@ class TestGetInstance:
mock_config = MagicMock(spec=OpenHandsConfig)
with patch(
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
'storage.saas_conversation_store.UserStore.get_user_by_id',
AsyncMock(return_value=mock_user),
) as mock_async_get_user, patch(
'storage.saas_conversation_store.session_maker'
@@ -205,7 +205,7 @@ class TestGetInstance:
mock_config = MagicMock(spec=OpenHandsConfig)
with patch(
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
'storage.saas_conversation_store.UserStore.get_user_by_id',
AsyncMock(return_value=None),
), patch('storage.saas_conversation_store.session_maker'):
# Act

View File

@@ -44,7 +44,7 @@ def secrets_store(async_session_maker, mock_config):
class TestSaasSecretsStore:
@pytest.mark.asyncio
@patch(
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
'storage.saas_secrets_store.UserStore.get_user_by_id',
new_callable=AsyncMock,
)
async def test_store_and_load(self, mock_get_user, secrets_store, mock_user):
@@ -84,7 +84,7 @@ class TestSaasSecretsStore:
@pytest.mark.asyncio
@patch(
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
'storage.saas_secrets_store.UserStore.get_user_by_id',
new_callable=AsyncMock,
)
async def test_encryption_decryption(self, mock_get_user, secrets_store, mock_user):
@@ -186,7 +186,7 @@ class TestSaasSecretsStore:
@pytest.mark.asyncio
@patch(
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
'storage.saas_secrets_store.UserStore.get_user_by_id',
new_callable=AsyncMock,
)
async def test_update_existing_secrets(

View File

@@ -35,9 +35,9 @@ def mock_check_idp():
@pytest.fixture
def mock_user_store():
"""Mock UserStore.get_user_by_id_async to return None by default."""
"""Mock UserStore.get_user_by_id to return None by default."""
with patch(
'server.routes.user.UserStore.get_user_by_id_async',
'server.routes.user.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=None,
) as mock_fn:

View File

@@ -101,11 +101,11 @@ async def test_create_default_settings_with_litellm(mock_litellm_api):
assert settings.llm_base_url == 'http://test.url'
# --- Tests for get_user_by_id_async ---
# --- Tests for get_user_by_id ---
@pytest.mark.asyncio
async def test_get_user_by_id_async_existing_user(async_session_maker):
async def test_get_user_by_id_existing_user(async_session_maker):
"""Test retrieving an existing user by ID."""
user_id = uuid.uuid4()
org_id = uuid.uuid4()
@@ -120,7 +120,7 @@ async def test_get_user_by_id_async_existing_user(async_session_maker):
# Test retrieval with patched session maker
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_user_by_id_async(str(user_id))
result = await UserStore.get_user_by_id(str(user_id))
assert result is not None
assert result.id == user_id
@@ -128,8 +128,8 @@ async def test_get_user_by_id_async_existing_user(async_session_maker):
@pytest.mark.asyncio
async def test_get_user_by_id_async_user_not_found(async_session_maker):
"""Test that get_user_by_id_async returns None for non-existent user."""
async def test_get_user_by_id_user_not_found(async_session_maker):
"""Test that get_user_by_id returns None for non-existent user."""
non_existent_id = str(uuid.uuid4())
with patch('storage.user_store.a_session_maker', async_session_maker):
@@ -138,16 +138,16 @@ async def test_get_user_by_id_async_user_not_found(async_session_maker):
patch.object(UserStore, '_acquire_user_creation_lock', return_value=True),
patch.object(UserStore, '_release_user_creation_lock', return_value=True),
):
result = await UserStore.get_user_by_id_async(non_existent_id)
result = await UserStore.get_user_by_id(non_existent_id)
assert result is None
# --- Tests for get_user_by_email_async ---
# --- Tests for get_user_by_email ---
@pytest.mark.asyncio
async def test_get_user_by_email_async_existing_user(async_session_maker):
async def test_get_user_by_email_existing_user(async_session_maker):
"""Test retrieving a user by email."""
user_id = uuid.uuid4()
org_id = uuid.uuid4()
@@ -163,7 +163,7 @@ async def test_get_user_by_email_async_existing_user(async_session_maker):
# Test retrieval
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_user_by_email_async(email)
result = await UserStore.get_user_by_email(email)
assert result is not None
assert result.id == user_id
@@ -171,28 +171,28 @@ async def test_get_user_by_email_async_existing_user(async_session_maker):
@pytest.mark.asyncio
async def test_get_user_by_email_async_not_found(async_session_maker):
"""Test that get_user_by_email_async returns None for non-existent email."""
async def test_get_user_by_email_not_found(async_session_maker):
"""Test that get_user_by_email returns None for non-existent email."""
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_user_by_email_async('nonexistent@example.com')
result = await UserStore.get_user_by_email('nonexistent@example.com')
assert result is None
@pytest.mark.asyncio
async def test_get_user_by_email_async_empty_email(async_session_maker):
"""Test that get_user_by_email_async returns None for empty email."""
async def test_get_user_by_email_empty_email(async_session_maker):
"""Test that get_user_by_email returns None for empty email."""
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_user_by_email_async('')
result = await UserStore.get_user_by_email('')
assert result is None
@pytest.mark.asyncio
async def test_get_user_by_email_async_none_email(async_session_maker):
"""Test that get_user_by_email_async returns None for None email."""
async def test_get_user_by_email_none_email(async_session_maker):
"""Test that get_user_by_email returns None for None email."""
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_user_by_email_async(None)
result = await UserStore.get_user_by_email(None)
assert result is None
@@ -543,47 +543,50 @@ async def test_backfill_contact_name_no_real_name(async_session_maker):
assert org.contact_name == 'jdoe'
# --- Tests for update_current_org (sync) ---
# --- Tests for update_current_org ---
def test_update_current_org_success(session_maker):
@pytest.mark.asyncio
async def test_update_current_org_success(async_session_maker):
"""Test updating a user's current organization."""
user_id = uuid.uuid4()
initial_org_id = uuid.uuid4()
new_org_id = uuid.uuid4()
# Create test data
with session_maker() as session:
async with async_session_maker() as session:
org1 = Org(id=initial_org_id, name='org1')
org2 = Org(id=new_org_id, name='org2')
session.add_all([org1, org2])
user = User(id=user_id, current_org_id=initial_org_id)
session.add(user)
session.commit()
await session.commit()
# Update current org
with patch('storage.user_store.session_maker', session_maker):
result = UserStore.update_current_org(str(user_id), new_org_id)
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.update_current_org(str(user_id), new_org_id)
assert result is not None
assert result.current_org_id == new_org_id
def test_update_current_org_user_not_found(session_maker):
@pytest.mark.asyncio
async def test_update_current_org_user_not_found(async_session_maker):
"""Test that update_current_org returns None for non-existent user."""
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
with patch('storage.user_store.session_maker', session_maker):
result = UserStore.update_current_org(user_id, org_id)
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.update_current_org(user_id, org_id)
assert result is None
# --- Tests for list_users (sync) ---
# --- Tests for list_users ---
def test_list_users(session_maker):
@pytest.mark.asyncio
async def test_list_users(async_session_maker):
"""Test listing all users."""
user_id1 = uuid.uuid4()
user_id2 = uuid.uuid4()
@@ -591,18 +594,18 @@ def test_list_users(session_maker):
org_id2 = uuid.uuid4()
# Create test data
with session_maker() as session:
async with async_session_maker() as session:
org1 = Org(id=org_id1, name='org1')
org2 = Org(id=org_id2, name='org2')
session.add_all([org1, org2])
user1 = User(id=user_id1, current_org_id=org_id1)
user2 = User(id=user_id2, current_org_id=org_id2)
session.add_all([user1, user2])
session.commit()
await session.commit()
# List users
with patch('storage.user_store.session_maker', session_maker):
users = UserStore.list_users()
with patch('storage.user_store.a_session_maker', async_session_maker):
users = await UserStore.list_users()
assert len(users) >= 2
user_ids = [user.id for user in users]