diff --git a/enterprise/server/auth/authorization.py b/enterprise/server/auth/authorization.py index 522ef47631..4595ea07a3 100644 --- a/enterprise/server/auth/authorization.py +++ b/enterprise/server/auth/authorization.py @@ -157,9 +157,9 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = { } -def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None: +async def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None: """ - Get the user's role in an organization (synchronous version). + Get the user's role in an organization. Args: user_id: User ID (string that will be converted to UUID) @@ -171,36 +171,11 @@ def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None: from uuid import UUID as parse_uuid if org_id is None: - org_member = OrgMemberStore.get_org_member_for_current_org(parse_uuid(user_id)) - else: - org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id)) - if not org_member: - return None - - return RoleStore.get_role_by_id(org_member.role_id) - - -async def get_user_org_role_async(user_id: str, org_id: UUID | None) -> Role | None: - """ - Get the user's role in an organization (async version). - - Args: - user_id: User ID (string that will be converted to UUID) - org_id: Organization ID, or None to use the user's current organization - - Returns: - Role object if user is a member, None otherwise - """ - from uuid import UUID as parse_uuid - - if org_id is None: - org_member = await OrgMemberStore.get_org_member_for_current_org_async( + org_member = await OrgMemberStore.get_org_member_for_current_org( parse_uuid(user_id) ) else: - org_member = await OrgMemberStore.get_org_member_async( - org_id, parse_uuid(user_id) - ) + org_member = await OrgMemberStore.get_org_member(org_id, parse_uuid(user_id)) if not org_member: return None @@ -274,7 +249,7 @@ def require_permission(permission: Permission): detail='User not authenticated', ) - user_role = await get_user_org_role_async(user_id, org_id) + user_role = await get_user_org_role(user_id, org_id) if not user_role: logger.warning( diff --git a/enterprise/server/routes/api_keys.py b/enterprise/server/routes/api_keys.py index 57394850ac..5b433aef98 100644 --- a/enterprise/server/routes/api_keys.py +++ b/enterprise/server/routes/api_keys.py @@ -49,7 +49,7 @@ async def store_byor_key_in_db(user_id: str, key: str) -> None: if not current_org_member: return None current_org_member.llm_api_key_for_byor = key - OrgMemberStore.update_org_member(current_org_member) + await OrgMemberStore.update_org_member(current_org_member) async def generate_byor_key(user_id: str) -> str | None: diff --git a/enterprise/server/routes/orgs.py b/enterprise/server/routes/orgs.py index b9c39a0925..8c1abb18ed 100644 --- a/enterprise/server/routes/orgs.py +++ b/enterprise/server/routes/orgs.py @@ -497,7 +497,7 @@ async def get_me( try: user_uuid = UUID(user_id) - return OrgMemberService.get_me(org_id, user_uuid) + return await OrgMemberService.get_me(org_id, user_uuid) except OrgMemberNotFoundError: raise HTTPException( diff --git a/enterprise/server/services/org_invitation_service.py b/enterprise/server/services/org_invitation_service.py index 2c1ba62f9b..5518ab5dd0 100644 --- a/enterprise/server/services/org_invitation_service.py +++ b/enterprise/server/services/org_invitation_service.py @@ -85,13 +85,13 @@ class OrgInvitationService: ) # Step 3: Check inviter is a member and has permission - inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id) + inviter_member = await OrgMemberStore.get_org_member(org_id, inviter_id) if not inviter_member: raise InsufficientPermissionError( 'You are not a member of this organization' ) - inviter_role = RoleStore.get_role_by_id(inviter_member.role_id) + inviter_role = await RoleStore.get_role_by_id_async(inviter_member.role_id) if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]: raise InsufficientPermissionError('Only owners and admins can invite users') @@ -101,14 +101,16 @@ class OrgInvitationService: raise InsufficientPermissionError('Only owners can invite with owner role') # Get the target role - target_role = RoleStore.get_role_by_name(role_name_lower) + target_role = await RoleStore.get_role_by_name_async(role_name_lower) if not target_role: raise ValueError(f'Invalid role: {role_name}') # Step 5: Check if user is already a member (by email) existing_user = await UserStore.get_user_by_email_async(email) if existing_user: - existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id) + existing_member = await OrgMemberStore.get_org_member( + org_id, existing_user.id + ) if existing_member: raise UserAlreadyMemberError( 'User is already a member of this organization' @@ -196,13 +198,13 @@ class OrgInvitationService: 'Cannot invite users to a personal workspace' ) - inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id) + inviter_member = await OrgMemberStore.get_org_member(org_id, inviter_id) if not inviter_member: raise InsufficientPermissionError( 'You are not a member of this organization' ) - inviter_role = RoleStore.get_role_by_id(inviter_member.role_id) + inviter_role = await RoleStore.get_role_by_id_async(inviter_member.role_id) if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]: raise InsufficientPermissionError('Only owners and admins can invite users') @@ -210,7 +212,7 @@ class OrgInvitationService: if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER: raise InsufficientPermissionError('Only owners can invite with owner role') - target_role = RoleStore.get_role_by_name(role_name_lower) + target_role = await RoleStore.get_role_by_name_async(role_name_lower) if not target_role: raise ValueError(f'Invalid role: {role_name}') @@ -336,7 +338,9 @@ class OrgInvitationService: raise EmailMismatchError() # Step 3: Check if user is already a member - existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id) + existing_member = await OrgMemberStore.get_org_member( + invitation.org_id, user_id + ) if existing_member: raise UserAlreadyMemberError( 'You are already a member of this organization' @@ -369,7 +373,7 @@ class OrgInvitationService: org_member_kwargs.pop('llm_model', None) org_member_kwargs.pop('llm_base_url', None) - OrgMemberStore.add_user_to_org( + await OrgMemberStore.add_user_to_org( org_id=invitation.org_id, user_id=user_id, role_id=invitation.role_id, diff --git a/enterprise/server/services/org_member_service.py b/enterprise/server/services/org_member_service.py index 5777ab0d5a..264d8fa135 100644 --- a/enterprise/server/services/org_member_service.py +++ b/enterprise/server/services/org_member_service.py @@ -22,14 +22,13 @@ from storage.role_store import RoleStore from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger -from openhands.utils.async_utils import call_sync_from_async class OrgMemberService: """Service for organization member operations.""" @staticmethod - def get_me(org_id: UUID, user_id: UUID) -> MeResponse: + async def get_me(org_id: UUID, user_id: UUID) -> MeResponse: """Get the current user's membership record for an organization. Retrieves the authenticated user's role, status, email, and LLM override @@ -47,17 +46,17 @@ class OrgMemberService: RoleNotFoundError: If the role associated with the member is not found """ # Look up the user's membership in this org - org_member = OrgMemberStore.get_org_member(org_id, user_id) + org_member = await OrgMemberStore.get_org_member(org_id, user_id) if org_member is None: raise OrgMemberNotFoundError(str(org_id), str(user_id)) # Resolve role name from role_id - role = RoleStore.get_role_by_id(org_member.role_id) + role = await RoleStore.get_role_by_id_async(org_member.role_id) if role is None: raise RoleNotFoundError(org_member.role_id) # Get user email - user = UserStore.get_user_by_id(str(user_id)) + user = await UserStore.get_user_by_id_async(str(user_id)) email = user.email if user and user.email else '' return MeResponse.from_org_member(org_member, role, email) @@ -83,7 +82,9 @@ class OrgMemberService: Tuple of (success, error_code, data). If success is True, error_code is None. """ # Verify current user is a member of the organization - requester_membership = OrgMemberStore.get_org_member(org_id, current_user_id) + requester_membership = await OrgMemberStore.get_org_member( + org_id, current_user_id + ) if not requester_membership: return False, 'not_a_member', None @@ -156,7 +157,9 @@ class OrgMemberService: OrgMemberNotFoundError: If requesting user is not a member of the organization. """ # Verify current user is a member of the organization - requester_membership = OrgMemberStore.get_org_member(org_id, current_user_id) + requester_membership = await OrgMemberStore.get_org_member( + org_id, current_user_id + ) if not requester_membership: raise OrgMemberNotFoundError(str(org_id), str(current_user_id)) @@ -176,82 +179,75 @@ class OrgMemberService: Returns: Tuple of (success, error_message). If success is True, error_message is None. """ + # Get current user's membership in the org + requester_membership = await OrgMemberStore.get_org_member( + org_id, current_user_id + ) + if not requester_membership: + return False, 'not_a_member' - def _remove_member(): - # Get current user's membership in the org - requester_membership = OrgMemberStore.get_org_member( - org_id, current_user_id - ) - if not requester_membership: - return False, 'not_a_member' + # Check if trying to remove self + if str(current_user_id) == str(target_user_id): + return False, 'cannot_remove_self' - # Check if trying to remove self - if str(current_user_id) == str(target_user_id): - return False, 'cannot_remove_self' + # Get target user's membership + target_membership = await OrgMemberStore.get_org_member(org_id, target_user_id) + if not target_membership: + return False, 'member_not_found' - # Get target user's membership - target_membership = OrgMemberStore.get_org_member(org_id, target_user_id) - if not target_membership: - return False, 'member_not_found' + requester_role = await RoleStore.get_role_by_id_async( + requester_membership.role_id + ) + target_role = await RoleStore.get_role_by_id_async(target_membership.role_id) - requester_role = RoleStore.get_role_by_id(requester_membership.role_id) - target_role = RoleStore.get_role_by_id(target_membership.role_id) + if not requester_role or not target_role: + return False, 'role_not_found' - if not requester_role or not target_role: - return False, 'role_not_found' + # Check permission based on roles + if not OrgMemberService._can_remove_member( + requester_role.name, target_role.name + ): + return False, 'insufficient_permission' - # Check permission based on roles - if not OrgMemberService._can_remove_member( - requester_role.name, target_role.name - ): - return False, 'insufficient_permission' + # Check if removing the last owner + if target_role.name == ROLE_OWNER: + if await OrgMemberService._is_last_owner(org_id, target_user_id): + return False, 'cannot_remove_last_owner' - # Check if removing the last owner - if target_role.name == ROLE_OWNER: - if OrgMemberService._is_last_owner(org_id, target_user_id): - return False, 'cannot_remove_last_owner' + # Perform the removal + success = await OrgMemberStore.remove_user_from_org(org_id, target_user_id) + if not success: + return False, 'removal_failed' - # Perform the removal - success = OrgMemberStore.remove_user_from_org(org_id, target_user_id) - if not success: - return False, 'removal_failed' - - # Update user's current_org_id if it points to the org they were removed from - user = UserStore.get_user_by_id(str(target_user_id)) - if user and user.current_org_id == org_id: - # Set current_org_id to personal workspace (org.id == user.id) - UserStore.update_current_org(str(target_user_id), target_user_id) - - return True, None - - success, error = await call_sync_from_async(_remove_member) + # Update user's current_org_id if it points to the org they were removed from + user = await UserStore.get_user_by_id_async(str(target_user_id)) + if user and user.current_org_id == org_id: + # Set current_org_id to personal workspace (org.id == user.id) + UserStore.update_current_org(str(target_user_id), target_user_id) # If database removal succeeded, also remove from LiteLLM team - if success: - try: - await LiteLlmManager.remove_user_from_team( - str(target_user_id), str(org_id) - ) - logger.info( - 'Successfully removed user from LiteLLM team', - extra={ - 'user_id': str(target_user_id), - 'org_id': str(org_id), - }, - ) - except Exception as e: - # Log but don't fail the operation - database removal already succeeded - # LiteLLM state will be eventually consistent - logger.warning( - 'Failed to remove user from LiteLLM team', - extra={ - 'user_id': str(target_user_id), - 'org_id': str(org_id), - 'error': str(e), - }, - ) + try: + await LiteLlmManager.remove_user_from_team(str(target_user_id), str(org_id)) + logger.info( + 'Successfully removed user from LiteLLM team', + extra={ + 'user_id': str(target_user_id), + 'org_id': str(org_id), + }, + ) + except Exception as e: + # Log but don't fail the operation - database removal already succeeded + # LiteLLM state will be eventually consistent + logger.warning( + 'Failed to remove user from LiteLLM team', + extra={ + 'user_id': str(target_user_id), + 'org_id': str(org_id), + 'error': str(e), + }, + ) - return success, error + return True, None @staticmethod async def update_org_member( @@ -287,85 +283,84 @@ class OrgMemberService: """ new_role_name = update_data.role - def _update_member(): - # Get current user's membership in the org - requester_membership = OrgMemberStore.get_org_member( - org_id, current_user_id - ) - if not requester_membership: - raise OrgMemberNotFoundError(str(org_id), str(current_user_id)) + # Get current user's membership in the org + requester_membership = await OrgMemberStore.get_org_member( + org_id, current_user_id + ) + if not requester_membership: + raise OrgMemberNotFoundError(str(org_id), str(current_user_id)) - # Check if trying to modify self - if str(current_user_id) == str(target_user_id): - raise CannotModifySelfError('modify') + # Check if trying to modify self + if str(current_user_id) == str(target_user_id): + raise CannotModifySelfError('modify') - # Get target user's membership - target_membership = OrgMemberStore.get_org_member(org_id, target_user_id) - if not target_membership: - raise OrgMemberNotFoundError(str(org_id), str(target_user_id)) + # Get target user's membership + target_membership = await OrgMemberStore.get_org_member(org_id, target_user_id) + if not target_membership: + raise OrgMemberNotFoundError(str(org_id), str(target_user_id)) - # Get roles - requester_role = RoleStore.get_role_by_id(requester_membership.role_id) - target_role = RoleStore.get_role_by_id(target_membership.role_id) + # Get roles + requester_role = await RoleStore.get_role_by_id_async( + requester_membership.role_id + ) + target_role = await RoleStore.get_role_by_id_async(target_membership.role_id) - if not requester_role: - raise RoleNotFoundError(requester_membership.role_id) - if not target_role: - raise RoleNotFoundError(target_membership.role_id) - - # If no role change requested, return current state - if new_role_name is None: - user = UserStore.get_user_by_id(str(target_user_id)) - return OrgMemberResponse( - user_id=str(target_membership.user_id), - email=user.email if user else None, - role_id=target_membership.role_id, - role=target_role.name, - role_rank=target_role.rank, - status=target_membership.status, - ) - - # Validate new role exists - new_role = RoleStore.get_role_by_name(new_role_name.lower()) - if not new_role: - raise InvalidRoleError(new_role_name) - - # Check permission to modify target - if not OrgMemberService._can_update_member_role( - requester_role.name, target_role.name, new_role.name - ): - raise InsufficientPermissionError( - 'You do not have permission to modify this member' - ) - - # Check if demoting the last owner - if ( - target_role.name == ROLE_OWNER - and new_role.name != ROLE_OWNER - and OrgMemberService._is_last_owner(org_id, target_user_id) - ): - raise LastOwnerError('demote') - - # Perform the update - updated_member = OrgMemberStore.update_user_role_in_org( - org_id, target_user_id, new_role.id - ) - if not updated_member: - raise MemberUpdateError('Failed to update member') - - # Get user email for response - user = UserStore.get_user_by_id(str(target_user_id)) + if not requester_role: + raise RoleNotFoundError(requester_membership.role_id) + if not target_role: + raise RoleNotFoundError(target_membership.role_id) + # If no role change requested, return current state + if new_role_name is None: + user = await UserStore.get_user_by_id_async(str(target_user_id)) return OrgMemberResponse( - user_id=str(updated_member.user_id), + user_id=str(target_membership.user_id), email=user.email if user else None, - role_id=updated_member.role_id, - role=new_role.name, - role_rank=new_role.rank, - status=updated_member.status, + role_id=target_membership.role_id, + role=target_role.name, + role_rank=target_role.rank, + status=target_membership.status, ) - return await call_sync_from_async(_update_member) + # Validate new role exists + new_role = await RoleStore.get_role_by_name_async(new_role_name.lower()) + if not new_role: + raise InvalidRoleError(new_role_name) + + # Check permission to modify target + if not OrgMemberService._can_update_member_role( + requester_role.name, target_role.name, new_role.name + ): + raise InsufficientPermissionError( + 'You do not have permission to modify this member' + ) + + # Check if demoting the last owner + if ( + target_role.name == ROLE_OWNER + and new_role.name != ROLE_OWNER + and await OrgMemberService._is_last_owner(org_id, target_user_id) + ): + raise LastOwnerError('demote') + + # Perform the update + updated_member = await OrgMemberStore.update_user_role_in_org( + org_id, target_user_id, new_role.id + ) + if not updated_member: + raise MemberUpdateError('Failed to update member') + + # Get user email for response + user = await UserStore.get_user_by_id_async(str(target_user_id)) + + return OrgMemberResponse( + user_id=str(updated_member.user_id), + email=user.email if user else None, + role_id=updated_member.role_id, + role=new_role.name, + role_rank=new_role.rank, + status=updated_member.status, + ) @staticmethod def _can_update_member_role( @@ -405,13 +400,13 @@ class OrgMemberService: return False @staticmethod - def _is_last_owner(org_id: UUID, user_id: UUID) -> bool: + async def _is_last_owner(org_id: UUID, user_id: UUID) -> bool: """Check if user is the last owner of the organization.""" - members = OrgMemberStore.get_org_members(org_id) + members = await OrgMemberStore.get_org_members(org_id) owners = [] for m in members: # Use role_id (column) instead of role (relationship) to avoid DetachedInstanceError - role = RoleStore.get_role_by_id(m.role_id) + role = await RoleStore.get_role_by_id_async(m.role_id) if role and role.name == ROLE_OWNER: owners.append(m) return len(owners) == 1 and str(owners[0].user_id) == str(user_id) diff --git a/enterprise/storage/org_member_store.py b/enterprise/storage/org_member_store.py index c92d7ba867..1f32c2e8b5 100644 --- a/enterprise/storage/org_member_store.py +++ b/enterprise/storage/org_member_store.py @@ -9,7 +9,7 @@ from server.routes.org_models import OrgMemberLLMSettings from sqlalchemy import func, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload -from storage.database import a_session_maker, session_maker +from storage.database import a_session_maker from storage.encrypt_utils import encrypt_value from storage.org_member import OrgMember from storage.user import User @@ -22,7 +22,7 @@ class OrgMemberStore: """Store for managing organization-member relationships.""" @staticmethod - def add_user_to_org( + async def add_user_to_org( org_id: UUID, user_id: UUID, role_id: int, @@ -30,7 +30,7 @@ class OrgMemberStore: status: Optional[str] = None, ) -> OrgMember: """Add a user to an organization with a specific role.""" - with session_maker() as session: + async with a_session_maker() as session: org_member = OrgMember( org_id=org_id, user_id=user_id, @@ -39,22 +39,12 @@ class OrgMemberStore: status=status, ) session.add(org_member) - session.commit() - session.refresh(org_member) + await session.commit() + await session.refresh(org_member) return org_member @staticmethod - def get_org_member(org_id: UUID, user_id: UUID) -> Optional[OrgMember]: - """Get organization-user relationship.""" - with session_maker() as session: - return ( - session.query(OrgMember) - .filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id) - .first() - ) - - @staticmethod - async def get_org_member_async(org_id: UUID, user_id: UUID) -> Optional[OrgMember]: + async def get_org_member(org_id: UUID, user_id: UUID) -> Optional[OrgMember]: """Get organization-user relationship.""" async with a_session_maker() as session: result = await session.execute( @@ -65,33 +55,9 @@ class OrgMemberStore: return result.scalars().first() @staticmethod - def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]: + async def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]: """Get the org member for a user's current organization. - Args: - user_id: The user's UUID. - - Returns: - The OrgMember for the user's current organization, or None if not found. - """ - with session_maker() as session: - result = ( - session.query(OrgMember) - .join(User, User.id == OrgMember.user_id) - .filter( - User.id == user_id, - OrgMember.org_id == User.current_org_id, - ) - .first() - ) - return result - - @staticmethod - async def get_org_member_for_current_org_async( - user_id: UUID, - ) -> Optional[OrgMember]: - """Get the org member for a user's current organization (async version). - Args: user_id: The user's UUID. @@ -110,35 +76,42 @@ class OrgMemberStore: return result.scalars().first() @staticmethod - def get_user_orgs(user_id: UUID) -> list[OrgMember]: + async def get_user_orgs(user_id: UUID) -> list[OrgMember]: """Get all organizations for a user.""" - with session_maker() as session: - return session.query(OrgMember).filter(OrgMember.user_id == user_id).all() + async with a_session_maker() as session: + result = await session.execute( + select(OrgMember).filter(OrgMember.user_id == user_id) + ) + return list(result.scalars().all()) @staticmethod - def get_org_members(org_id: UUID) -> list[OrgMember]: + async def get_org_members(org_id: UUID) -> list[OrgMember]: """Get all users in an organization.""" - with session_maker() as session: - return session.query(OrgMember).filter(OrgMember.org_id == org_id).all() + async with a_session_maker() as session: + result = await session.execute( + select(OrgMember).filter(OrgMember.org_id == org_id) + ) + return list(result.scalars().all()) @staticmethod - def update_org_member(org_member: OrgMember) -> None: + async def update_org_member(org_member: OrgMember) -> None: """Update an organization-member relationship.""" - with session_maker() as session: - session.merge(org_member) - session.commit() + async with a_session_maker() as session: + await session.merge(org_member) + await session.commit() @staticmethod - def update_user_role_in_org( + async def update_user_role_in_org( org_id: UUID, user_id: UUID, role_id: int, status: Optional[str] = None ) -> Optional[OrgMember]: """Update user's role in an organization.""" - with session_maker() as session: - org_member = ( - session.query(OrgMember) - .filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id) - .first() + async with a_session_maker() as session: + result = await session.execute( + select(OrgMember).filter( + OrgMember.org_id == org_id, OrgMember.user_id == user_id + ) ) + org_member = result.scalars().first() if not org_member: return None @@ -147,25 +120,26 @@ class OrgMemberStore: if status is not None: org_member.status = status - session.commit() - session.refresh(org_member) + await session.commit() + await session.refresh(org_member) return org_member @staticmethod - def remove_user_from_org(org_id: UUID, user_id: UUID) -> bool: + async def remove_user_from_org(org_id: UUID, user_id: UUID) -> bool: """Remove a user from an organization.""" - with session_maker() as session: - org_member = ( - session.query(OrgMember) - .filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id) - .first() + async with a_session_maker() as session: + result = await session.execute( + select(OrgMember).filter( + OrgMember.org_id == org_id, OrgMember.user_id == user_id + ) ) + org_member = result.scalars().first() if not org_member: return False - session.delete(org_member) - session.commit() + await session.delete(org_member) + await session.commit() return True @staticmethod diff --git a/enterprise/storage/org_service.py b/enterprise/storage/org_service.py index 3d328b3ff6..780fca890e 100644 --- a/enterprise/storage/org_service.py +++ b/enterprise/storage/org_service.py @@ -398,7 +398,7 @@ class OrgService: return e @staticmethod - def has_admin_or_owner_role(user_id: str, org_id: UUID) -> bool: + async def has_admin_or_owner_role(user_id: str, org_id: UUID) -> bool: """ Check if user has admin or owner role in the specified organization. @@ -415,12 +415,12 @@ class OrgService: # Get the user's membership in this organization # Note: The type annotation says int but the actual column is UUID - org_member = OrgMemberStore.get_org_member(org_id, user_uuid) + org_member = await OrgMemberStore.get_org_member(org_id, user_uuid) if not org_member: return False # Get the role details - role = RoleStore.get_role_by_id(org_member.role_id) + role = await RoleStore.get_role_by_id_async(org_member.role_id) if not role: return False @@ -440,7 +440,7 @@ class OrgService: return False @staticmethod - def is_org_member(user_id: str, org_id: UUID) -> bool: + async def is_org_member(user_id: str, org_id: UUID) -> bool: """ Check if user is a member of the specified organization. @@ -453,7 +453,7 @@ class OrgService: """ try: user_uuid = parse_uuid(user_id) - org_member = OrgMemberStore.get_org_member(org_id, user_uuid) + org_member = await OrgMemberStore.get_org_member(org_id, user_uuid) return org_member is not None except Exception as e: logger.warning( @@ -540,7 +540,7 @@ class OrgService: raise ValueError(f'Organization with ID {org_id} not found') # Check if user is a member of this organization - if not OrgService.is_org_member(user_id, org_id): + if not await OrgService.is_org_member(user_id, org_id): logger.warning( 'Non-member attempted to update organization', extra={ @@ -574,7 +574,7 @@ class OrgService: llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data) if llm_fields_being_updated: # Verify user has admin or owner role - has_permission = OrgService.has_admin_or_owner_role(user_id, org_id) + has_permission = await OrgService.has_admin_or_owner_role(user_id, org_id) if not has_permission: logger.warning( 'User attempted to update LLM settings without permission', @@ -745,7 +745,7 @@ class OrgService: ) # Verify user is a member of the organization - org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id)) + org_member = await OrgMemberStore.get_org_member(org_id, parse_uuid(user_id)) if not org_member: logger.warning( 'User is not a member of organization or organization does not exist', @@ -792,12 +792,12 @@ class OrgService: raise OrgNotFoundError(str(org_id)) # Check if user is a member of the organization - org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id)) + org_member = await OrgMemberStore.get_org_member(org_id, parse_uuid(user_id)) if not org_member: raise OrgAuthorizationError('User is not a member of this organization') # Check if user has owner role - role = RoleStore.get_role_by_id(org_member.role_id) + role = await RoleStore.get_role_by_id_async(org_member.role_id) if not role or role.name != 'owner': raise OrgAuthorizationError( 'Only organization owners can delete organizations' @@ -918,7 +918,7 @@ class OrgService: raise OrgNotFoundError(str(org_id)) # Step 2: Validate user is a member of the organization - if not OrgService.is_org_member(user_id, org_id): + if not await OrgService.is_org_member(user_id, org_id): logger.warning( 'User attempted to switch to organization they are not a member of', extra={'user_id': user_id, 'org_id': str(org_id)}, diff --git a/enterprise/tests/unit/server/routes/test_orgs.py b/enterprise/tests/unit/server/routes/test_orgs.py index 8462249dde..7aec94c847 100644 --- a/enterprise/tests/unit/server/routes/test_orgs.py +++ b/enterprise/tests/unit/server/routes/test_orgs.py @@ -1027,7 +1027,7 @@ async def test_get_org_success(mock_app_with_get_user_id, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1067,7 +1067,7 @@ async def test_get_org_user_not_member(mock_app_with_get_user_id): # When user is not a member, get_user_org_role returns None with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=None), ): client = TestClient(mock_app_with_get_user_id) @@ -1092,7 +1092,7 @@ async def test_get_org_not_found(mock_app_with_get_user_id, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1167,7 +1167,7 @@ async def test_get_org_unexpected_error(mock_app_with_get_user_id, mock_owner_ro with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1218,7 +1218,7 @@ async def test_get_org_personal_workspace(): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ), patch( @@ -1260,7 +1260,7 @@ async def test_get_org_team_workspace(mock_app_with_get_user_id, mock_owner_role with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1305,7 +1305,7 @@ async def test_get_org_with_credits_none(mock_app_with_get_user_id, mock_owner_r with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1354,7 +1354,7 @@ async def test_get_org_sensitive_fields_not_exposed( with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1404,7 +1404,7 @@ async def test_delete_org_success(mock_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1439,7 +1439,7 @@ async def test_delete_org_not_found(mock_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1469,7 +1469,7 @@ async def test_delete_org_not_owner(mock_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1503,7 +1503,7 @@ async def test_delete_org_not_member(mock_app): # When user is not a member, get_user_org_role returns None with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=None), ): client = TestClient(mock_app) @@ -1528,7 +1528,7 @@ async def test_delete_org_database_failure(mock_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1558,7 +1558,7 @@ async def test_delete_org_unexpected_error(mock_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1606,7 +1606,7 @@ async def test_delete_org_unauthorized(mock_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1636,7 +1636,7 @@ async def test_delete_org_orphaned_users(mock_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1708,7 +1708,7 @@ async def test_update_org_personal_workspace_preserved(): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ), patch( @@ -1769,7 +1769,7 @@ async def test_update_org_team_workspace_preserved(): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ), patch( @@ -1809,7 +1809,7 @@ async def test_update_org_not_found(mock_update_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1845,7 +1845,7 @@ async def test_update_org_permission_denied_non_member(mock_update_app): # When user is not a member, get_user_org_role returns None with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=None), ): async with httpx.AsyncClient( @@ -1876,7 +1876,7 @@ async def test_update_org_permission_denied_llm_settings( with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1917,7 +1917,7 @@ async def test_update_org_duplicate_name_returns_409(mock_update_app, mock_owner with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1951,7 +1951,7 @@ async def test_update_org_database_error(mock_update_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -1985,7 +1985,7 @@ async def test_update_org_unexpected_error(mock_update_app, mock_owner_role): with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ), patch( @@ -2041,7 +2041,7 @@ async def test_update_org_invalid_field_values(mock_update_app, mock_owner_role) update_data = {'default_max_iterations': -1} # Invalid: must be > 0 with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ): async with httpx.AsyncClient( @@ -2068,7 +2068,7 @@ async def test_update_org_empty_name_returns_422(mock_update_app, mock_owner_rol update_data = {'name': ' '} with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ): async with httpx.AsyncClient( @@ -2095,7 +2095,7 @@ async def test_update_org_invalid_email_format(mock_update_app, mock_owner_role) update_data = {'contact_email': 'invalid-email'} # Missing @ with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_owner_role), ): async with httpx.AsyncClient( @@ -3031,6 +3031,7 @@ class TestGetMeEndpoint: with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, return_value=me_response, ): client = TestClient(mock_me_app) @@ -3066,6 +3067,7 @@ class TestGetMeEndpoint: with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, return_value=me_response, ): client = TestClient(mock_me_app) @@ -3086,6 +3088,7 @@ class TestGetMeEndpoint: """ with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, side_effect=OrgMemberNotFoundError(str(test_org_id), 'user-id'), ): client = TestClient(mock_me_app) @@ -3131,6 +3134,7 @@ class TestGetMeEndpoint: """ with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, side_effect=RuntimeError('Database connection failed'), ): client = TestClient(mock_me_app) @@ -3157,6 +3161,7 @@ class TestGetMeEndpoint: with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, return_value=me_response, ): client = TestClient(mock_me_app) @@ -3185,6 +3190,7 @@ class TestGetMeEndpoint: with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, return_value=me_response, ): client = TestClient(mock_me_app) @@ -3210,6 +3216,7 @@ class TestGetMeEndpoint: with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, return_value=me_response, ): client = TestClient(mock_me_app) @@ -3230,6 +3237,7 @@ class TestGetMeEndpoint: """ with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, side_effect=RoleNotFoundError(role_id=999), ): client = TestClient(mock_me_app) @@ -3250,6 +3258,7 @@ class TestGetMeEndpoint: with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, return_value=me_response, ): result = await get_me(org_id=test_org_id, user_id=test_user_id) @@ -3266,6 +3275,7 @@ class TestGetMeEndpoint: """Test direct function call to get_me raises HTTPException on member not found.""" with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, side_effect=OrgMemberNotFoundError(str(test_org_id), test_user_id), ): with pytest.raises(HTTPException) as exc_info: @@ -3281,6 +3291,7 @@ class TestGetMeEndpoint: """Test direct function call to get_me raises HTTPException on role not found.""" with patch( 'server.routes.orgs.OrgMemberService.get_me', + new_callable=AsyncMock, side_effect=RoleNotFoundError(role_id=999), ): with pytest.raises(HTTPException) as exc_info: @@ -3453,7 +3464,7 @@ async def test_get_org_app_settings_success( with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ), patch( @@ -3493,7 +3504,7 @@ async def test_get_org_app_settings_with_null_values( with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ), patch( @@ -3527,7 +3538,7 @@ async def test_get_org_app_settings_not_found( # Arrange with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ), patch( @@ -3554,7 +3565,7 @@ async def test_get_org_app_settings_user_not_member(mock_app_with_get_user_id): """ # Arrange - user has no role (not a member) with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=None), ): client = TestClient(mock_app_with_get_user_id) @@ -3585,7 +3596,7 @@ async def test_update_org_app_settings_success( with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ), patch( @@ -3632,7 +3643,7 @@ async def test_update_org_app_settings_partial_update( with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ), patch( @@ -3675,7 +3686,7 @@ async def test_update_org_app_settings_set_null( with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ), patch( @@ -3708,7 +3719,7 @@ async def test_update_org_app_settings_invalid_max_budget( """ # Arrange with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ): client = TestClient(mock_app_with_get_user_id) @@ -3734,7 +3745,7 @@ async def test_update_org_app_settings_zero_max_budget( """ # Arrange with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ): client = TestClient(mock_app_with_get_user_id) @@ -3761,7 +3772,7 @@ async def test_update_org_app_settings_not_found( # Arrange with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ), patch( @@ -3794,7 +3805,7 @@ async def test_update_org_app_settings_database_error( # Arrange with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_member_role), ), patch( @@ -3824,7 +3835,7 @@ async def test_update_org_app_settings_user_not_member(mock_app_with_get_user_id """ # Arrange - user has no role (not a member) with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=None), ): client = TestClient(mock_app_with_get_user_id) diff --git a/enterprise/tests/unit/server/services/test_org_member_service.py b/enterprise/tests/unit/server/services/test_org_member_service.py index f3f4aadc13..f992787b0c 100644 --- a/enterprise/tests/unit/server/services/test_org_member_service.py +++ b/enterprise/tests/unit/server/services/test_org_member_service.py @@ -150,7 +150,8 @@ class TestOrgMemberServiceGetOrgMembers: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -188,7 +189,8 @@ class TestOrgMemberServiceGetOrgMembers: """Test that retrieval fails when user is not a member.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.return_value = None @@ -212,7 +214,8 @@ class TestOrgMemberServiceGetOrgMembers: """Test that negative page_id returns error.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.return_value = requester_membership_owner @@ -236,7 +239,8 @@ class TestOrgMemberServiceGetOrgMembers: """Test that non-integer page_id returns error.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.return_value = requester_membership_owner @@ -261,7 +265,8 @@ class TestOrgMemberServiceGetOrgMembers: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -295,7 +300,8 @@ class TestOrgMemberServiceGetOrgMembers: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -329,7 +335,8 @@ class TestOrgMemberServiceGetOrgMembers: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -360,7 +367,8 @@ class TestOrgMemberServiceGetOrgMembers: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -399,7 +407,8 @@ class TestOrgMemberServiceGetOrgMembers: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -439,7 +448,8 @@ class TestOrgMemberServiceGetOrgMembers: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -488,7 +498,8 @@ class TestOrgMemberServiceGetOrgMembers: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -519,7 +530,8 @@ class TestOrgMemberServiceGetOrgMembers: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -551,7 +563,8 @@ class TestOrgMemberServiceGetOrgMembers: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_paginated', @@ -596,7 +609,8 @@ class TestOrgMemberServiceGetOrgMembersCount: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_count', @@ -624,7 +638,8 @@ class TestOrgMemberServiceGetOrgMembersCount: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_member_service.OrgMemberStore.get_org_members_count', @@ -650,7 +665,8 @@ class TestOrgMemberServiceGetOrgMembersCount: """Test that non-member raises OrgMemberNotFoundError.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.return_value = None @@ -690,16 +706,20 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.side_effect = [ @@ -735,16 +755,20 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.side_effect = [ @@ -779,16 +803,20 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.side_effect = [ @@ -815,7 +843,8 @@ class TestOrgMemberServiceRemoveOrgMember: """Test that removing fails when requester is not a member of the organization.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.return_value = None @@ -835,7 +864,8 @@ class TestOrgMemberServiceRemoveOrgMember: """Test that removing fails when trying to remove oneself.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.return_value = requester_membership_owner @@ -860,7 +890,8 @@ class TestOrgMemberServiceRemoveOrgMember: """Test that removing fails when target member is not found.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.side_effect = [requester_membership_owner, None] @@ -887,10 +918,12 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, ): mock_get_member.side_effect = [ @@ -922,16 +955,20 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, patch( 'server.services.org_member_service.LiteLlmManager.remove_user_from_team' @@ -970,10 +1007,12 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, ): mock_get_member.side_effect = [ @@ -1010,10 +1049,12 @@ class TestOrgMemberServiceRemoveOrgMember: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, ): mock_get_member.side_effect = [ @@ -1045,13 +1086,16 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.get_org_members' + 'server.services.org_member_service.OrgMemberStore.get_org_members', + new_callable=AsyncMock, ) as mock_get_members, ): mock_get_member.side_effect = [ @@ -1089,19 +1133,24 @@ class TestOrgMemberServiceRemoveOrgMember: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.get_org_members' + 'server.services.org_member_service.OrgMemberStore.get_org_members', + new_callable=AsyncMock, ) as mock_get_members, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.side_effect = [ @@ -1142,13 +1191,16 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, ): mock_get_member.side_effect = [ @@ -1187,16 +1239,20 @@ class TestOrgMemberServiceRemoveOrgMember: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, patch( 'server.services.org_member_service.UserStore.update_current_org' @@ -1239,16 +1295,20 @@ class TestOrgMemberServiceRemoveOrgMember: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, patch( 'server.services.org_member_service.UserStore.update_current_org' @@ -1287,16 +1347,20 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, patch( 'server.services.org_member_service.UserStore.update_current_org' @@ -1335,16 +1399,20 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, patch( 'server.services.org_member_service.LiteLlmManager.remove_user_from_team', @@ -1385,16 +1453,20 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, patch( 'server.services.org_member_service.LiteLlmManager.remove_user_from_team', @@ -1434,13 +1506,16 @@ class TestOrgMemberServiceRemoveOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.OrgMemberStore.remove_user_from_org' + 'server.services.org_member_service.OrgMemberStore.remove_user_from_org', + new_callable=AsyncMock, ) as mock_remove, patch( 'server.services.org_member_service.LiteLlmManager.remove_user_from_team', @@ -1541,19 +1616,24 @@ class TestOrgMemberServiceUpdateOrgMember: mock_user.email = 'target@example.com' with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name' + 'server.services.org_member_service.RoleStore.get_role_by_name_async', + new_callable=AsyncMock, ) as mock_get_role_by_name, patch( - 'server.services.org_member_service.OrgMemberStore.update_user_role_in_org' + 'server.services.org_member_service.OrgMemberStore.update_user_role_in_org', + new_callable=AsyncMock, ) as mock_update, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.side_effect = [ @@ -1597,19 +1677,24 @@ class TestOrgMemberServiceUpdateOrgMember: mock_user.email = 'target@example.com' with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name' + 'server.services.org_member_service.RoleStore.get_role_by_name_async', + new_callable=AsyncMock, ) as mock_get_role_by_name, patch( - 'server.services.org_member_service.OrgMemberStore.update_user_role_in_org' + 'server.services.org_member_service.OrgMemberStore.update_user_role_in_org', + new_callable=AsyncMock, ) as mock_update, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.side_effect = [ @@ -1651,19 +1736,24 @@ class TestOrgMemberServiceUpdateOrgMember: mock_user.email = 'target@example.com' with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name' + 'server.services.org_member_service.RoleStore.get_role_by_name_async', + new_callable=AsyncMock, ) as mock_get_role_by_name, patch( - 'server.services.org_member_service.OrgMemberStore.update_user_role_in_org' + 'server.services.org_member_service.OrgMemberStore.update_user_role_in_org', + new_callable=AsyncMock, ) as mock_update, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.side_effect = [ @@ -1709,21 +1799,31 @@ class TestOrgMemberServiceUpdateOrgMember: mock_user.email = 'target@example.com' with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name' + 'server.services.org_member_service.RoleStore.get_role_by_name_async', + new_callable=AsyncMock, ) as mock_get_role_by_name, patch( - 'server.services.org_member_service.OrgMemberStore.update_user_role_in_org' + 'server.services.org_member_service.OrgMemberStore.update_user_role_in_org', + new_callable=AsyncMock, ) as mock_update, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, - patch.object(OrgMemberService, '_is_last_owner', return_value=False), + patch.object( + OrgMemberService, + '_is_last_owner', + new_callable=AsyncMock, + return_value=False, + ), ): mock_get_member.side_effect = [ requester_membership_owner, @@ -1754,7 +1854,8 @@ class TestOrgMemberServiceUpdateOrgMember: """GIVEN requester not in org WHEN update_org_member THEN raises OrgMemberNotFoundError.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.return_value = None @@ -1774,7 +1875,8 @@ class TestOrgMemberServiceUpdateOrgMember: """GIVEN requester updates self WHEN update_org_member THEN raises CannotModifySelfError.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.return_value = requester_membership_owner @@ -1799,7 +1901,8 @@ class TestOrgMemberServiceUpdateOrgMember: """GIVEN target not in org WHEN update_org_member THEN raises OrgMemberNotFoundError.""" # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.side_effect = [requester_membership_owner, None] @@ -1827,13 +1930,16 @@ class TestOrgMemberServiceUpdateOrgMember: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name' + 'server.services.org_member_service.RoleStore.get_role_by_name_async', + new_callable=AsyncMock, ) as mock_get_role_by_name, ): mock_get_member.side_effect = [ @@ -1867,19 +1973,23 @@ class TestOrgMemberServiceUpdateOrgMember: # Arrange: patch _can_update_member_role so we reach the last-owner check (owner cannot normally modify owner) with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name' + 'server.services.org_member_service.RoleStore.get_role_by_name_async', + new_callable=AsyncMock, ) as mock_get_role_by_name, patch( 'server.services.org_member_service.OrgMemberService._can_update_member_role' ) as mock_can_update, patch( - 'server.services.org_member_service.OrgMemberService._is_last_owner' + 'server.services.org_member_service.OrgMemberService._is_last_owner', + new_callable=AsyncMock, ) as mock_is_last_owner, ): mock_get_member.side_effect = [ @@ -1918,13 +2028,16 @@ class TestOrgMemberServiceUpdateOrgMember: target_membership_user.status = 'active' with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.side_effect = [ @@ -2006,7 +2119,7 @@ class TestOrgMemberServiceCanUpdateMemberRole: class TestOrgMemberServiceIsLastOwner: """Test cases for OrgMemberService._is_last_owner.""" - def test_is_last_owner_when_only_one_owner( + async def test_is_last_owner_when_only_one_owner( self, org_id, target_user_id, owner_role ): """Test that returns True when user is the only owner.""" @@ -2017,22 +2130,24 @@ class TestOrgMemberServiceIsLastOwner: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_members' + 'server.services.org_member_service.OrgMemberStore.get_org_members', + new_callable=AsyncMock, ) as mock_get_members, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, ): mock_get_members.return_value = [target_membership] mock_get_role.return_value = owner_role # Act - result = OrgMemberService._is_last_owner(org_id, target_user_id) + result = await OrgMemberService._is_last_owner(org_id, target_user_id) # Assert assert result is True - def test_is_not_last_owner_when_multiple_owners( + async def test_is_not_last_owner_when_multiple_owners( self, org_id, target_user_id, owner_role ): """Test that returns False when there are multiple owners.""" @@ -2047,22 +2162,24 @@ class TestOrgMemberServiceIsLastOwner: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_members' + 'server.services.org_member_service.OrgMemberStore.get_org_members', + new_callable=AsyncMock, ) as mock_get_members, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, ): mock_get_members.return_value = [target_membership, another_owner] mock_get_role.return_value = owner_role # Act - result = OrgMemberService._is_last_owner(org_id, target_user_id) + result = await OrgMemberService._is_last_owner(org_id, target_user_id) # Assert assert result is False - def test_is_not_last_owner_when_user_is_not_owner( + async def test_is_not_last_owner_when_user_is_not_owner( self, org_id, target_user_id, member_role ): """Test that returns False when user is not an owner.""" @@ -2073,17 +2190,19 @@ class TestOrgMemberServiceIsLastOwner: with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_members' + 'server.services.org_member_service.OrgMemberStore.get_org_members', + new_callable=AsyncMock, ) as mock_get_members, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, ): mock_get_members.return_value = [target_membership] mock_get_role.return_value = member_role # Act - result = OrgMemberService._is_last_owner(org_id, target_user_id) + result = await OrgMemberService._is_last_owner(org_id, target_user_id) # Assert assert result is False @@ -2115,7 +2234,8 @@ class TestOrgMemberServiceGetMe: user.email = 'test@example.com' return user - def test_get_me_success_returns_me_response( + @pytest.mark.asyncio + async def test_get_me_success_returns_me_response( self, org_id, current_user_id, mock_org_member, mock_user, owner_role ): """GIVEN: User is a member of the organization @@ -2125,13 +2245,16 @@ class TestOrgMemberServiceGetMe: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.return_value = mock_org_member @@ -2139,7 +2262,7 @@ class TestOrgMemberServiceGetMe: mock_get_user.return_value = mock_user # Act - result = OrgMemberService.get_me(org_id, current_user_id) + result = await OrgMemberService.get_me(org_id, current_user_id) # Assert assert isinstance(result, MeResponse) @@ -2151,24 +2274,27 @@ class TestOrgMemberServiceGetMe: assert result.max_iterations == 50 assert result.status == 'active' - def test_get_me_member_not_found_raises_error(self, org_id, current_user_id): + @pytest.mark.asyncio + async def test_get_me_member_not_found_raises_error(self, org_id, current_user_id): """GIVEN: User is not a member of the organization WHEN: get_me is called THEN: Raises OrgMemberNotFoundError """ # Arrange with patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member: mock_get_member.return_value = None # Act & Assert with pytest.raises(OrgMemberNotFoundError) as exc_info: - OrgMemberService.get_me(org_id, current_user_id) + await OrgMemberService.get_me(org_id, current_user_id) assert str(org_id) in str(exc_info.value) - def test_get_me_role_not_found_raises_error( + @pytest.mark.asyncio + async def test_get_me_role_not_found_raises_error( self, org_id, current_user_id, mock_org_member ): """GIVEN: Member exists but role lookup fails @@ -2178,10 +2304,12 @@ class TestOrgMemberServiceGetMe: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, ): mock_get_member.return_value = mock_org_member @@ -2189,11 +2317,12 @@ class TestOrgMemberServiceGetMe: # Act & Assert with pytest.raises(RoleNotFoundError) as exc_info: - OrgMemberService.get_me(org_id, current_user_id) + await OrgMemberService.get_me(org_id, current_user_id) assert exc_info.value.role_id == mock_org_member.role_id - def test_get_me_user_not_found_returns_empty_email( + @pytest.mark.asyncio + async def test_get_me_user_not_found_returns_empty_email( self, org_id, current_user_id, mock_org_member, owner_role ): """GIVEN: Member exists but user lookup returns None @@ -2203,13 +2332,16 @@ class TestOrgMemberServiceGetMe: # Arrange with ( patch( - 'server.services.org_member_service.OrgMemberStore.get_org_member' + 'server.services.org_member_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id' + 'server.services.org_member_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.UserStore.get_user_by_id' + 'server.services.org_member_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, ) as mock_get_user, ): mock_get_member.return_value = mock_org_member @@ -2217,7 +2349,7 @@ class TestOrgMemberServiceGetMe: mock_get_user.return_value = None # Act - result = OrgMemberService.get_me(org_id, current_user_id) + result = await OrgMemberService.get_me(org_id, current_user_id) # Assert assert result.email == '' diff --git a/enterprise/tests/unit/test_authorization.py b/enterprise/tests/unit/test_authorization.py index 237f34c6f3..748389d178 100644 --- a/enterprise/tests/unit/test_authorization.py +++ b/enterprise/tests/unit/test_authorization.py @@ -336,7 +336,8 @@ class TestHasPermission: class TestGetUserOrgRole: """Tests for get_user_org_role function.""" - def test_returns_role_when_member_exists(self): + @pytest.mark.asyncio + async def test_returns_role_when_member_exists(self): """ GIVEN: User is a member of organization with role WHEN: get_user_org_role is called @@ -354,17 +355,20 @@ class TestGetUserOrgRole: with ( patch( 'server.auth.authorization.OrgMemberStore.get_org_member', + new_callable=AsyncMock, return_value=mock_org_member, ), patch( - 'server.auth.authorization.RoleStore.get_role_by_id', + 'server.auth.authorization.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, return_value=mock_role, ), ): - result = get_user_org_role(user_id, org_id) + result = await get_user_org_role(user_id, org_id) assert result == mock_role - def test_returns_none_when_not_member(self): + @pytest.mark.asyncio + async def test_returns_none_when_not_member(self): """ GIVEN: User is not a member of organization WHEN: get_user_org_role is called @@ -375,12 +379,14 @@ class TestGetUserOrgRole: with patch( 'server.auth.authorization.OrgMemberStore.get_org_member', + new_callable=AsyncMock, return_value=None, ): - result = get_user_org_role(user_id, org_id) + result = await get_user_org_role(user_id, org_id) assert result is None - def test_returns_role_when_org_id_is_none(self): + @pytest.mark.asyncio + async def test_returns_role_when_org_id_is_none(self): """ GIVEN: User with a current organization WHEN: get_user_org_role is called with org_id=None @@ -397,22 +403,26 @@ class TestGetUserOrgRole: with ( patch( 'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org', + new_callable=AsyncMock, return_value=mock_org_member, ) as mock_get_current, patch( 'server.auth.authorization.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_org_member, patch( - 'server.auth.authorization.RoleStore.get_role_by_id', + 'server.auth.authorization.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, return_value=mock_role, ), ): - result = get_user_org_role(user_id, None) + result = await get_user_org_role(user_id, None) assert result == mock_role mock_get_current.assert_called_once() mock_get_org_member.assert_not_called() - def test_returns_none_when_org_id_is_none_and_no_current_org(self): + @pytest.mark.asyncio + async def test_returns_none_when_org_id_is_none_and_no_current_org(self): """ GIVEN: User with no current organization membership WHEN: get_user_org_role is called with org_id=None @@ -422,9 +432,10 @@ class TestGetUserOrgRole: with patch( 'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org', + new_callable=AsyncMock, return_value=None, ): - result = get_user_org_role(user_id, None) + result = await get_user_org_role(user_id, None) assert result is None @@ -450,7 +461,7 @@ class TestRequirePermission: mock_role.name = 'admin' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) @@ -484,7 +495,7 @@ class TestRequirePermission: org_id = uuid4() with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=None), ): permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) @@ -508,7 +519,7 @@ class TestRequirePermission: mock_role.name = 'member' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) @@ -532,7 +543,7 @@ class TestRequirePermission: mock_role.name = 'owner' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) @@ -553,7 +564,7 @@ class TestRequirePermission: mock_role.name = 'admin' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) @@ -577,7 +588,7 @@ class TestRequirePermission: with ( patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ), patch('server.auth.authorization.logger') as mock_logger, @@ -605,7 +616,7 @@ class TestRequirePermission: mock_role.name = 'admin' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ) as mock_get_role: permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) @@ -623,7 +634,7 @@ class TestRequirePermission: user_id = str(uuid4()) with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=None), ): permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) @@ -656,7 +667,7 @@ class TestPermissionScenarios: mock_role.name = 'member' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.MANAGE_SECRETS) @@ -677,7 +688,7 @@ class TestPermissionScenarios: mock_role.name = 'member' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ): permission_checker = require_permission( @@ -702,7 +713,7 @@ class TestPermissionScenarios: mock_role.name = 'admin' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ): permission_checker = require_permission( @@ -725,7 +736,7 @@ class TestPermissionScenarios: mock_role.name = 'admin' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER) @@ -748,7 +759,7 @@ class TestPermissionScenarios: mock_role.name = 'owner' with patch( - 'server.auth.authorization.get_user_org_role_async', + 'server.auth.authorization.get_user_org_role', AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER) diff --git a/enterprise/tests/unit/test_org_invitation_service.py b/enterprise/tests/unit/test_org_invitation_service.py index 06c0d258ed..822fd5c5a6 100644 --- a/enterprise/tests/unit/test_org_invitation_service.py +++ b/enterprise/tests/unit/test_org_invitation_service.py @@ -113,14 +113,16 @@ class TestAcceptInvitationEmailValidation: 'server.services.org_invitation_service.TokenManager' ) as mock_token_manager_class, patch( - 'server.services.org_invitation_service.OrgMemberStore.get_org_member' + 'server.services.org_invitation_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_invitation_service.OrgService.create_litellm_integration', new_callable=AsyncMock, ) as mock_create_litellm, patch( - 'server.services.org_invitation_service.OrgMemberStore.add_user_to_org' + 'server.services.org_invitation_service.OrgMemberStore.add_user_to_org', + new_callable=AsyncMock, ), patch( 'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status', @@ -222,14 +224,16 @@ class TestAcceptInvitationEmailValidation: new_callable=AsyncMock, ) as mock_get_user, patch( - 'server.services.org_invitation_service.OrgMemberStore.get_org_member' + 'server.services.org_invitation_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, ) as mock_get_member, patch( 'server.services.org_invitation_service.OrgService.create_litellm_integration', new_callable=AsyncMock, ) as mock_create_litellm, patch( - 'server.services.org_invitation_service.OrgMemberStore.add_user_to_org' + 'server.services.org_invitation_service.OrgMemberStore.add_user_to_org', + new_callable=AsyncMock, ), patch( 'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status', @@ -323,11 +327,13 @@ class TestCreateInvitationsBatch: return_value=mock_inviter_member, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_id', + 'server.services.org_invitation_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, return_value=mock_owner_role, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_name', + 'server.services.org_invitation_service.RoleStore.get_role_by_name_async', + new_callable=AsyncMock, return_value=mock_member_role, ), patch.object( @@ -377,11 +383,13 @@ class TestCreateInvitationsBatch: return_value=mock_inviter_member, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_id', + 'server.services.org_invitation_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, return_value=mock_owner_role, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_name', + 'server.services.org_invitation_service.RoleStore.get_role_by_name_async', + new_callable=AsyncMock, return_value=mock_member_role, ), patch.object( @@ -444,11 +452,13 @@ class TestCreateInvitationsBatch: return_value=mock_inviter_member, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_id', + 'server.services.org_invitation_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, return_value=mock_owner_role, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_name', + 'server.services.org_invitation_service.RoleStore.get_role_by_name_async', + new_callable=AsyncMock, return_value=None, # Invalid role ), ): diff --git a/enterprise/tests/unit/test_org_member_store.py b/enterprise/tests/unit/test_org_member_store.py index 26a0b27ab8..c2901358a8 100644 --- a/enterprise/tests/unit/test_org_member_store.py +++ b/enterprise/tests/unit/test_org_member_store.py @@ -37,19 +37,20 @@ async def async_session_maker(async_engine): return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) -def test_get_org_members(session_maker): +@pytest.mark.asyncio +async def test_get_org_members(async_session_maker): # Test getting org_members by org ID - with session_maker() as session: + async with async_session_maker() as session: # Create test data org = Org(name='test-org') session.add(org) - session.flush() + await session.flush() user1 = User(id=uuid.uuid4(), current_org_id=org.id) user2 = User(id=uuid.uuid4(), current_org_id=org.id) role = Role(name='admin', rank=1) session.add_all([user1, user2, role]) - session.flush() + await session.flush() org_member1 = OrgMember( org_id=org.id, @@ -66,31 +67,32 @@ def test_get_org_members(session_maker): status='active', ) session.add_all([org_member1, org_member2]) - session.commit() + await session.commit() org_id = org.id # Test retrieval - with patch('storage.org_member_store.session_maker', session_maker): - org_members = OrgMemberStore.get_org_members(org_id) + with patch('storage.org_member_store.a_session_maker', async_session_maker): + org_members = await OrgMemberStore.get_org_members(org_id) assert len(org_members) == 2 api_keys = [om.llm_api_key.get_secret_value() for om in org_members] assert 'test-key-1' in api_keys assert 'test-key-2' in api_keys -def test_get_user_orgs(session_maker): +@pytest.mark.asyncio +async def test_get_user_orgs(async_session_maker): # Test getting org_members by user ID - with session_maker() as session: + async with async_session_maker() as session: # Create test data org1 = Org(name='test-org-1') org2 = Org(name='test-org-2') session.add_all([org1, org2]) - session.flush() + await session.flush() user = User(id=uuid.uuid4(), current_org_id=org1.id) role = Role(name='admin', rank=1) session.add_all([user, role]) - session.flush() + await session.flush() org_member1 = OrgMember( org_id=org1.id, @@ -107,30 +109,31 @@ def test_get_user_orgs(session_maker): status='active', ) session.add_all([org_member1, org_member2]) - session.commit() + await session.commit() user_id = user.id # Test retrieval - with patch('storage.org_member_store.session_maker', session_maker): - org_members = OrgMemberStore.get_user_orgs(user_id) + with patch('storage.org_member_store.a_session_maker', async_session_maker): + org_members = await OrgMemberStore.get_user_orgs(user_id) assert len(org_members) == 2 api_keys = [ou.llm_api_key.get_secret_value() for ou in org_members] assert 'test-key-1' in api_keys assert 'test-key-2' in api_keys -def test_get_org_member(session_maker): +@pytest.mark.asyncio +async def test_get_org_member(async_session_maker): # Test getting org_member by org and user ID - with session_maker() as session: + async with async_session_maker() as session: # Create test data org = Org(name='test-org') session.add(org) - session.flush() + await session.flush() user = User(id=uuid.uuid4(), current_org_id=org.id) role = Role(name='admin', rank=1) session.add_all([user, role]) - session.flush() + await session.flush() org_member = OrgMember( org_id=org.id, @@ -140,32 +143,33 @@ def test_get_org_member(session_maker): status='active', ) session.add(org_member) - session.commit() + await session.commit() org_id = org.id user_id = user.id # Test retrieval - with patch('storage.org_member_store.session_maker', session_maker): - retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id) + with patch('storage.org_member_store.a_session_maker', async_session_maker): + retrieved_org_member = await OrgMemberStore.get_org_member(org_id, user_id) assert retrieved_org_member is not None assert retrieved_org_member.org_id == org_id assert retrieved_org_member.user_id == user_id assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key' -def test_get_org_member_for_current_org(session_maker): +@pytest.mark.asyncio +async def test_get_org_member_for_current_org(async_session_maker): # Test getting org_member for user's current organization - with session_maker() as session: + async with async_session_maker() as session: # Create test data - user belongs to two orgs but current_org is org1 org1 = Org(name='test-org-1') org2 = Org(name='test-org-2') session.add_all([org1, org2]) - session.flush() + await session.flush() user = User(id=uuid.uuid4(), current_org_id=org1.id) role = Role(name='admin', rank=1) session.add_all([user, role]) - session.flush() + await session.flush() org_member1 = OrgMember( org_id=org1.id, @@ -182,47 +186,51 @@ def test_get_org_member_for_current_org(session_maker): status='active', ) session.add_all([org_member1, org_member2]) - session.commit() + await session.commit() user_id = user.id org1_id = org1.id # Test retrieval - should return org_member for current_org (org1) - with patch('storage.org_member_store.session_maker', session_maker): - retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(user_id) + with patch('storage.org_member_store.a_session_maker', async_session_maker): + retrieved_org_member = await OrgMemberStore.get_org_member_for_current_org( + user_id + ) assert retrieved_org_member is not None assert retrieved_org_member.org_id == org1_id assert retrieved_org_member.user_id == user_id assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key-1' -def test_get_org_member_for_current_org_user_not_found(session_maker): +@pytest.mark.asyncio +async def test_get_org_member_for_current_org_user_not_found(async_session_maker): # Test getting org_member for non-existent user - with patch('storage.org_member_store.session_maker', session_maker): - retrieved_org_member = OrgMemberStore.get_org_member_for_current_org( + with patch('storage.org_member_store.a_session_maker', async_session_maker): + retrieved_org_member = await OrgMemberStore.get_org_member_for_current_org( uuid.uuid4() ) assert retrieved_org_member is None -def test_add_user_to_org(session_maker): +@pytest.mark.asyncio +async def test_add_user_to_org(async_session_maker): # Test adding a user to an org - with session_maker() as session: + async with async_session_maker() as session: # Create test data org = Org(name='test-org') session.add(org) - session.flush() + await session.flush() user = User(id=uuid.uuid4(), current_org_id=org.id) role = Role(name='admin', rank=1) session.add_all([user, role]) - session.commit() + await session.commit() org_id = org.id user_id = user.id role_id = role.id # Test creation - with patch('storage.org_member_store.session_maker', session_maker): - org_member = OrgMemberStore.add_user_to_org( + with patch('storage.org_member_store.a_session_maker', async_session_maker): + org_member = await OrgMemberStore.add_user_to_org( org_id=org_id, user_id=user_id, role_id=role_id, @@ -238,19 +246,20 @@ def test_add_user_to_org(session_maker): assert org_member.status == 'active' -def test_update_user_role_in_org(session_maker): +@pytest.mark.asyncio +async def test_update_user_role_in_org(async_session_maker): # Test updating user role in org - with session_maker() as session: + async with async_session_maker() as session: # Create test data org = Org(name='test-org') session.add(org) - session.flush() + await session.flush() user = User(id=uuid.uuid4(), current_org_id=org.id) role1 = Role(name='admin', rank=1) role2 = Role(name='user', rank=2) session.add_all([user, role1, role2]) - session.flush() + await session.flush() org_member = OrgMember( org_id=org.id, @@ -260,14 +269,14 @@ def test_update_user_role_in_org(session_maker): status='active', ) session.add(org_member) - session.commit() + await session.commit() org_id = org.id user_id = user.id role2_id = role2.id # Test update - with patch('storage.org_member_store.session_maker', session_maker): - updated_org_member = OrgMemberStore.update_user_role_in_org( + with patch('storage.org_member_store.a_session_maker', async_session_maker): + updated_org_member = await OrgMemberStore.update_user_role_in_org( org_id=org_id, user_id=user_id, role_id=role2_id, status='inactive' ) @@ -276,29 +285,31 @@ def test_update_user_role_in_org(session_maker): assert updated_org_member.status == 'inactive' -def test_update_user_role_in_org_not_found(session_maker): +@pytest.mark.asyncio +async def test_update_user_role_in_org_not_found(async_session_maker): # Test updating org_member that doesn't exist from uuid import uuid4 - with patch('storage.org_member_store.session_maker', session_maker): - updated_org_member = OrgMemberStore.update_user_role_in_org( - org_id=uuid4(), user_id=99999, role_id=1 + with patch('storage.org_member_store.a_session_maker', async_session_maker): + updated_org_member = await OrgMemberStore.update_user_role_in_org( + org_id=uuid4(), user_id=uuid4(), role_id=1 ) assert updated_org_member is None -def test_remove_user_from_org(session_maker): +@pytest.mark.asyncio +async def test_remove_user_from_org(async_session_maker): # Test removing a user from an org - with session_maker() as session: + async with async_session_maker() as session: # Create test data org = Org(name='test-org') session.add(org) - session.flush() + await session.flush() user = User(id=uuid.uuid4(), current_org_id=org.id) role = Role(name='admin', rank=1) session.add_all([user, role]) - session.flush() + await session.flush() org_member = OrgMember( org_id=org.id, @@ -308,26 +319,27 @@ def test_remove_user_from_org(session_maker): status='active', ) session.add(org_member) - session.commit() + await session.commit() org_id = org.id user_id = user.id # Test removal - with patch('storage.org_member_store.session_maker', session_maker): - result = OrgMemberStore.remove_user_from_org(org_id, user_id) + with patch('storage.org_member_store.a_session_maker', async_session_maker): + result = await OrgMemberStore.remove_user_from_org(org_id, user_id) assert result is True # Verify it's removed - retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id) + retrieved_org_member = await OrgMemberStore.get_org_member(org_id, user_id) assert retrieved_org_member is None -def test_remove_user_from_org_not_found(session_maker): +@pytest.mark.asyncio +async def test_remove_user_from_org_not_found(async_session_maker): # Test removing user from org that doesn't exist from uuid import uuid4 - with patch('storage.org_member_store.session_maker', session_maker): - result = OrgMemberStore.remove_user_from_org(uuid4(), 99999) + with patch('storage.org_member_store.a_session_maker', async_session_maker): + result = await OrgMemberStore.remove_user_from_org(uuid4(), uuid4()) assert result is False diff --git a/enterprise/tests/unit/test_org_service.py b/enterprise/tests/unit/test_org_service.py index 3b823d9994..bf71d1c1d7 100644 --- a/enterprise/tests/unit/test_org_service.py +++ b/enterprise/tests/unit/test_org_service.py @@ -76,7 +76,7 @@ async def test_validate_name_uniqueness_with_unique_name(async_session_maker): # Act & Assert - should not raise with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker'), + patch('storage.org_member_store.a_session_maker'), patch('storage.role_store.session_maker'), ): await OrgService.validate_name_uniqueness(unique_name) @@ -591,7 +591,10 @@ async def test_get_org_by_id_success(session_maker, owner_role): ) with ( - patch('storage.org_service.OrgMemberStore.get_org_member') as mock_get_member, + patch( + 'storage.org_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, + ) as mock_get_member, patch( 'storage.org_service.OrgStore.get_org_by_id', new_callable=AsyncMock, @@ -624,6 +627,7 @@ async def test_get_org_by_id_user_not_member(): with patch( 'storage.org_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, return_value=None, ): # Act & Assert @@ -656,6 +660,7 @@ async def test_get_org_by_id_org_not_found(): with ( patch( 'storage.org_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, return_value=mock_org_member, ), patch( @@ -834,10 +839,13 @@ async def test_verify_owner_authorization_success(session_maker, owner_role): ), patch( 'storage.org_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, return_value=mock_org_member, ), patch( - 'storage.org_service.RoleStore.get_role_by_id', return_value=mock_owner_role + 'storage.org_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, + return_value=mock_owner_role, ), ): # Act & Assert - should not raise @@ -891,7 +899,11 @@ async def test_verify_owner_authorization_user_not_member(session_maker, owner_r new_callable=AsyncMock, return_value=mock_org, ), - patch('storage.org_service.OrgMemberStore.get_org_member', return_value=None), + patch( + 'storage.org_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, + return_value=None, + ), ): # Act & Assert with pytest.raises(OrgAuthorizationError) as exc_info: @@ -934,9 +946,14 @@ async def test_verify_owner_authorization_user_not_owner(session_maker): ), patch( 'storage.org_service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, return_value=mock_org_member, ), - patch('storage.org_service.RoleStore.get_role_by_id', return_value=admin_role), + patch( + 'storage.org_service.RoleStore.get_role_by_id_async', + new_callable=AsyncMock, + return_value=admin_role, + ), ): # Act & Assert with pytest.raises(OrgAuthorizationError) as exc_info: @@ -1118,7 +1135,7 @@ async def test_update_org_with_permissions_success_non_llm_fields( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), ): # Act @@ -1181,8 +1198,8 @@ async def test_update_org_with_permissions_success_llm_fields_admin( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act result = await OrgService.update_org_with_permissions( @@ -1243,8 +1260,8 @@ async def test_update_org_with_permissions_success_llm_fields_owner( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act result = await OrgService.update_org_with_permissions( @@ -1306,8 +1323,8 @@ async def test_update_org_with_permissions_success_mixed_fields_admin( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act result = await OrgService.update_org_with_permissions( @@ -1366,7 +1383,7 @@ async def test_update_org_with_permissions_empty_update( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), ): # Act @@ -1401,7 +1418,7 @@ async def test_update_org_with_permissions_org_not_found( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), ): # Act & Assert @@ -1449,7 +1466,7 @@ async def test_update_org_with_permissions_non_member( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), ): # Act & Assert @@ -1507,7 +1524,7 @@ async def test_update_org_with_permissions_llm_fields_insufficient_permission( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), ): # Act & Assert @@ -1567,7 +1584,7 @@ async def test_update_org_with_permissions_database_error( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), patch( 'storage.org_service.OrgStore.update_org', @@ -1622,7 +1639,7 @@ async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), patch( 'storage.org_service.OrgStore.get_org_by_id', @@ -1675,7 +1692,7 @@ async def test_update_org_with_permissions_same_name_allowed( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), patch( 'storage.org_service.OrgStore.get_org_by_id', @@ -1753,8 +1770,8 @@ async def test_update_org_with_permissions_only_llm_fields( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act result = await OrgService.update_org_with_permissions( @@ -1817,7 +1834,7 @@ async def test_update_org_with_permissions_only_non_llm_fields( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.org_member_store.session_maker', session_maker), + patch('storage.org_member_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), ): # Act