From 8dac1095d73af2f3b8eeb0a77dccb0c2ee1ef8f2 Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Tue, 3 Mar 2026 17:51:53 -0700 Subject: [PATCH] Refactor user_store.py to use async database sessions (#13187) Co-authored-by: openhands --- enterprise/server/routes/api_keys.py | 8 +- enterprise/server/routes/auth.py | 2 +- enterprise/server/routes/billing.py | 4 +- enterprise/server/routes/integration/slack.py | 2 +- enterprise/server/routes/orgs.py | 2 +- enterprise/server/routes/user.py | 2 +- .../server/services/org_invitation_service.py | 6 +- .../server/services/org_member_service.py | 10 +- enterprise/storage/api_key_store.py | 6 +- enterprise/storage/lite_llm_manager.py | 2 +- enterprise/storage/org_service.py | 4 +- enterprise/storage/saas_conversation_store.py | 2 +- enterprise/storage/saas_secrets_store.py | 4 +- enterprise/storage/saas_settings_store.py | 2 +- enterprise/storage/user_store.py | 226 ++++++------------ .../tests/unit/server/routes/test_api_keys.py | 8 +- .../tests/unit/server/routes/test_orgs.py | 18 +- .../services/test_org_member_service.py | 51 ++-- enterprise/tests/unit/test_api_key_store.py | 8 +- enterprise/tests/unit/test_auth_routes.py | 50 ++-- enterprise/tests/unit/test_billing.py | 10 +- .../tests/unit/test_lite_llm_manager.py | 6 +- .../tests/unit/test_org_invitation_service.py | 8 +- enterprise/tests/unit/test_org_service.py | 17 +- .../unit/test_saas_conversation_store.py | 8 +- .../tests/unit/test_saas_secrets_store.py | 6 +- .../tests/unit/test_user_route_fallback.py | 4 +- enterprise/tests/unit/test_user_store.py | 69 +++--- 28 files changed, 235 insertions(+), 310 deletions(-) diff --git a/enterprise/server/routes/api_keys.py b/enterprise/server/routes/api_keys.py index 5b433aef98..1e3f8a0d51 100644 --- a/enterprise/server/routes/api_keys.py +++ b/enterprise/server/routes/api_keys.py @@ -17,7 +17,7 @@ from openhands.server.user_auth import get_user_id # Helper functions for BYOR API key management async def get_byor_key_from_db(user_id: str) -> str | None: """Get the BYOR key from the database for a user.""" - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) if not user: return None @@ -36,7 +36,7 @@ async def get_byor_key_from_db(user_id: str) -> str | None: async def store_byor_key_in_db(user_id: str, key: str) -> None: """Store the BYOR key in the database for a user.""" - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) if not user: return None @@ -55,7 +55,7 @@ async def store_byor_key_in_db(user_id: str, key: str) -> None: async def generate_byor_key(user_id: str) -> str | None: """Generate a new BYOR key for a user.""" try: - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) if not user: return None current_org_id = str(user.current_org_id) @@ -98,7 +98,7 @@ async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool: """ try: # Get user to construct the key alias - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) key_alias = None if user and user.current_org_id: key_alias = f'BYOR Key - user {user_id}, org {user.current_org_id}' diff --git a/enterprise/server/routes/auth.py b/enterprise/server/routes/auth.py index fa298f73e4..7c596cd558 100644 --- a/enterprise/server/routes/auth.py +++ b/enterprise/server/routes/auth.py @@ -204,7 +204,7 @@ async def keycloak_callback( email = user_info.email user_id = user_info.sub user_info_dict = user_info.model_dump(exclude_none=True) - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) if not user: user = await UserStore.create_user(user_id, user_info_dict) else: diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index 942b843cb5..51e5ee3fb1 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -90,7 +90,7 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float: async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse: if not stripe_service.STRIPE_API_KEY: return GetCreditsResponse() - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) if user is None: raise HTTPException(status.HTTP_404_NOT_FOUND, detail='User not found') user_team_info = await LiteLlmManager.get_user_team_info( @@ -248,7 +248,7 @@ async def success_callback(session_id: str, request: Request): ) raise HTTPException(status.HTTP_400_BAD_REQUEST) - user = await UserStore.get_user_by_id_async(billing_session.user_id) + user = await UserStore.get_user_by_id(billing_session.user_id) if user is None: raise HTTPException(status.HTTP_404_NOT_FOUND, detail='User not found') user_team_info = await LiteLlmManager.get_user_team_info( diff --git a/enterprise/server/routes/integration/slack.py b/enterprise/server/routes/integration/slack.py index 221c966eb1..dc98552bc3 100644 --- a/enterprise/server/routes/integration/slack.py +++ b/enterprise/server/routes/integration/slack.py @@ -197,7 +197,7 @@ async def keycloak_callback( user_info = await token_manager.get_user_info(keycloak_access_token) keycloak_user_id = user_info.sub - user = await UserStore.get_user_by_id_async(keycloak_user_id) + user = await UserStore.get_user_by_id(keycloak_user_id) if not user: return _html_response( title='Failed to authenticate.', diff --git a/enterprise/server/routes/orgs.py b/enterprise/server/routes/orgs.py index 8c1abb18ed..f67b1f45f2 100644 --- a/enterprise/server/routes/orgs.py +++ b/enterprise/server/routes/orgs.py @@ -99,7 +99,7 @@ async def list_user_orgs( try: # Fetch user to get current_org_id - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) current_org_id = ( str(user.current_org_id) if user and user.current_org_id else None ) diff --git a/enterprise/server/routes/user.py b/enterprise/server/routes/user.py index b0e9a6de7b..908f96281b 100644 --- a/enterprise/server/routes/user.py +++ b/enterprise/server/routes/user.py @@ -115,7 +115,7 @@ async def saas_get_user( email = user_info.email sub = user_info.sub if sub: - db_user = await UserStore.get_user_by_id_async(sub) + db_user = await UserStore.get_user_by_id(sub) if db_user and db_user.email is not None: email = db_user.email diff --git a/enterprise/server/services/org_invitation_service.py b/enterprise/server/services/org_invitation_service.py index 3ef0d9d8fa..bae78f1a7b 100644 --- a/enterprise/server/services/org_invitation_service.py +++ b/enterprise/server/services/org_invitation_service.py @@ -106,7 +106,7 @@ class OrgInvitationService: raise ValueError(f'Invalid role: {role_name}') # Step 5: Check if user is already a member (by email) - existing_user = await UserStore.get_user_by_email_async(email) + existing_user = await UserStore.get_user_by_email(email) if existing_user: existing_member = await OrgMemberStore.get_org_member( org_id, existing_user.id @@ -127,7 +127,7 @@ class OrgInvitationService: # Step 7: Send invitation email try: # Get inviter info for the email - inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id)) + inviter_user = await UserStore.get_user_by_id(str(inviter_member.user_id)) inviter_name = 'A team member' if inviter_user and inviter_user.email: inviter_name = inviter_user.email.split('@')[0] @@ -308,7 +308,7 @@ class OrgInvitationService: raise InvitationExpiredError('Invitation has expired') # Step 2.5: Verify user email matches invitation email - user = await UserStore.get_user_by_id_async(str(user_id)) + user = await UserStore.get_user_by_id(str(user_id)) if not user: raise InvitationInvalidError('User not found') diff --git a/enterprise/server/services/org_member_service.py b/enterprise/server/services/org_member_service.py index 7168d0954e..c20292b93e 100644 --- a/enterprise/server/services/org_member_service.py +++ b/enterprise/server/services/org_member_service.py @@ -56,7 +56,7 @@ class OrgMemberService: raise RoleNotFoundError(org_member.role_id) # Get user email - user = await UserStore.get_user_by_id_async(str(user_id)) + user = await UserStore.get_user_by_id(str(user_id)) email = user.email if user and user.email else '' return MeResponse.from_org_member(org_member, role, email) @@ -218,10 +218,10 @@ class OrgMemberService: return False, 'removal_failed' # Update user's current_org_id if it points to the org they were removed from - user = await UserStore.get_user_by_id_async(str(target_user_id)) + user = await UserStore.get_user_by_id(str(target_user_id)) if user and user.current_org_id == org_id: # Set current_org_id to personal workspace (org.id == user.id) - UserStore.update_current_org(str(target_user_id), target_user_id) + await UserStore.update_current_org(str(target_user_id), target_user_id) # If database removal succeeded, also remove from LiteLLM team try: @@ -308,7 +308,7 @@ class OrgMemberService: # If no role change requested, return current state if new_role_name is None: - user = await UserStore.get_user_by_id_async(str(target_user_id)) + user = await UserStore.get_user_by_id(str(target_user_id)) return OrgMemberResponse( user_id=str(target_membership.user_id), email=user.email if user else None, @@ -347,7 +347,7 @@ class OrgMemberService: raise MemberUpdateError('Failed to update member') # Get user email for response - user = await UserStore.get_user_by_id_async(str(target_user_id)) + user = await UserStore.get_user_by_id(str(target_user_id)) return OrgMemberResponse( user_id=str(updated_member.user_id), diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index c6a4cbd05d..d514b70693 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -37,7 +37,7 @@ class ApiKeyStore: The generated API key """ api_key = self.generate_api_key() - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) if user is None: raise ValueError(f'User not found: {user_id}') org_id = user.current_org_id @@ -117,7 +117,7 @@ class ApiKeyStore: async def list_api_keys(self, user_id: str) -> list[ApiKey]: """List all API keys for a user.""" - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) if user is None: raise ValueError(f'User not found: {user_id}') org_id = user.current_org_id @@ -132,7 +132,7 @@ class ApiKeyStore: return [key for key in keys if key.name != 'MCP_API_KEY'] async def retrieve_mcp_api_key(self, user_id: str) -> str | None: - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) if user is None: raise ValueError(f'User not found: {user_id}') org_id = user.current_org_id diff --git a/enterprise/storage/lite_llm_manager.py b/enterprise/storage/lite_llm_manager.py index 8cf8b4e998..fc45fb2271 100644 --- a/enterprise/storage/lite_llm_manager.py +++ b/enterprise/storage/lite_llm_manager.py @@ -1171,7 +1171,7 @@ class LiteLlmManager: if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None: logger.warning('LiteLLM API configuration not found') return None - user = await UserStore.get_user_by_id_async(keycloak_user_id) + user = await UserStore.get_user_by_id(keycloak_user_id) if not user: return {} diff --git a/enterprise/storage/org_service.py b/enterprise/storage/org_service.py index a5108137dc..bc021c525b 100644 --- a/enterprise/storage/org_service.py +++ b/enterprise/storage/org_service.py @@ -875,7 +875,7 @@ class OrgService: Returns: bool: True if BYOR export is enabled, False otherwise """ - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) if not user or not user.current_org_id: return False @@ -929,7 +929,7 @@ class OrgService: # Step 3: Update user's current_org_id try: - updated_user = UserStore.update_current_org(user_id, org_id) + updated_user = await UserStore.update_current_org(user_id, org_id) if not updated_user: raise OrgDatabaseError('User not found') diff --git a/enterprise/storage/saas_conversation_store.py b/enterprise/storage/saas_conversation_store.py index eec6961d02..b8ac843e13 100644 --- a/enterprise/storage/saas_conversation_store.py +++ b/enterprise/storage/saas_conversation_store.py @@ -236,6 +236,6 @@ class SaasConversationStore(ConversationStore): # user_id should not be None in SaaS, should we raise? # Use async version since callers now use asyncio.run_coroutine_threadsafe() # to dispatch to the main event loop where asyncpg connections work properly. - user = await UserStore.get_user_by_id_async(user_id) + user = await UserStore.get_user_by_id(user_id) org_id = user.current_org_id if user else None return SaasConversationStore(str(user_id), org_id, session_maker) diff --git a/enterprise/storage/saas_secrets_store.py b/enterprise/storage/saas_secrets_store.py index ccde502cc6..3b2820485b 100644 --- a/enterprise/storage/saas_secrets_store.py +++ b/enterprise/storage/saas_secrets_store.py @@ -24,7 +24,7 @@ class SaasSecretsStore(SecretsStore): async def load(self) -> Secrets | None: if not self.user_id: return None - user = await UserStore.get_user_by_id_async(self.user_id) + user = await UserStore.get_user_by_id(self.user_id) org_id = user.current_org_id if user else None async with a_session_maker() as session: @@ -52,7 +52,7 @@ class SaasSecretsStore(SecretsStore): return Secrets(custom_secrets=kwargs) # type: ignore[arg-type] async def store(self, item: Secrets): - user = await UserStore.get_user_by_id_async(self.user_id) + user = await UserStore.get_user_by_id(self.user_id) if user is None: raise ValueError(f'User not found: {self.user_id}') org_id = user.current_org_id diff --git a/enterprise/storage/saas_settings_store.py b/enterprise/storage/saas_settings_store.py index 3653f83574..bd43fa1a7a 100644 --- a/enterprise/storage/saas_settings_store.py +++ b/enterprise/storage/saas_settings_store.py @@ -68,7 +68,7 @@ class SaasSettingsStore(SettingsStore): return result.scalars().first() async def load(self) -> Settings | None: - user = await UserStore.get_user_by_id_async(self.user_id) + user = await UserStore.get_user_by_id(self.user_id) if not user: logger.error(f'User not found for ID {self.user_id}') return None diff --git a/enterprise/storage/user_store.py b/enterprise/storage/user_store.py index 8c20bd013c..4f55d8650c 100644 --- a/enterprise/storage/user_store.py +++ b/enterprise/storage/user_store.py @@ -16,8 +16,8 @@ from server.constants import ( ) from server.logger import logger from sqlalchemy import select, text -from sqlalchemy.orm import joinedload -from storage.database import a_session_maker, session_maker +from sqlalchemy.orm import selectinload +from storage.database import a_session_maker from storage.encrypt_utils import ( decrypt_legacy_model, decrypt_legacy_value, @@ -30,8 +30,6 @@ from storage.user import User from storage.user_settings import UserSettings from utils.identity import resolve_display_name -from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync - # The max possible time to wait for another process to finish creating a user before retrying _REDIS_CREATE_TIMEOUT_SECONDS = 30 # The delay to wait for another process to finish creating a user before trying to load again @@ -50,7 +48,7 @@ class UserStore: role_id: Optional[int] = None, ) -> User | None: """Create a new user.""" - with session_maker() as session: + async with a_session_maker() as session: # create personal org org = Org( id=uuid.UUID(user_id), @@ -105,9 +103,9 @@ class UserStore: **org_member_kwargs, ) session.add(org_member) - session.commit() - session.refresh(user) - user.org_members # load org_members + await session.commit() + await session.refresh(user) + await session.refresh(user, ['org_members']) # load org_members return user @staticmethod @@ -176,19 +174,17 @@ class UserStore: user_settings, ) decrypted_user_settings = UserSettings(**kwargs) - with session_maker() as session: + async with a_session_maker() as session: # Check if user has completed billing sessions to enable BYOR export from storage.billing_session import BillingSession - has_completed_billing = ( - session.query(BillingSession) - .filter( + result = await session.execute( + select(BillingSession).filter( BillingSession.user_id == user_id, BillingSession.status == 'completed', ) - .first() - is not None ) + has_completed_billing = result.scalars().first() is not None # create personal org org = Org( @@ -297,15 +293,15 @@ class UserStore: # Mark the old user_settings as migrated instead of deleting user_settings.already_migrated = True - session.merge(user_settings) - session.flush() + await session.merge(user_settings) + await session.flush() logger.debug( 'user_store:migrate_user:session_flush_complete', extra={'user_id': user_id}, ) # need to migrate conversation metadata - session.execute( + await session.execute( text(""" INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id) SELECT @@ -322,7 +318,7 @@ class UserStore: user_uuid = uuid.UUID(user_id) # Update stripe_customers - session.execute( + await session.execute( text( 'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id' ), @@ -330,7 +326,7 @@ class UserStore: ) # Update slack_users - session.execute( + await session.execute( text( 'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id' ), @@ -338,7 +334,7 @@ class UserStore: ) # Update slack_conversation - session.execute( + await session.execute( text( 'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id' ), @@ -346,13 +342,13 @@ class UserStore: ) # Update api_keys - session.execute( + await session.execute( text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'), {'org_id': user_uuid, 'user_id': user_uuid}, ) # Update custom_secrets - session.execute( + await session.execute( text( 'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id' ), @@ -360,16 +356,16 @@ class UserStore: ) # Update billing_sessions - session.execute( + await session.execute( text( 'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id' ), {'org_id': user_uuid, 'user_id': user_uuid}, ) - session.commit() - session.refresh(user) - user.org_members # load org_members + await session.commit() + await session.refresh(user) + await session.refresh(user, ['org_members']) # load org_members logger.debug( 'user_store:migrate_user:session_committed', extra={'user_id': user_id}, @@ -410,14 +406,14 @@ class UserStore: extra={'user_id': user_id}, ) - with session_maker() as session: + async with a_session_maker() as session: # Get the user and their org_member - user = ( - session.query(User) - .options(joinedload(User.org_members)) + result = await session.execute( + select(User) + .options(selectinload(User.org_members)) .filter(User.id == uuid.UUID(user_id)) - .first() ) + user = result.scalars().first() if not user: logger.warning( 'user_store:downgrade_user:user_not_found', @@ -426,7 +422,10 @@ class UserStore: return None # Get the user's personal org (org_id == user_id) - org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first() + result = await session.execute( + select(Org).filter(Org.id == uuid.UUID(user_id)) + ) + org = result.scalars().first() if not org: logger.warning( 'user_store:downgrade_user:org_not_found', @@ -435,9 +434,10 @@ class UserStore: return None # Get org_members for this org - should only be one for personal orgs - org_members = ( - session.query(OrgMember).filter(OrgMember.org_id == org.id).all() + result = await session.execute( + select(OrgMember).filter(OrgMember.org_id == org.id) ) + org_members = result.scalars().all() if len(org_members) != 1: logger.error( @@ -453,14 +453,13 @@ class UserStore: org_member = org_members[0] # Get the user_settings (for migrated users) - user_settings = ( - session.query(UserSettings) - .filter( + result = await session.execute( + select(UserSettings).filter( UserSettings.keycloak_user_id == user_id, UserSettings.already_migrated.is_(True), ) - .first() ) + user_settings = result.scalars().first() # For new sign-ups after migration, user_settings won't exist # Fall back to getting data from org_members @@ -491,7 +490,7 @@ class UserStore: 'user_store:downgrade_user:created_user_settings_from_org_member', extra={'user_id': user_id}, ) - session.flush() + await session.flush() # Call LiteLLM downgrade from storage.lite_llm_manager import LiteLlmManager @@ -531,7 +530,7 @@ class UserStore: # Step 3: Copy user_id from conversation_metadata_saas to conversation_metadata # This ensures any conversations created after migration have their user_id # preserved in the original table before we delete the saas entries - session.execute( + await session.execute( text(""" UPDATE conversation_metadata SET user_id = :user_id @@ -545,14 +544,14 @@ class UserStore: ) # Step 4: Delete conversation_metadata_saas entries - session.execute( + await session.execute( text('DELETE FROM conversation_metadata_saas WHERE user_id = :user_id'), {'user_id': user_uuid}, ) # Step 5: Reset org_id columns in related tables # Reset stripe_customers - session.execute( + await session.execute( text( 'UPDATE stripe_customers SET org_id = NULL WHERE org_id = :org_id' ), @@ -560,13 +559,13 @@ class UserStore: ) # Reset slack_users - session.execute( + await session.execute( text('UPDATE slack_users SET org_id = NULL WHERE org_id = :org_id'), {'org_id': user_uuid}, ) # Reset slack_conversation - session.execute( + await session.execute( text( 'UPDATE slack_conversation SET org_id = NULL WHERE org_id = :org_id' ), @@ -574,19 +573,19 @@ class UserStore: ) # Reset api_keys - session.execute( + await session.execute( text('UPDATE api_keys SET org_id = NULL WHERE org_id = :org_id'), {'org_id': user_uuid}, ) # Reset custom_secrets - session.execute( + await session.execute( text('UPDATE custom_secrets SET org_id = NULL WHERE org_id = :org_id'), {'org_id': user_uuid}, ) # Reset billing_sessions - session.execute( + await session.execute( text( 'UPDATE billing_sessions SET org_id = NULL WHERE org_id = :org_id' ), @@ -594,19 +593,19 @@ class UserStore: ) # Step 6: Delete org_member entries for this org - session.execute( + await session.execute( text('DELETE FROM org_member WHERE org_id = :org_id'), {'org_id': user_uuid}, ) # Step 7: Delete the user entry - session.execute( + await session.execute( text('DELETE FROM "user" WHERE id = :user_id'), {'user_id': user_uuid}, ) # Delete the org entry - session.execute( + await session.execute( text('DELETE FROM org WHERE id = :org_id'), {'org_id': user_uuid}, ) @@ -626,9 +625,9 @@ class UserStore: if value is not None and not _is_legacy_value_encrypted(value): setattr(user_settings, key, encrypt_legacy_value(value)) - session.merge(user_settings) + await session.merge(user_settings) - session.commit() + await session.commit() logger.info( 'user_store:downgrade_user:complete', @@ -637,88 +636,12 @@ class UserStore: return user_settings @staticmethod - def get_user_by_id(user_id: str) -> Optional[User]: - """Get user by Keycloak user ID (sync version). - - Note: This method uses call_async_from_sync internally which creates a new - event loop. If you're already in an async context, use get_user_by_id_async - instead to avoid event loop conflicts. - """ - with session_maker() as session: - user = ( - session.query(User) - .options(joinedload(User.org_members)) - .filter(User.id == uuid.UUID(user_id)) - .first() - ) - if user: - return user - - # Check if we need to migrate from user_settings - while not call_async_from_sync( - UserStore._acquire_user_creation_lock, GENERAL_TIMEOUT, user_id - ): - # The user is already being created in another thread / process - logger.info( - 'user_store:create_default_settings:waiting_for_lock', - extra={'user_id': user_id}, - ) - call_async_from_sync( - asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS - ) - - try: - # Check for user again as migration could have happened while trying to get the lock. - user = ( - session.query(User) - .options(joinedload(User.org_members)) - .filter(User.id == uuid.UUID(user_id)) - .first() - ) - if user: - return user - - user_settings = ( - session.query(UserSettings) - .filter( - UserSettings.keycloak_user_id == user_id, - UserSettings.already_migrated.is_(False), - ) - .first() - ) - if user_settings: - token_manager = TokenManager() - user_info = call_async_from_sync( - token_manager.get_user_info_from_user_id, - GENERAL_TIMEOUT, - user_id, - ) - user = call_async_from_sync( - UserStore.migrate_user, - GENERAL_TIMEOUT, - user_id, - user_settings, - user_info, - ) - return user - else: - return None - finally: - call_async_from_sync( - UserStore._release_user_creation_lock, GENERAL_TIMEOUT, user_id - ) - - @staticmethod - async def get_user_by_id_async(user_id: str) -> Optional[User]: - """Get user by Keycloak user ID (async version). - - This is the preferred method when calling from an async context as it - avoids event loop conflicts that can occur with the sync version. - """ + async def get_user_by_id(user_id: str) -> Optional[User]: + """Get user by Keycloak user ID.""" async with a_session_maker() as session: result = await session.execute( select(User) - .options(joinedload(User.org_members)) + .options(selectinload(User.org_members)) .filter(User.id == uuid.UUID(user_id)) ) user = result.scalars().first() @@ -729,7 +652,7 @@ class UserStore: while not await UserStore._acquire_user_creation_lock(user_id): # The user is already being created in another thread / process logger.info( - 'user_store:get_user_by_id_async:waiting_for_lock', + 'user_store:create_default_settings:waiting_for_lock', extra={'user_id': user_id}, ) await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS) @@ -738,17 +661,13 @@ class UserStore: # Check for user again as migration could have happened while trying to get the lock. result = await session.execute( select(User) - .options(joinedload(User.org_members)) + .options(selectinload(User.org_members)) .filter(User.id == uuid.UUID(user_id)) ) user = result.scalars().first() if user: return user - logger.info( - 'user_store:get_user_by_id_async:start_migration', - extra={'user_id': user_id}, - ) result = await session.execute( select(UserSettings).filter( UserSettings.keycloak_user_id == user_id, @@ -759,10 +678,6 @@ class UserStore: if user_settings: token_manager = TokenManager() user_info = await token_manager.get_user_info_from_user_id(user_id) - logger.info( - 'user_store:get_user_by_id_async:calling_migrate_user', - extra={'user_id': user_id}, - ) user = await UserStore.migrate_user( user_id, user_settings, @@ -775,8 +690,8 @@ class UserStore: await UserStore._release_user_creation_lock(user_id) @staticmethod - async def get_user_by_email_async(email: str) -> Optional[User]: - """Get user by email address (async version). + async def get_user_by_email(email: str) -> Optional[User]: + """Get user by email address. This method looks up a user by their email address. Note that email addresses may not be unique across all users in rare cases. @@ -793,19 +708,20 @@ class UserStore: async with a_session_maker() as session: result = await session.execute( select(User) - .options(joinedload(User.org_members)) + .options(selectinload(User.org_members)) .filter(User.email == email.lower().strip()) ) return result.scalars().first() @staticmethod - def list_users() -> list[User]: + async def list_users() -> list[User]: """List all users.""" - with session_maker() as session: - return session.query(User).all() + async with a_session_maker() as session: + result = await session.execute(select(User)) + return list(result.scalars().all()) @staticmethod - def update_current_org(user_id: str, org_id: UUID) -> Optional[User]: + async def update_current_org(user_id: str, org_id: UUID) -> Optional[User]: """Update the user's current organization. Args: @@ -815,19 +731,17 @@ class UserStore: Returns: User: The updated user object, or None if user not found """ - with session_maker() as session: - user = ( - session.query(User) - .filter(User.id == uuid.UUID(user_id)) - .with_for_update() - .first() + async with a_session_maker() as session: + result = await session.execute( + select(User).filter(User.id == uuid.UUID(user_id)).with_for_update() ) + user = result.scalars().first() if not user: return None user.current_org_id = org_id - session.commit() - session.refresh(user) + await session.commit() + await session.refresh(user) return user @staticmethod diff --git a/enterprise/tests/unit/server/routes/test_api_keys.py b/enterprise/tests/unit/server/routes/test_api_keys.py index 734db4e692..57a9cb465d 100644 --- a/enterprise/tests/unit/server/routes/test_api_keys.py +++ b/enterprise/tests/unit/server/routes/test_api_keys.py @@ -374,7 +374,7 @@ class TestDeleteByorKeyFromLitellm: @pytest.mark.asyncio @patch('storage.lite_llm_manager.LiteLlmManager.delete_key') - @patch('storage.user_store.UserStore.get_user_by_id_async') + @patch('storage.user_store.UserStore.get_user_by_id') async def test_delete_constructs_alias_from_user( self, mock_get_user, mock_delete_key ): @@ -400,7 +400,7 @@ class TestDeleteByorKeyFromLitellm: @pytest.mark.asyncio @patch('storage.lite_llm_manager.LiteLlmManager.delete_key') - @patch('storage.user_store.UserStore.get_user_by_id_async') + @patch('storage.user_store.UserStore.get_user_by_id') async def test_delete_without_user_passes_no_alias( self, mock_get_user, mock_delete_key ): @@ -421,7 +421,7 @@ class TestDeleteByorKeyFromLitellm: @pytest.mark.asyncio @patch('storage.lite_llm_manager.LiteLlmManager.delete_key') - @patch('storage.user_store.UserStore.get_user_by_id_async') + @patch('storage.user_store.UserStore.get_user_by_id') async def test_delete_without_org_id_passes_no_alias( self, mock_get_user, mock_delete_key ): @@ -444,7 +444,7 @@ class TestDeleteByorKeyFromLitellm: @pytest.mark.asyncio @patch('storage.lite_llm_manager.LiteLlmManager.delete_key') - @patch('storage.user_store.UserStore.get_user_by_id_async') + @patch('storage.user_store.UserStore.get_user_by_id') async def test_delete_returns_false_on_exception( self, mock_get_user, mock_delete_key ): diff --git a/enterprise/tests/unit/server/routes/test_orgs.py b/enterprise/tests/unit/server/routes/test_orgs.py index 7aec94c847..069896b2cf 100644 --- a/enterprise/tests/unit/server/routes/test_orgs.py +++ b/enterprise/tests/unit/server/routes/test_orgs.py @@ -514,7 +514,7 @@ async def test_list_user_orgs_success(mock_app_list): with ( patch( - 'server.routes.orgs.UserStore.get_user_by_id_async', + 'server.routes.orgs.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -568,7 +568,7 @@ async def test_list_user_orgs_returns_current_org_id(mock_app_list): with ( patch( - 'server.routes.orgs.UserStore.get_user_by_id_async', + 'server.routes.orgs.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -613,7 +613,7 @@ async def test_list_user_orgs_with_pagination(mock_app_list): with ( patch( - 'server.routes.orgs.UserStore.get_user_by_id_async', + 'server.routes.orgs.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -648,7 +648,7 @@ async def test_list_user_orgs_empty(mock_app_list): with ( patch( - 'server.routes.orgs.UserStore.get_user_by_id_async', + 'server.routes.orgs.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -715,7 +715,7 @@ async def test_list_user_orgs_service_error(mock_app_list): with ( patch( - 'server.routes.orgs.UserStore.get_user_by_id_async', + 'server.routes.orgs.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -781,7 +781,7 @@ async def test_list_user_orgs_personal_org_identified(mock_app_list): with ( patch( - 'server.routes.orgs.UserStore.get_user_by_id_async', + 'server.routes.orgs.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -820,7 +820,7 @@ async def test_list_user_orgs_team_org_identified(mock_app_list): with ( patch( - 'server.routes.orgs.UserStore.get_user_by_id_async', + 'server.routes.orgs.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -869,7 +869,7 @@ async def test_list_user_orgs_mixed_personal_and_team(mock_app_list): with ( patch( - 'server.routes.orgs.UserStore.get_user_by_id_async', + 'server.routes.orgs.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -941,7 +941,7 @@ async def test_list_user_orgs_all_fields_present(mock_app_list): with ( patch( - 'server.routes.orgs.UserStore.get_user_by_id_async', + 'server.routes.orgs.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( diff --git a/enterprise/tests/unit/server/services/test_org_member_service.py b/enterprise/tests/unit/server/services/test_org_member_service.py index 001b958c03..440ecde4ba 100644 --- a/enterprise/tests/unit/server/services/test_org_member_service.py +++ b/enterprise/tests/unit/server/services/test_org_member_service.py @@ -718,7 +718,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -767,7 +767,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -815,7 +815,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -967,7 +967,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch( @@ -1149,7 +1149,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -1251,11 +1251,12 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch( - 'server.services.org_member_service.UserStore.update_current_org' + 'server.services.org_member_service.UserStore.update_current_org', + new_callable=AsyncMock, ) as mock_update_org, ): mock_get_member.side_effect = [ @@ -1274,7 +1275,9 @@ class TestOrgMemberServiceRemoveOrgMember: # Assert assert success is True assert error is None - mock_update_org.assert_called_once_with(str(target_user_id), target_user_id) + mock_update_org.assert_awaited_once_with( + str(target_user_id), target_user_id + ) @pytest.mark.asyncio async def test_remove_member_does_not_update_current_org_id_when_not_matching( @@ -1307,11 +1310,12 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch( - 'server.services.org_member_service.UserStore.update_current_org' + 'server.services.org_member_service.UserStore.update_current_org', + new_callable=AsyncMock, ) as mock_update_org, ): mock_get_member.side_effect = [ @@ -1330,7 +1334,7 @@ class TestOrgMemberServiceRemoveOrgMember: # Assert assert success is True assert error is None - mock_update_org.assert_not_called() + mock_update_org.assert_not_awaited() @pytest.mark.asyncio async def test_remove_member_succeeds_when_user_not_found_after_removal( @@ -1359,11 +1363,12 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch( - 'server.services.org_member_service.UserStore.update_current_org' + 'server.services.org_member_service.UserStore.update_current_org', + new_callable=AsyncMock, ) as mock_update_org, ): mock_get_member.side_effect = [ @@ -1382,7 +1387,7 @@ class TestOrgMemberServiceRemoveOrgMember: # Assert assert success is True assert error is None - mock_update_org.assert_not_called() + mock_update_org.assert_not_awaited() @pytest.mark.asyncio async def test_successful_removal_calls_litellm_remove_user_from_team( @@ -1411,7 +1416,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch( @@ -1465,7 +1470,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch( @@ -1632,7 +1637,7 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_update, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -1693,7 +1698,7 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_update, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -1752,7 +1757,7 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_update, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -1815,7 +1820,7 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_update, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch.object( @@ -2036,7 +2041,7 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -2253,7 +2258,7 @@ class TestOrgMemberServiceGetMe: new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -2340,7 +2345,7 @@ class TestOrgMemberServiceGetMe: new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.UserStore.get_user_by_id_async', + 'server.services.org_member_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): diff --git a/enterprise/tests/unit/test_api_key_store.py b/enterprise/tests/unit/test_api_key_store.py index fb163f978a..26f96d3f03 100644 --- a/enterprise/tests/unit/test_api_key_store.py +++ b/enterprise/tests/unit/test_api_key_store.py @@ -56,7 +56,7 @@ def test_generate_api_key(api_key_store): @pytest.mark.asyncio -@patch('storage.api_key_store.UserStore.get_user_by_id_async') +@patch('storage.api_key_store.UserStore.get_user_by_id') async def test_create_api_key( mock_get_user, api_key_store, async_session_maker, mock_user ): @@ -324,7 +324,7 @@ async def test_delete_api_key_by_id(api_key_store, async_session_maker): @pytest.mark.asyncio -@patch('storage.api_key_store.UserStore.get_user_by_id_async') +@patch('storage.api_key_store.UserStore.get_user_by_id') async def test_list_api_keys( mock_get_user, api_key_store, async_session_maker, mock_user ): @@ -377,7 +377,7 @@ async def test_list_api_keys( @pytest.mark.asyncio -@patch('storage.api_key_store.UserStore.get_user_by_id_async') +@patch('storage.api_key_store.UserStore.get_user_by_id') async def test_retrieve_mcp_api_key( mock_get_user, api_key_store, async_session_maker, mock_user ): @@ -416,7 +416,7 @@ async def test_retrieve_mcp_api_key( @pytest.mark.asyncio -@patch('storage.api_key_store.UserStore.get_user_by_id_async') +@patch('storage.api_key_store.UserStore.get_user_by_id') async def test_retrieve_mcp_api_key_not_found( mock_get_user, api_key_store, async_session_maker, mock_user ): diff --git a/enterprise/tests/unit/test_auth_routes.py b/enterprise/tests/unit/test_auth_routes.py index ce3d142ec4..0d1ed3760c 100644 --- a/enterprise/tests/unit/test_auth_routes.py +++ b/enterprise/tests/unit/test_auth_routes.py @@ -158,7 +158,7 @@ async def test_keycloak_callback_user_not_allowed( mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = None - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.migrate_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() @@ -197,7 +197,7 @@ async def test_keycloak_callback_success_with_valid_offline_token( mock_user.accepted_tos = '2025-01-01' # Setup UserStore mocks - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.migrate_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() @@ -273,7 +273,7 @@ async def test_keycloak_callback_email_not_verified( mock_user = MagicMock() mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -324,7 +324,7 @@ async def test_keycloak_callback_email_not_verified_missing_field( mock_user = MagicMock() mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -368,7 +368,7 @@ async def test_keycloak_callback_success_without_offline_token( mock_user.accepted_tos = '2025-01-01' # Setup UserStore mocks - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.migrate_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() @@ -616,7 +616,7 @@ async def test_keycloak_callback_blocked_email_domain( mock_user = MagicMock() mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -683,7 +683,7 @@ async def test_keycloak_callback_allowed_email_domain( mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -750,7 +750,7 @@ async def test_keycloak_callback_domain_blocking_inactive( mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -813,7 +813,7 @@ async def test_keycloak_callback_missing_email(mock_request, create_keycloak_use mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -862,7 +862,7 @@ async def test_keycloak_callback_duplicate_email_detected( mock_user = MagicMock() mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -910,7 +910,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails( mock_user = MagicMock() mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -971,7 +971,7 @@ async def test_keycloak_callback_duplicate_check_exception( mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1032,7 +1032,7 @@ async def test_keycloak_callback_no_duplicate_email( mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1096,7 +1096,7 @@ async def test_keycloak_callback_no_email_in_user_info( mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1254,7 +1254,7 @@ class TestKeycloakCallbackRecaptcha: mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1322,7 +1322,7 @@ class TestKeycloakCallbackRecaptcha: mock_user = MagicMock() mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1408,7 +1408,7 @@ class TestKeycloakCallbackRecaptcha: mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1497,7 +1497,7 @@ class TestKeycloakCallbackRecaptcha: mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1585,7 +1585,7 @@ class TestKeycloakCallbackRecaptcha: mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1670,7 +1670,7 @@ class TestKeycloakCallbackRecaptcha: mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1752,7 +1752,7 @@ class TestKeycloakCallbackRecaptcha: mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1822,7 +1822,7 @@ class TestKeycloakCallbackRecaptcha: mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1896,7 +1896,7 @@ class TestKeycloakCallbackRecaptcha: mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -1970,7 +1970,7 @@ class TestKeycloakCallbackRecaptcha: mock_user = MagicMock() mock_user.id = 'test_user_id' mock_user.current_org_id = 'test_org_id' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() @@ -2020,7 +2020,7 @@ async def test_keycloak_callback_calls_backfill_user_email_for_existing_user( mock_user.current_org_id = 'test_org_id' mock_user.accepted_tos = '2025-01-01' - mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() mock_user_store.backfill_user_email = AsyncMock() diff --git a/enterprise/tests/unit/test_billing.py b/enterprise/tests/unit/test_billing.py index b419faed36..995e277acd 100644 --- a/enterprise/tests/unit/test_billing.py +++ b/enterprise/tests/unit/test_billing.py @@ -103,7 +103,7 @@ async def test_get_credits_lite_llm_error(): with ( patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'), patch( - 'storage.user_store.UserStore.get_user_by_id_async', + 'storage.user_store.UserStore.get_user_by_id', new_callable=AsyncMock, return_value=MagicMock(current_org_id='mock_org_id'), ), @@ -135,7 +135,7 @@ async def test_get_credits_success(): patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'), patch('httpx.AsyncClient', return_value=mock_client), patch( - 'storage.user_store.UserStore.get_user_by_id_async', + 'storage.user_store.UserStore.get_user_by_id', new_callable=AsyncMock, return_value=MagicMock(current_org_id='mock_org_id'), ), @@ -338,7 +338,7 @@ async def test_success_callback_success(async_session_maker, test_org, test_user patch('server.routes.billing.a_session_maker', async_session_maker), patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve, patch( - 'storage.user_store.UserStore.get_user_by_id_async', + 'storage.user_store.UserStore.get_user_by_id', new_callable=AsyncMock, return_value=MagicMock(current_org_id=test_org.id), ), @@ -410,7 +410,7 @@ async def test_success_callback_lite_llm_error( patch('server.routes.billing.a_session_maker', async_session_maker), patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve, patch( - 'storage.user_store.UserStore.get_user_by_id_async', + 'storage.user_store.UserStore.get_user_by_id', new_callable=AsyncMock, return_value=MagicMock(current_org_id=test_org.id), ), @@ -464,7 +464,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback( patch('server.routes.billing.a_session_maker', async_session_maker), patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve, patch( - 'storage.user_store.UserStore.get_user_by_id_async', + 'storage.user_store.UserStore.get_user_by_id', new_callable=AsyncMock, return_value=MagicMock(current_org_id=test_org.id), ), diff --git a/enterprise/tests/unit/test_lite_llm_manager.py b/enterprise/tests/unit/test_lite_llm_manager.py index cac0b37e23..9b1f53a6b0 100644 --- a/enterprise/tests/unit/test_lite_llm_manager.py +++ b/enterprise/tests/unit/test_lite_llm_manager.py @@ -1100,9 +1100,7 @@ class TestLiteLlmManager: mock_org_member.org_id = 'test-ord-id' mock_org_member.llm_api_key = 'test-api-key' mock_user.org_members = [mock_org_member] - mock_user_store.get_user_by_id_async = AsyncMock( - return_value=mock_user - ) + mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user) result = await LiteLlmManager._get_key_info( mock_http_client, 'test-ord-id', 'test-user-id' @@ -1118,7 +1116,7 @@ class TestLiteLlmManager: with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'): with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'): with patch('storage.user_store.UserStore') as mock_user_store: - mock_user_store.get_user_by_id_async = AsyncMock(return_value=None) + mock_user_store.get_user_by_id = AsyncMock(return_value=None) result = await LiteLlmManager._get_key_info( mock_http_client, 'test-ord-id', 'test-user-id' diff --git a/enterprise/tests/unit/test_org_invitation_service.py b/enterprise/tests/unit/test_org_invitation_service.py index 5f797dedde..487243327e 100644 --- a/enterprise/tests/unit/test_org_invitation_service.py +++ b/enterprise/tests/unit/test_org_invitation_service.py @@ -70,7 +70,7 @@ class TestAcceptInvitationEmailValidation: 'server.services.org_invitation_service.OrgInvitationStore.is_token_expired' ) as mock_is_expired, patch( - 'server.services.org_invitation_service.UserStore.get_user_by_id_async', + 'server.services.org_invitation_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, ): @@ -106,7 +106,7 @@ class TestAcceptInvitationEmailValidation: 'server.services.org_invitation_service.OrgInvitationStore.is_token_expired' ) as mock_is_expired, patch( - 'server.services.org_invitation_service.UserStore.get_user_by_id_async', + 'server.services.org_invitation_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch( @@ -174,7 +174,7 @@ class TestAcceptInvitationEmailValidation: 'server.services.org_invitation_service.OrgInvitationStore.is_token_expired' ) as mock_is_expired, patch( - 'server.services.org_invitation_service.UserStore.get_user_by_id_async', + 'server.services.org_invitation_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch( @@ -220,7 +220,7 @@ class TestAcceptInvitationEmailValidation: 'server.services.org_invitation_service.OrgInvitationStore.is_token_expired' ) as mock_is_expired, patch( - 'server.services.org_invitation_service.UserStore.get_user_by_id_async', + 'server.services.org_invitation_service.UserStore.get_user_by_id', new_callable=AsyncMock, ) as mock_get_user, patch( diff --git a/enterprise/tests/unit/test_org_service.py b/enterprise/tests/unit/test_org_service.py index 94edcbff3f..0ddd225f5e 100644 --- a/enterprise/tests/unit/test_org_service.py +++ b/enterprise/tests/unit/test_org_service.py @@ -1870,7 +1870,7 @@ async def test_check_byor_export_enabled_returns_true_when_enabled(): with ( patch( - 'storage.org_service.UserStore.get_user_by_id_async', + 'storage.org_service.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -1905,7 +1905,7 @@ async def test_check_byor_export_enabled_returns_false_when_disabled(): with ( patch( - 'storage.org_service.UserStore.get_user_by_id_async', + 'storage.org_service.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -1932,7 +1932,7 @@ async def test_check_byor_export_enabled_returns_false_when_user_not_found(): user_id = 'nonexistent-user' with patch( - 'storage.org_service.UserStore.get_user_by_id_async', + 'storage.org_service.UserStore.get_user_by_id', AsyncMock(return_value=None), ): # Act @@ -1956,7 +1956,7 @@ async def test_check_byor_export_enabled_returns_false_when_no_current_org(): mock_user.current_org_id = None with patch( - 'storage.org_service.UserStore.get_user_by_id_async', + 'storage.org_service.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ): # Act @@ -1982,7 +1982,7 @@ async def test_check_byor_export_enabled_returns_false_when_org_not_found(): with ( patch( - 'storage.org_service.UserStore.get_user_by_id_async', + 'storage.org_service.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ), patch( @@ -2025,6 +2025,7 @@ async def test_switch_org_success(): patch('storage.org_service.OrgService.is_org_member', return_value=True), patch( 'storage.org_service.UserStore.update_current_org', + new_callable=AsyncMock, return_value=mock_updated_user, ), ): @@ -2116,7 +2117,11 @@ async def test_switch_org_user_not_found(): return_value=mock_org, ), patch('storage.org_service.OrgService.is_org_member', return_value=True), - patch('storage.org_service.UserStore.update_current_org', return_value=None), + patch( + 'storage.org_service.UserStore.update_current_org', + new_callable=AsyncMock, + return_value=None, + ), ): # Act & Assert with pytest.raises(OrgDatabaseError) as exc_info: diff --git a/enterprise/tests/unit/test_saas_conversation_store.py b/enterprise/tests/unit/test_saas_conversation_store.py index 4d59c1227f..6492be3f7f 100644 --- a/enterprise/tests/unit/test_saas_conversation_store.py +++ b/enterprise/tests/unit/test_saas_conversation_store.py @@ -169,14 +169,14 @@ async def test_exists(session_maker): class TestGetInstance: """Tests for SaasConversationStore.get_instance method. - The get_instance method uses async UserStore.get_user_by_id_async because + The get_instance method uses async UserStore.get_user_by_id because callers now use asyncio.run_coroutine_threadsafe() to dispatch to the main event loop where asyncpg connections work properly. """ @pytest.mark.asyncio async def test_get_instance_uses_async_get_user_by_id(self): - """Verify get_instance calls the async get_user_by_id_async for proper event loop handling.""" + """Verify get_instance calls the async get_user_by_id for proper event loop handling.""" # Arrange user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' mock_user = MagicMock(spec=User) @@ -184,7 +184,7 @@ class TestGetInstance: mock_config = MagicMock(spec=OpenHandsConfig) with patch( - 'storage.saas_conversation_store.UserStore.get_user_by_id_async', + 'storage.saas_conversation_store.UserStore.get_user_by_id', AsyncMock(return_value=mock_user), ) as mock_async_get_user, patch( 'storage.saas_conversation_store.session_maker' @@ -205,7 +205,7 @@ class TestGetInstance: mock_config = MagicMock(spec=OpenHandsConfig) with patch( - 'storage.saas_conversation_store.UserStore.get_user_by_id_async', + 'storage.saas_conversation_store.UserStore.get_user_by_id', AsyncMock(return_value=None), ), patch('storage.saas_conversation_store.session_maker'): # Act diff --git a/enterprise/tests/unit/test_saas_secrets_store.py b/enterprise/tests/unit/test_saas_secrets_store.py index 5cd42cfb71..f9a560d11c 100644 --- a/enterprise/tests/unit/test_saas_secrets_store.py +++ b/enterprise/tests/unit/test_saas_secrets_store.py @@ -44,7 +44,7 @@ def secrets_store(async_session_maker, mock_config): class TestSaasSecretsStore: @pytest.mark.asyncio @patch( - 'storage.saas_secrets_store.UserStore.get_user_by_id_async', + 'storage.saas_secrets_store.UserStore.get_user_by_id', new_callable=AsyncMock, ) async def test_store_and_load(self, mock_get_user, secrets_store, mock_user): @@ -84,7 +84,7 @@ class TestSaasSecretsStore: @pytest.mark.asyncio @patch( - 'storage.saas_secrets_store.UserStore.get_user_by_id_async', + 'storage.saas_secrets_store.UserStore.get_user_by_id', new_callable=AsyncMock, ) async def test_encryption_decryption(self, mock_get_user, secrets_store, mock_user): @@ -186,7 +186,7 @@ class TestSaasSecretsStore: @pytest.mark.asyncio @patch( - 'storage.saas_secrets_store.UserStore.get_user_by_id_async', + 'storage.saas_secrets_store.UserStore.get_user_by_id', new_callable=AsyncMock, ) async def test_update_existing_secrets( diff --git a/enterprise/tests/unit/test_user_route_fallback.py b/enterprise/tests/unit/test_user_route_fallback.py index 23ae43dd7b..cb43301f80 100644 --- a/enterprise/tests/unit/test_user_route_fallback.py +++ b/enterprise/tests/unit/test_user_route_fallback.py @@ -35,9 +35,9 @@ def mock_check_idp(): @pytest.fixture def mock_user_store(): - """Mock UserStore.get_user_by_id_async to return None by default.""" + """Mock UserStore.get_user_by_id to return None by default.""" with patch( - 'server.routes.user.UserStore.get_user_by_id_async', + 'server.routes.user.UserStore.get_user_by_id', new_callable=AsyncMock, return_value=None, ) as mock_fn: diff --git a/enterprise/tests/unit/test_user_store.py b/enterprise/tests/unit/test_user_store.py index 32bfacb1e9..6a2ecb41ac 100644 --- a/enterprise/tests/unit/test_user_store.py +++ b/enterprise/tests/unit/test_user_store.py @@ -101,11 +101,11 @@ async def test_create_default_settings_with_litellm(mock_litellm_api): assert settings.llm_base_url == 'http://test.url' -# --- Tests for get_user_by_id_async --- +# --- Tests for get_user_by_id --- @pytest.mark.asyncio -async def test_get_user_by_id_async_existing_user(async_session_maker): +async def test_get_user_by_id_existing_user(async_session_maker): """Test retrieving an existing user by ID.""" user_id = uuid.uuid4() org_id = uuid.uuid4() @@ -120,7 +120,7 @@ async def test_get_user_by_id_async_existing_user(async_session_maker): # Test retrieval with patched session maker with patch('storage.user_store.a_session_maker', async_session_maker): - result = await UserStore.get_user_by_id_async(str(user_id)) + result = await UserStore.get_user_by_id(str(user_id)) assert result is not None assert result.id == user_id @@ -128,8 +128,8 @@ async def test_get_user_by_id_async_existing_user(async_session_maker): @pytest.mark.asyncio -async def test_get_user_by_id_async_user_not_found(async_session_maker): - """Test that get_user_by_id_async returns None for non-existent user.""" +async def test_get_user_by_id_user_not_found(async_session_maker): + """Test that get_user_by_id returns None for non-existent user.""" non_existent_id = str(uuid.uuid4()) with patch('storage.user_store.a_session_maker', async_session_maker): @@ -138,16 +138,16 @@ async def test_get_user_by_id_async_user_not_found(async_session_maker): patch.object(UserStore, '_acquire_user_creation_lock', return_value=True), patch.object(UserStore, '_release_user_creation_lock', return_value=True), ): - result = await UserStore.get_user_by_id_async(non_existent_id) + result = await UserStore.get_user_by_id(non_existent_id) assert result is None -# --- Tests for get_user_by_email_async --- +# --- Tests for get_user_by_email --- @pytest.mark.asyncio -async def test_get_user_by_email_async_existing_user(async_session_maker): +async def test_get_user_by_email_existing_user(async_session_maker): """Test retrieving a user by email.""" user_id = uuid.uuid4() org_id = uuid.uuid4() @@ -163,7 +163,7 @@ async def test_get_user_by_email_async_existing_user(async_session_maker): # Test retrieval with patch('storage.user_store.a_session_maker', async_session_maker): - result = await UserStore.get_user_by_email_async(email) + result = await UserStore.get_user_by_email(email) assert result is not None assert result.id == user_id @@ -171,28 +171,28 @@ async def test_get_user_by_email_async_existing_user(async_session_maker): @pytest.mark.asyncio -async def test_get_user_by_email_async_not_found(async_session_maker): - """Test that get_user_by_email_async returns None for non-existent email.""" +async def test_get_user_by_email_not_found(async_session_maker): + """Test that get_user_by_email returns None for non-existent email.""" with patch('storage.user_store.a_session_maker', async_session_maker): - result = await UserStore.get_user_by_email_async('nonexistent@example.com') + result = await UserStore.get_user_by_email('nonexistent@example.com') assert result is None @pytest.mark.asyncio -async def test_get_user_by_email_async_empty_email(async_session_maker): - """Test that get_user_by_email_async returns None for empty email.""" +async def test_get_user_by_email_empty_email(async_session_maker): + """Test that get_user_by_email returns None for empty email.""" with patch('storage.user_store.a_session_maker', async_session_maker): - result = await UserStore.get_user_by_email_async('') + result = await UserStore.get_user_by_email('') assert result is None @pytest.mark.asyncio -async def test_get_user_by_email_async_none_email(async_session_maker): - """Test that get_user_by_email_async returns None for None email.""" +async def test_get_user_by_email_none_email(async_session_maker): + """Test that get_user_by_email returns None for None email.""" with patch('storage.user_store.a_session_maker', async_session_maker): - result = await UserStore.get_user_by_email_async(None) + result = await UserStore.get_user_by_email(None) assert result is None @@ -543,47 +543,50 @@ async def test_backfill_contact_name_no_real_name(async_session_maker): assert org.contact_name == 'jdoe' -# --- Tests for update_current_org (sync) --- +# --- Tests for update_current_org --- -def test_update_current_org_success(session_maker): +@pytest.mark.asyncio +async def test_update_current_org_success(async_session_maker): """Test updating a user's current organization.""" user_id = uuid.uuid4() initial_org_id = uuid.uuid4() new_org_id = uuid.uuid4() # Create test data - with session_maker() as session: + async with async_session_maker() as session: org1 = Org(id=initial_org_id, name='org1') org2 = Org(id=new_org_id, name='org2') session.add_all([org1, org2]) user = User(id=user_id, current_org_id=initial_org_id) session.add(user) - session.commit() + await session.commit() # Update current org - with patch('storage.user_store.session_maker', session_maker): - result = UserStore.update_current_org(str(user_id), new_org_id) + with patch('storage.user_store.a_session_maker', async_session_maker): + result = await UserStore.update_current_org(str(user_id), new_org_id) assert result is not None assert result.current_org_id == new_org_id -def test_update_current_org_user_not_found(session_maker): +@pytest.mark.asyncio +async def test_update_current_org_user_not_found(async_session_maker): """Test that update_current_org returns None for non-existent user.""" user_id = str(uuid.uuid4()) org_id = uuid.uuid4() - with patch('storage.user_store.session_maker', session_maker): - result = UserStore.update_current_org(user_id, org_id) + with patch('storage.user_store.a_session_maker', async_session_maker): + result = await UserStore.update_current_org(user_id, org_id) assert result is None -# --- Tests for list_users (sync) --- +# --- Tests for list_users --- -def test_list_users(session_maker): +@pytest.mark.asyncio +async def test_list_users(async_session_maker): """Test listing all users.""" user_id1 = uuid.uuid4() user_id2 = uuid.uuid4() @@ -591,18 +594,18 @@ def test_list_users(session_maker): org_id2 = uuid.uuid4() # Create test data - with session_maker() as session: + async with async_session_maker() as session: org1 = Org(id=org_id1, name='org1') org2 = Org(id=org_id2, name='org2') session.add_all([org1, org2]) user1 = User(id=user_id1, current_org_id=org_id1) user2 = User(id=user_id2, current_org_id=org_id2) session.add_all([user1, user2]) - session.commit() + await session.commit() # List users - with patch('storage.user_store.session_maker', session_maker): - users = UserStore.list_users() + with patch('storage.user_store.a_session_maker', async_session_maker): + users = await UserStore.list_users() assert len(users) >= 2 user_ids = [user.id for user in users]