diff --git a/enterprise/server/auth/authorization.py b/enterprise/server/auth/authorization.py index 4595ea07a3..c8d72021c6 100644 --- a/enterprise/server/auth/authorization.py +++ b/enterprise/server/auth/authorization.py @@ -179,7 +179,7 @@ async def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None: if not org_member: return None - return await RoleStore.get_role_by_id_async(org_member.role_id) + return await RoleStore.get_role_by_id(org_member.role_id) def get_role_permissions(role_name: str) -> frozenset[Permission]: diff --git a/enterprise/server/services/org_invitation_service.py b/enterprise/server/services/org_invitation_service.py index 5518ab5dd0..3ef0d9d8fa 100644 --- a/enterprise/server/services/org_invitation_service.py +++ b/enterprise/server/services/org_invitation_service.py @@ -91,7 +91,7 @@ class OrgInvitationService: 'You are not a member of this organization' ) - inviter_role = await RoleStore.get_role_by_id_async(inviter_member.role_id) + inviter_role = await RoleStore.get_role_by_id(inviter_member.role_id) if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]: raise InsufficientPermissionError('Only owners and admins can invite users') @@ -101,7 +101,7 @@ class OrgInvitationService: raise InsufficientPermissionError('Only owners can invite with owner role') # Get the target role - target_role = await RoleStore.get_role_by_name_async(role_name_lower) + target_role = await RoleStore.get_role_by_name(role_name_lower) if not target_role: raise ValueError(f'Invalid role: {role_name}') @@ -204,7 +204,7 @@ class OrgInvitationService: 'You are not a member of this organization' ) - inviter_role = await RoleStore.get_role_by_id_async(inviter_member.role_id) + inviter_role = await RoleStore.get_role_by_id(inviter_member.role_id) if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]: raise InsufficientPermissionError('Only owners and admins can invite users') @@ -212,7 +212,7 @@ class OrgInvitationService: if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER: raise InsufficientPermissionError('Only owners can invite with owner role') - target_role = await RoleStore.get_role_by_name_async(role_name_lower) + target_role = await RoleStore.get_role_by_name(role_name_lower) if not target_role: raise ValueError(f'Invalid role: {role_name}') diff --git a/enterprise/server/services/org_member_service.py b/enterprise/server/services/org_member_service.py index 264d8fa135..7168d0954e 100644 --- a/enterprise/server/services/org_member_service.py +++ b/enterprise/server/services/org_member_service.py @@ -51,7 +51,7 @@ class OrgMemberService: raise OrgMemberNotFoundError(str(org_id), str(user_id)) # Resolve role name from role_id - role = await RoleStore.get_role_by_id_async(org_member.role_id) + role = await RoleStore.get_role_by_id(org_member.role_id) if role is None: raise RoleNotFoundError(org_member.role_id) @@ -195,10 +195,8 @@ class OrgMemberService: if not target_membership: return False, 'member_not_found' - requester_role = await RoleStore.get_role_by_id_async( - requester_membership.role_id - ) - target_role = await RoleStore.get_role_by_id_async(target_membership.role_id) + requester_role = await RoleStore.get_role_by_id(requester_membership.role_id) + target_role = await RoleStore.get_role_by_id(target_membership.role_id) if not requester_role or not target_role: return False, 'role_not_found' @@ -300,10 +298,8 @@ class OrgMemberService: raise OrgMemberNotFoundError(str(org_id), str(target_user_id)) # Get roles - requester_role = await RoleStore.get_role_by_id_async( - requester_membership.role_id - ) - target_role = await RoleStore.get_role_by_id_async(target_membership.role_id) + requester_role = await RoleStore.get_role_by_id(requester_membership.role_id) + target_role = await RoleStore.get_role_by_id(target_membership.role_id) if not requester_role: raise RoleNotFoundError(requester_membership.role_id) @@ -323,7 +319,7 @@ class OrgMemberService: ) # Validate new role exists - new_role = await RoleStore.get_role_by_name_async(new_role_name.lower()) + new_role = await RoleStore.get_role_by_name(new_role_name.lower()) if not new_role: raise InvalidRoleError(new_role_name) @@ -406,7 +402,7 @@ class OrgMemberService: owners = [] for m in members: # Use role_id (column) instead of role (relationship) to avoid DetachedInstanceError - role = await RoleStore.get_role_by_id_async(m.role_id) + role = await RoleStore.get_role_by_id(m.role_id) if role and role.name == ROLE_OWNER: owners.append(m) return len(owners) == 1 and str(owners[0].user_id) == str(user_id) diff --git a/enterprise/storage/org_service.py b/enterprise/storage/org_service.py index 780fca890e..a5108137dc 100644 --- a/enterprise/storage/org_service.py +++ b/enterprise/storage/org_service.py @@ -130,7 +130,7 @@ class OrgService: setattr(org, key, value) @staticmethod - def get_owner_role(): + async def get_owner_role(): """ Get the owner role from the database. @@ -140,7 +140,7 @@ class OrgService: Raises: Exception: If owner role not found """ - owner_role = RoleStore.get_role_by_name('owner') + owner_role = await RoleStore.get_role_by_name('owner') if not owner_role: raise Exception('Owner role not found in database') return owner_role @@ -237,7 +237,7 @@ class OrgService: OrgService.apply_litellm_settings_to_org(org, settings) # Step 6: Get owner role and create member entity - owner_role = OrgService.get_owner_role() + owner_role = await OrgService.get_owner_role() org_member = OrgService.create_org_member_entity( org_id=org_id, user_id=user_id, @@ -420,7 +420,7 @@ class OrgService: return False # Get the role details - role = await RoleStore.get_role_by_id_async(org_member.role_id) + role = await RoleStore.get_role_by_id(org_member.role_id) if not role: return False @@ -797,7 +797,7 @@ class OrgService: raise OrgAuthorizationError('User is not a member of this organization') # Check if user has owner role - role = await RoleStore.get_role_by_id_async(org_member.role_id) + role = await RoleStore.get_role_by_id(org_member.role_id) if not role or role.name != 'owner': raise OrgAuthorizationError( 'Only organization owners can delete organizations' diff --git a/enterprise/storage/role_store.py b/enterprise/storage/role_store.py index fa35cc461f..9f5d028b3d 100644 --- a/enterprise/storage/role_store.py +++ b/enterprise/storage/role_store.py @@ -2,11 +2,11 @@ Store class for managing roles. """ -from typing import List, Optional +from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from storage.database import a_session_maker, session_maker +from storage.database import a_session_maker from storage.role import Role @@ -14,57 +14,70 @@ class RoleStore: """Store for managing roles.""" @staticmethod - def create_role(name: str, rank: int) -> Role: + async def _create_role(name: str, rank: int, session: AsyncSession) -> Role: + role = Role(name=name, rank=rank) + session.add(role) + await session.flush() + await session.refresh(role) + return role + + @staticmethod + async def create_role( + name: str, + rank: int, + session: Optional[AsyncSession] = None, + ) -> Role: """Create a new role.""" - with session_maker() as session: - role = Role(name=name, rank=rank) - session.add(role) - session.commit() - session.refresh(role) + if session is not None: + return await RoleStore._create_role(name, rank, session) + async with a_session_maker() as new_session: + role = await RoleStore._create_role(name, rank, new_session) + await new_session.commit() return role @staticmethod - def get_role_by_id(role_id: int) -> Optional[Role]: - """Get role by ID.""" - with session_maker() as session: - return session.query(Role).filter(Role.id == role_id).first() + async def _get_role_by_id(role_id: int, session: AsyncSession) -> Optional[Role]: + result = await session.execute(select(Role).where(Role.id == role_id)) + return result.scalars().first() @staticmethod - async def get_role_by_id_async( + async def get_role_by_id( role_id: int, session: Optional[AsyncSession] = None, ) -> Optional[Role]: - """Get role by ID (async version).""" + """Get role by ID.""" if session is not None: - result = await session.execute(select(Role).where(Role.id == role_id)) - return result.scalars().first() - - async with a_session_maker() as session: - result = await session.execute(select(Role).where(Role.id == role_id)) - return result.scalars().first() + return await RoleStore._get_role_by_id(role_id, session) + async with a_session_maker() as new_session: + return await RoleStore._get_role_by_id(role_id, new_session) @staticmethod - def get_role_by_name(name: str) -> Optional[Role]: - """Get role by name.""" - with session_maker() as session: - return session.query(Role).filter(Role.name == name).first() + async def _get_role_by_name(name: str, session: AsyncSession) -> Optional[Role]: + result = await session.execute(select(Role).where(Role.name == name)) + return result.scalars().first() @staticmethod - async def get_role_by_name_async( + async def get_role_by_name( name: str, session: Optional[AsyncSession] = None, ) -> Optional[Role]: """Get role by name.""" if session is not None: - result = await session.execute(select(Role).where(Role.name == name)) - return result.scalars().first() - - async with a_session_maker() as session: - result = await session.execute(select(Role).where(Role.name == name)) - return result.scalars().first() + return await RoleStore._get_role_by_name(name, session) + async with a_session_maker() as new_session: + return await RoleStore._get_role_by_name(name, new_session) @staticmethod - def list_roles() -> List[Role]: + async def _list_roles(session: AsyncSession) -> list[Role]: + result = await session.execute(select(Role).order_by(Role.rank)) + return list(result.scalars().all()) + + @staticmethod + async def list_roles( + session: Optional[AsyncSession] = None, + ) -> list[Role]: """List all roles.""" - with session_maker() as session: - return session.query(Role).order_by(Role.rank).all() + if session is not None: + return await RoleStore._list_roles(session) + async with a_session_maker() as new_session: + return await RoleStore._list_roles(new_session) diff --git a/enterprise/storage/user_store.py b/enterprise/storage/user_store.py index 67585f154d..8c20bd013c 100644 --- a/enterprise/storage/user_store.py +++ b/enterprise/storage/user_store.py @@ -87,7 +87,7 @@ class UserStore: user.email_verified = user_info.get('email_verified') session.add(user) - role = RoleStore.get_role_by_name('owner') + role = await RoleStore.get_role_by_name('owner') if role is None: raise ValueError('Owner role not found in database') @@ -266,7 +266,7 @@ class UserStore: 'user_store:migrate_user:calling_get_role_by_name', extra={'user_id': user_id}, ) - role = await RoleStore.get_role_by_name_async('owner') + role = await RoleStore.get_role_by_name('owner') logger.debug( 'user_store:migrate_user:done_get_role_by_name', extra={'user_id': user_id}, diff --git a/enterprise/tests/unit/server/services/test_org_member_service.py b/enterprise/tests/unit/server/services/test_org_member_service.py index f992787b0c..001b958c03 100644 --- a/enterprise/tests/unit/server/services/test_org_member_service.py +++ b/enterprise/tests/unit/server/services/test_org_member_service.py @@ -710,7 +710,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -759,7 +759,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -807,7 +807,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -922,7 +922,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, ): @@ -959,7 +959,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1011,7 +1011,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, ): @@ -1053,7 +1053,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, ): @@ -1090,7 +1090,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1137,7 +1137,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1195,7 +1195,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1243,7 +1243,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1299,7 +1299,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1351,7 +1351,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1403,7 +1403,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1457,7 +1457,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1510,7 +1510,7 @@ class TestOrgMemberServiceRemoveOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -1620,11 +1620,11 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name_async', + 'server.services.org_member_service.RoleStore.get_role_by_name', new_callable=AsyncMock, ) as mock_get_role_by_name, patch( @@ -1681,11 +1681,11 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name_async', + 'server.services.org_member_service.RoleStore.get_role_by_name', new_callable=AsyncMock, ) as mock_get_role_by_name, patch( @@ -1740,11 +1740,11 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name_async', + 'server.services.org_member_service.RoleStore.get_role_by_name', new_callable=AsyncMock, ) as mock_get_role_by_name, patch( @@ -1803,11 +1803,11 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name_async', + 'server.services.org_member_service.RoleStore.get_role_by_name', new_callable=AsyncMock, ) as mock_get_role_by_name, patch( @@ -1934,11 +1934,11 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name_async', + 'server.services.org_member_service.RoleStore.get_role_by_name', new_callable=AsyncMock, ) as mock_get_role_by_name, ): @@ -1977,11 +1977,11 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( - 'server.services.org_member_service.RoleStore.get_role_by_name_async', + 'server.services.org_member_service.RoleStore.get_role_by_name', new_callable=AsyncMock, ) as mock_get_role_by_name, patch( @@ -2032,7 +2032,7 @@ class TestOrgMemberServiceUpdateOrgMember: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -2134,7 +2134,7 @@ class TestOrgMemberServiceIsLastOwner: new_callable=AsyncMock, ) as mock_get_members, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, ): @@ -2166,7 +2166,7 @@ class TestOrgMemberServiceIsLastOwner: new_callable=AsyncMock, ) as mock_get_members, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, ): @@ -2194,7 +2194,7 @@ class TestOrgMemberServiceIsLastOwner: new_callable=AsyncMock, ) as mock_get_members, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, ): @@ -2249,7 +2249,7 @@ class TestOrgMemberServiceGetMe: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( @@ -2308,7 +2308,7 @@ class TestOrgMemberServiceGetMe: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, ): @@ -2336,7 +2336,7 @@ class TestOrgMemberServiceGetMe: new_callable=AsyncMock, ) as mock_get_member, patch( - 'server.services.org_member_service.RoleStore.get_role_by_id_async', + 'server.services.org_member_service.RoleStore.get_role_by_id', new_callable=AsyncMock, ) as mock_get_role, patch( diff --git a/enterprise/tests/unit/test_authorization.py b/enterprise/tests/unit/test_authorization.py index 748389d178..c751e6454a 100644 --- a/enterprise/tests/unit/test_authorization.py +++ b/enterprise/tests/unit/test_authorization.py @@ -359,7 +359,7 @@ class TestGetUserOrgRole: return_value=mock_org_member, ), patch( - 'server.auth.authorization.RoleStore.get_role_by_id_async', + 'server.auth.authorization.RoleStore.get_role_by_id', new_callable=AsyncMock, return_value=mock_role, ), @@ -411,7 +411,7 @@ class TestGetUserOrgRole: new_callable=AsyncMock, ) as mock_get_org_member, patch( - 'server.auth.authorization.RoleStore.get_role_by_id_async', + 'server.auth.authorization.RoleStore.get_role_by_id', new_callable=AsyncMock, return_value=mock_role, ), diff --git a/enterprise/tests/unit/test_org_invitation_service.py b/enterprise/tests/unit/test_org_invitation_service.py index 822fd5c5a6..5f797dedde 100644 --- a/enterprise/tests/unit/test_org_invitation_service.py +++ b/enterprise/tests/unit/test_org_invitation_service.py @@ -327,12 +327,12 @@ class TestCreateInvitationsBatch: return_value=mock_inviter_member, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_id_async', + 'server.services.org_invitation_service.RoleStore.get_role_by_id', new_callable=AsyncMock, return_value=mock_owner_role, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_name_async', + 'server.services.org_invitation_service.RoleStore.get_role_by_name', new_callable=AsyncMock, return_value=mock_member_role, ), @@ -383,12 +383,12 @@ class TestCreateInvitationsBatch: return_value=mock_inviter_member, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_id_async', + 'server.services.org_invitation_service.RoleStore.get_role_by_id', new_callable=AsyncMock, return_value=mock_owner_role, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_name_async', + 'server.services.org_invitation_service.RoleStore.get_role_by_name', new_callable=AsyncMock, return_value=mock_member_role, ), @@ -452,12 +452,12 @@ class TestCreateInvitationsBatch: return_value=mock_inviter_member, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_id_async', + 'server.services.org_invitation_service.RoleStore.get_role_by_id', new_callable=AsyncMock, return_value=mock_owner_role, ), patch( - 'server.services.org_invitation_service.RoleStore.get_role_by_name_async', + 'server.services.org_invitation_service.RoleStore.get_role_by_name', new_callable=AsyncMock, return_value=None, # Invalid role ), diff --git a/enterprise/tests/unit/test_org_service.py b/enterprise/tests/unit/test_org_service.py index bf71d1c1d7..94edcbff3f 100644 --- a/enterprise/tests/unit/test_org_service.py +++ b/enterprise/tests/unit/test_org_service.py @@ -77,7 +77,7 @@ async def test_validate_name_uniqueness_with_unique_name(async_session_maker): with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker'), - patch('storage.role_store.session_maker'), + patch('storage.role_store.a_session_maker'), ): await OrgService.validate_name_uniqueness(unique_name) @@ -132,7 +132,7 @@ async def test_create_org_with_owner_success( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), patch( 'storage.org_service.UserStore.create_default_settings', AsyncMock(return_value=mock_settings), @@ -200,7 +200,7 @@ async def test_create_org_with_owner_duplicate_name( # Act & Assert with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), patch( 'storage.org_service.UserStore.create_default_settings', mock_create_settings, @@ -276,7 +276,7 @@ async def test_create_org_with_owner_database_failure_triggers_cleanup( with ( patch('storage.org_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), patch( 'storage.org_service.UserStore.create_default_settings', AsyncMock(return_value=mock_settings), @@ -843,7 +843,7 @@ async def test_verify_owner_authorization_success(session_maker, owner_role): return_value=mock_org_member, ), patch( - 'storage.org_service.RoleStore.get_role_by_id_async', + 'storage.org_service.RoleStore.get_role_by_id', new_callable=AsyncMock, return_value=mock_owner_role, ), @@ -950,7 +950,7 @@ async def test_verify_owner_authorization_user_not_owner(session_maker): return_value=mock_org_member, ), patch( - 'storage.org_service.RoleStore.get_role_by_id_async', + 'storage.org_service.RoleStore.get_role_by_id', new_callable=AsyncMock, return_value=admin_role, ), @@ -1136,7 +1136,7 @@ async def test_update_org_with_permissions_success_non_llm_fields( with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act result = await OrgService.update_org_with_permissions( @@ -1384,7 +1384,7 @@ async def test_update_org_with_permissions_empty_update( with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act result = await OrgService.update_org_with_permissions( @@ -1419,7 +1419,7 @@ async def test_update_org_with_permissions_org_not_found( with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1467,7 +1467,7 @@ async def test_update_org_with_permissions_non_member( with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act & Assert with pytest.raises(PermissionError) as exc_info: @@ -1525,7 +1525,7 @@ async def test_update_org_with_permissions_llm_fields_insufficient_permission( with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act & Assert with pytest.raises(PermissionError) as exc_info: @@ -1585,7 +1585,7 @@ async def test_update_org_with_permissions_database_error( with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), patch( 'storage.org_service.OrgStore.update_org', new_callable=AsyncMock, @@ -1640,7 +1640,7 @@ async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), patch( 'storage.org_service.OrgStore.get_org_by_id', new_callable=AsyncMock, @@ -1693,7 +1693,7 @@ async def test_update_org_with_permissions_same_name_allowed( with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), patch( 'storage.org_service.OrgStore.get_org_by_id', new_callable=AsyncMock, @@ -1835,7 +1835,7 @@ async def test_update_org_with_permissions_only_non_llm_fields( with ( patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.a_session_maker', async_session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.role_store.a_session_maker', async_session_maker), ): # Act result = await OrgService.update_org_with_permissions( diff --git a/enterprise/tests/unit/test_role_store.py b/enterprise/tests/unit/test_role_store.py index 6de5549062..9ed50ec1c6 100644 --- a/enterprise/tests/unit/test_role_store.py +++ b/enterprise/tests/unit/test_role_store.py @@ -33,86 +33,9 @@ async def async_session_maker(async_engine): return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) -def test_get_role_by_id(session_maker): - # Test getting role by ID - with session_maker() as session: - # Create a test role - role = Role(name='admin', rank=1) - session.add(role) - session.commit() - role_id = role.id - - # Test retrieval - with patch('storage.role_store.session_maker', session_maker): - retrieved_role = RoleStore.get_role_by_id(role_id) - assert retrieved_role is not None - assert retrieved_role.id == role_id - assert retrieved_role.name == 'admin' - - -def test_get_role_by_id_not_found(session_maker): - # Test getting role by ID when it doesn't exist - with patch('storage.role_store.session_maker', session_maker): - retrieved_role = RoleStore.get_role_by_id(99999) - assert retrieved_role is None - - -def test_get_role_by_name(session_maker): - # Test getting role by name - with session_maker() as session: - # Create a test role - role = Role(name='admin', rank=1) - session.add(role) - session.commit() - role_id = role.id - - # Test retrieval - with patch('storage.role_store.session_maker', session_maker): - retrieved_role = RoleStore.get_role_by_name('admin') - assert retrieved_role is not None - assert retrieved_role.id == role_id - assert retrieved_role.name == 'admin' - - -def test_get_role_by_name_not_found(session_maker): - # Test getting role by name when it doesn't exist - with patch('storage.role_store.session_maker', session_maker): - retrieved_role = RoleStore.get_role_by_name('nonexistent') - assert retrieved_role is None - - -def test_list_roles(session_maker): - # Test listing all roles - with session_maker() as session: - # Create test roles - role1 = Role(name='admin', rank=1) - role2 = Role(name='user', rank=2) - session.add_all([role1, role2]) - session.commit() - - # Test listing - with patch('storage.role_store.session_maker', session_maker): - roles = RoleStore.list_roles() - assert len(roles) >= 2 - role_names = [role.name for role in roles] - assert 'admin' in role_names - assert 'user' in role_names - - -def test_create_role(session_maker): - # Test creating a new role - with patch('storage.role_store.session_maker', session_maker): - role = RoleStore.create_role(name='moderator', rank=2) - - assert role is not None - assert role.name == 'moderator' - assert role.rank == 2 - assert role.id is not None - - @pytest.mark.asyncio -async def test_get_role_by_name_async_with_session(async_session_maker): - """Test getting role by name asynchronously with an explicit session.""" +async def test_get_role_by_id_with_session(async_session_maker): + """Test getting role by ID with an explicit session.""" # Create a test role async with async_session_maker() as session: role = Role(name='admin', rank=1) @@ -123,9 +46,53 @@ async def test_get_role_by_name_async_with_session(async_session_maker): # Test retrieval with explicit session async with async_session_maker() as session: - retrieved_role = await RoleStore.get_role_by_name_async( - 'admin', session=session - ) + retrieved_role = await RoleStore.get_role_by_id(role_id, session=session) + assert retrieved_role is not None + assert retrieved_role.id == role_id + assert retrieved_role.name == 'admin' + + +@pytest.mark.asyncio +async def test_get_role_by_id_without_session(async_session_maker): + """Test getting role by ID using internal session maker.""" + # Create a test role + async with async_session_maker() as session: + role = Role(name='admin', rank=1) + session.add(role) + await session.commit() + await session.refresh(role) + role_id = role.id + + # Test retrieval without explicit session (using patched a_session_maker) + with patch('storage.role_store.a_session_maker', async_session_maker): + retrieved_role = await RoleStore.get_role_by_id(role_id) + assert retrieved_role is not None + assert retrieved_role.id == role_id + assert retrieved_role.name == 'admin' + + +@pytest.mark.asyncio +async def test_get_role_by_id_not_found(async_session_maker): + """Test getting role by ID when it doesn't exist.""" + with patch('storage.role_store.a_session_maker', async_session_maker): + retrieved_role = await RoleStore.get_role_by_id(99999) + assert retrieved_role is None + + +@pytest.mark.asyncio +async def test_get_role_by_name_with_session(async_session_maker): + """Test getting role by name with an explicit session.""" + # Create a test role + async with async_session_maker() as session: + role = Role(name='admin', rank=1) + session.add(role) + await session.commit() + await session.refresh(role) + role_id = role.id + + # Test retrieval with explicit session + async with async_session_maker() as session: + retrieved_role = await RoleStore.get_role_by_name('admin', session=session) assert retrieved_role is not None assert retrieved_role.id == role_id assert retrieved_role.name == 'admin' @@ -133,8 +100,8 @@ async def test_get_role_by_name_async_with_session(async_session_maker): @pytest.mark.asyncio -async def test_get_role_by_name_async_without_session(async_session_maker): - """Test getting role by name asynchronously using internal session maker.""" +async def test_get_role_by_name_without_session(async_session_maker): + """Test getting role by name using internal session maker.""" # Create a test role async with async_session_maker() as session: role = Role(name='editor', rank=2) @@ -145,7 +112,7 @@ async def test_get_role_by_name_async_without_session(async_session_maker): # Test retrieval without explicit session (using patched a_session_maker) with patch('storage.role_store.a_session_maker', async_session_maker): - retrieved_role = await RoleStore.get_role_by_name_async('editor') + retrieved_role = await RoleStore.get_role_by_name('editor') assert retrieved_role is not None assert retrieved_role.id == role_id assert retrieved_role.name == 'editor' @@ -153,18 +120,81 @@ async def test_get_role_by_name_async_without_session(async_session_maker): @pytest.mark.asyncio -async def test_get_role_by_name_async_not_found_with_session(async_session_maker): +async def test_get_role_by_name_not_found_with_session(async_session_maker): """Test getting role by name when it doesn't exist (with explicit session).""" async with async_session_maker() as session: - retrieved_role = await RoleStore.get_role_by_name_async( + retrieved_role = await RoleStore.get_role_by_name( 'nonexistent', session=session ) assert retrieved_role is None @pytest.mark.asyncio -async def test_get_role_by_name_async_not_found_without_session(async_session_maker): +async def test_get_role_by_name_not_found_without_session(async_session_maker): """Test getting role by name when it doesn't exist (without explicit session).""" with patch('storage.role_store.a_session_maker', async_session_maker): - retrieved_role = await RoleStore.get_role_by_name_async('nonexistent') + retrieved_role = await RoleStore.get_role_by_name('nonexistent') assert retrieved_role is None + + +@pytest.mark.asyncio +async def test_list_roles_with_session(async_session_maker): + """Test listing all roles with an explicit session.""" + # Create test roles + async with async_session_maker() as session: + role1 = Role(name='admin', rank=1) + role2 = Role(name='user', rank=2) + session.add_all([role1, role2]) + await session.commit() + + # Test listing with explicit session + async with async_session_maker() as session: + roles = await RoleStore.list_roles(session=session) + assert len(roles) >= 2 + role_names = [role.name for role in roles] + assert 'admin' in role_names + assert 'user' in role_names + + +@pytest.mark.asyncio +async def test_list_roles_without_session(async_session_maker): + """Test listing all roles using internal session maker.""" + # Create test roles + async with async_session_maker() as session: + role1 = Role(name='admin', rank=1) + role2 = Role(name='user', rank=2) + session.add_all([role1, role2]) + await session.commit() + + # Test listing without explicit session (using patched a_session_maker) + with patch('storage.role_store.a_session_maker', async_session_maker): + roles = await RoleStore.list_roles() + assert len(roles) >= 2 + role_names = [role.name for role in roles] + assert 'admin' in role_names + assert 'user' in role_names + + +@pytest.mark.asyncio +async def test_create_role_with_session(async_session_maker): + """Test creating a new role with an explicit session.""" + async with async_session_maker() as session: + role = await RoleStore.create_role(name='moderator', rank=2, session=session) + await session.commit() + + assert role is not None + assert role.name == 'moderator' + assert role.rank == 2 + assert role.id is not None + + +@pytest.mark.asyncio +async def test_create_role_without_session(async_session_maker): + """Test creating a new role using internal session maker.""" + with patch('storage.role_store.a_session_maker', async_session_maker): + role = await RoleStore.create_role(name='moderator', rank=2) + + assert role is not None + assert role.name == 'moderator' + assert role.rank == 2 + assert role.id is not None