mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
feat(backend): add pagination and email filtering for organization members (#12999)
This commit is contained in:
@@ -267,7 +267,8 @@ class OrgMemberPage(BaseModel):
|
||||
"""Paginated response for organization members."""
|
||||
|
||||
items: list[OrgMemberResponse]
|
||||
next_page_id: str | None = None
|
||||
current_page: int = 1
|
||||
per_page: int = 10
|
||||
|
||||
|
||||
class OrgMemberUpdate(BaseModel):
|
||||
|
||||
@@ -519,7 +519,7 @@ async def get_org_members(
|
||||
org_id: UUID,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
Query(title='Optional page offset for pagination'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
@@ -528,10 +528,18 @@ async def get_org_members(
|
||||
gt=0,
|
||||
lte=100,
|
||||
),
|
||||
] = 100,
|
||||
] = 10,
|
||||
email: Annotated[
|
||||
str | None,
|
||||
Query(
|
||||
title='Filter members by email (case-insensitive partial match)',
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
),
|
||||
] = None,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> OrgMemberPage:
|
||||
"""Get all members of an organization with cursor-based pagination.
|
||||
"""Get all members of an organization with pagination and optional email filter.
|
||||
|
||||
This endpoint retrieves a paginated list of organization members. Access requires
|
||||
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
|
||||
@@ -539,12 +547,15 @@ async def get_org_members(
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
page_id: Optional page ID (offset) for pagination
|
||||
limit: Maximum number of members to return (1-100, default 100)
|
||||
page_id: Optional page offset for pagination
|
||||
limit: Maximum number of members to return (1-100, default 10)
|
||||
email: Optional email filter (case-insensitive partial match)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgMemberPage: Paginated list of organization members
|
||||
OrgMemberPage: Paginated list of organization members with
|
||||
current_page and per_page metadata. Use the /count endpoint
|
||||
to get the total count separately.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
@@ -558,6 +569,7 @@ async def get_org_members(
|
||||
current_user_id=UUID(user_id),
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
email_filter=email,
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -600,6 +612,64 @@ async def get_org_members(
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}/members/count')
|
||||
async def get_org_members_count(
|
||||
org_id: UUID,
|
||||
email: Annotated[
|
||||
str | None,
|
||||
Query(
|
||||
title='Filter members by email (case-insensitive partial match)',
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
),
|
||||
] = None,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> int:
|
||||
"""Get count of organization members with optional email filter.
|
||||
|
||||
This endpoint returns the total count of organization members matching
|
||||
the filter criteria. Access requires the VIEW_ORG_SETTINGS permission,
|
||||
which is granted to all organization members (member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
email: Optional email filter (case-insensitive partial match)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
int: Total count of organization members matching the filter
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission or is not a member
|
||||
HTTPException: 400 if org_id format is invalid
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
try:
|
||||
return await OrgMemberService.get_org_members_count(
|
||||
org_id=org_id,
|
||||
current_user_id=UUID(user_id),
|
||||
email_filter=email,
|
||||
)
|
||||
except OrgMemberNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='You are not a member of this organization',
|
||||
)
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error retrieving organization member count')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve member count',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}/members/{user_id}')
|
||||
async def remove_org_member(
|
||||
org_id: UUID,
|
||||
|
||||
@@ -67,10 +67,18 @@ class OrgMemberService:
|
||||
org_id: UUID,
|
||||
current_user_id: UUID,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
limit: int = 10,
|
||||
email_filter: str | None = None,
|
||||
) -> tuple[bool, str | None, OrgMemberPage | None]:
|
||||
"""Get organization members with authorization check.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID.
|
||||
current_user_id: Requesting user's UUID.
|
||||
page_id: Offset encoded as string (e.g., "0", "10", "20").
|
||||
limit: Items per page (default 10).
|
||||
email_filter: Optional case-insensitive partial email match.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_code, data). If success is True, error_code is None.
|
||||
"""
|
||||
@@ -90,8 +98,11 @@ class OrgMemberService:
|
||||
return False, 'invalid_page_id', None
|
||||
|
||||
# Call store to get paginated members
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=offset, limit=limit
|
||||
members, _ = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
email_filter=email_filter,
|
||||
)
|
||||
|
||||
# Transform data to response format
|
||||
@@ -112,12 +123,47 @@ class OrgMemberService:
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate next_page_id
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
# Calculate current page (1-indexed)
|
||||
current_page = (offset // limit) + 1
|
||||
|
||||
return True, None, OrgMemberPage(items=items, next_page_id=next_page_id)
|
||||
return (
|
||||
True,
|
||||
None,
|
||||
OrgMemberPage(
|
||||
items=items,
|
||||
current_page=current_page,
|
||||
per_page=limit,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_count(
|
||||
org_id: UUID,
|
||||
current_user_id: UUID,
|
||||
email_filter: str | None = None,
|
||||
) -> int:
|
||||
"""Get count of organization members with authorization check.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID.
|
||||
current_user_id: Requesting user's UUID.
|
||||
email_filter: Optional case-insensitive partial email match.
|
||||
|
||||
Returns:
|
||||
int: Count of organization members matching the filter.
|
||||
|
||||
Raises:
|
||||
OrgMemberNotFoundError: If requesting user is not a member of the organization.
|
||||
"""
|
||||
# Verify current user is a member of the organization
|
||||
requester_membership = OrgMemberStore.get_org_member(org_id, current_user_id)
|
||||
if not requester_membership:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(current_user_id))
|
||||
|
||||
return await OrgMemberStore.get_org_members_count(
|
||||
org_id=org_id,
|
||||
email_filter=email_filter,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def remove_org_member(
|
||||
|
||||
@@ -5,7 +5,7 @@ Store class for managing organization-member relationships.
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.org_member import OrgMember
|
||||
@@ -183,14 +183,48 @@ class OrgMemberStore:
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_count(
|
||||
org_id: UUID,
|
||||
email_filter: str | None = None,
|
||||
) -> int:
|
||||
"""Get total count of organization members, optionally filtered by email.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID.
|
||||
email_filter: Optional case-insensitive partial email match.
|
||||
|
||||
Returns:
|
||||
Total count of matching members.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
query = select(func.count(OrgMember.user_id)).filter(
|
||||
OrgMember.org_id == org_id
|
||||
)
|
||||
|
||||
if email_filter:
|
||||
query = query.join(User, User.id == OrgMember.user_id).filter(
|
||||
User.email.ilike(f'%{email_filter}%')
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_paginated(
|
||||
org_id: UUID,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
email_filter: str | None = None,
|
||||
) -> tuple[list[OrgMember], bool]:
|
||||
"""Get paginated list of organization members with user and role info.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID.
|
||||
offset: Number of records to skip.
|
||||
limit: Maximum number of records to return.
|
||||
email_filter: Optional case-insensitive partial email match.
|
||||
|
||||
Returns:
|
||||
Tuple of (members_list, has_more) where has_more indicates if there are more results.
|
||||
"""
|
||||
@@ -200,13 +234,18 @@ class OrgMemberStore:
|
||||
query = (
|
||||
select(OrgMember)
|
||||
.options(joinedload(OrgMember.user), joinedload(OrgMember.role))
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(OrgMember.org_id == org_id)
|
||||
.order_by(OrgMember.user_id)
|
||||
.offset(offset)
|
||||
.limit(limit + 1)
|
||||
)
|
||||
|
||||
# Apply email filter if provided
|
||||
if email_filter:
|
||||
query = query.filter(User.email.ilike(f'%{email_filter}%'))
|
||||
|
||||
query = query.order_by(OrgMember.user_id).offset(offset).limit(limit + 1)
|
||||
|
||||
result = await session.execute(query)
|
||||
members = list(result.scalars().all())
|
||||
members = list(result.unique().scalars().all())
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(members) > limit
|
||||
|
||||
@@ -2132,7 +2132,8 @@ class TestGetOrgMembersEndpoint:
|
||||
status='active',
|
||||
)
|
||||
],
|
||||
next_page_id=None,
|
||||
current_page=1,
|
||||
per_page=100,
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -2150,7 +2151,7 @@ class TestGetOrgMembersEndpoint:
|
||||
# Assert
|
||||
assert isinstance(result, OrgMemberPage)
|
||||
assert len(result.items) == 1
|
||||
assert result.next_page_id is None
|
||||
assert result.current_page == 1
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -2326,7 +2327,8 @@ class TestGetOrgMembersEndpoint:
|
||||
status='active',
|
||||
)
|
||||
],
|
||||
next_page_id='200',
|
||||
current_page=2,
|
||||
per_page=100,
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -2343,15 +2345,132 @@ class TestGetOrgMembersEndpoint:
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgMemberPage)
|
||||
assert result.next_page_id == '200'
|
||||
assert result.current_page == 2
|
||||
mock_get.assert_called_once_with(
|
||||
org_id=uuid.UUID(org_id),
|
||||
current_user_id=uuid.UUID(current_user_id),
|
||||
page_id='100',
|
||||
limit=100,
|
||||
email_filter=None,
|
||||
)
|
||||
|
||||
|
||||
class TestGetOrgMembersCountEndpoint:
|
||||
"""Test cases for GET /api/organizations/{org_id}/members/count endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_succeeds_returns_int(self, org_id, current_user_id):
|
||||
"""Test that successful count returns an integer."""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.orgs.OrgMemberService.get_org_members_count',
|
||||
AsyncMock(return_value=42),
|
||||
) as mock_get_count:
|
||||
# Import here to avoid circular import issues
|
||||
from server.routes.orgs import get_org_members_count
|
||||
|
||||
# Act
|
||||
result = await get_org_members_count(
|
||||
org_id=uuid.UUID(org_id),
|
||||
email=None,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 42
|
||||
mock_get_count.assert_called_once_with(
|
||||
org_id=uuid.UUID(org_id),
|
||||
current_user_id=uuid.UUID(current_user_id),
|
||||
email_filter=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_with_email_filter(self, org_id, current_user_id):
|
||||
"""Test that email filter is passed to service."""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.orgs.OrgMemberService.get_org_members_count',
|
||||
AsyncMock(return_value=5),
|
||||
) as mock_get_count:
|
||||
from server.routes.orgs import get_org_members_count
|
||||
|
||||
# Act
|
||||
result = await get_org_members_count(
|
||||
org_id=uuid.UUID(org_id),
|
||||
email='alice',
|
||||
user_id=current_user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 5
|
||||
mock_get_count.assert_called_once_with(
|
||||
org_id=uuid.UUID(org_id),
|
||||
current_user_id=uuid.UUID(current_user_id),
|
||||
email_filter='alice',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_a_member_returns_403(self, org_id, current_user_id):
|
||||
"""Test that OrgMemberNotFoundError returns 403 Forbidden."""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.orgs.OrgMemberService.get_org_members_count',
|
||||
AsyncMock(side_effect=OrgMemberNotFoundError(org_id, current_user_id)),
|
||||
):
|
||||
from server.routes.orgs import get_org_members_count
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_org_members_count(
|
||||
org_id=uuid.UUID(org_id),
|
||||
email=None,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert 'not a member of this organization' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_uuid_returns_400(self, org_id):
|
||||
"""Test that invalid user_id UUID format returns 400 Bad Request."""
|
||||
# Arrange
|
||||
invalid_user_id = 'not-a-uuid'
|
||||
|
||||
from server.routes.orgs import get_org_members_count
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_org_members_count(
|
||||
org_id=uuid.UUID(org_id),
|
||||
email=None,
|
||||
user_id=invalid_user_id,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert 'Invalid organization ID format' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_exception_returns_500(self, org_id, current_user_id):
|
||||
"""Test that generic exception returns 500 Internal Server Error."""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.orgs.OrgMemberService.get_org_members_count',
|
||||
AsyncMock(side_effect=Exception('Database error')),
|
||||
):
|
||||
from server.routes.orgs import get_org_members_count
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_org_members_count(
|
||||
org_id=uuid.UUID(org_id),
|
||||
email=None,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert 'Failed to retrieve member count' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestRemoveOrgMemberEndpoint:
|
||||
"""Test cases for DELETE /api/organizations/{org_id}/members/{user_id} endpoint."""
|
||||
|
||||
|
||||
@@ -175,7 +175,8 @@ class TestOrgMemberServiceGetOrgMembers:
|
||||
assert data is not None
|
||||
assert isinstance(data, OrgMemberPage)
|
||||
assert len(data.items) == 1
|
||||
assert data.next_page_id is None
|
||||
assert data.current_page == 1
|
||||
assert data.per_page == 100
|
||||
assert data.items[0].user_id == str(current_user_id)
|
||||
assert data.items[0].email == 'test@example.com'
|
||||
assert data.items[0].role_id == 1
|
||||
@@ -282,9 +283,9 @@ class TestOrgMemberServiceGetOrgMembers:
|
||||
# Assert
|
||||
assert success is True
|
||||
assert data is not None
|
||||
assert data.next_page_id is None
|
||||
assert data.current_page == 1
|
||||
mock_get_paginated.assert_called_once_with(
|
||||
org_id=org_id, offset=0, limit=100
|
||||
org_id=org_id, offset=0, limit=100, email_filter=None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -316,9 +317,9 @@ class TestOrgMemberServiceGetOrgMembers:
|
||||
# Assert
|
||||
assert success is True
|
||||
assert data is not None
|
||||
assert data.next_page_id == '150' # offset (100) + limit (50)
|
||||
assert data.current_page == 3 # offset (100) / limit (50) + 1
|
||||
mock_get_paginated.assert_called_once_with(
|
||||
org_id=org_id, offset=100, limit=50
|
||||
org_id=org_id, offset=100, limit=50, email_filter=None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -350,7 +351,7 @@ class TestOrgMemberServiceGetOrgMembers:
|
||||
# Assert
|
||||
assert success is True
|
||||
assert data is not None
|
||||
assert data.next_page_id is None
|
||||
assert data.current_page == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_organization_no_members(
|
||||
@@ -382,7 +383,6 @@ class TestOrgMemberServiceGetOrgMembers:
|
||||
assert success is True
|
||||
assert data is not None
|
||||
assert len(data.items) == 0
|
||||
assert data.next_page_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_user_relationship_handles_gracefully(
|
||||
@@ -512,6 +512,156 @@ class TestOrgMemberServiceGetOrgMembers:
|
||||
assert data is not None
|
||||
assert len(data.items) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_email_filter_passed_to_store(
|
||||
self, org_id, current_user_id, mock_org_member, requester_membership_owner
|
||||
):
|
||||
"""Test that email filter is passed to store methods."""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_member_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
):
|
||||
mock_get_member.return_value = requester_membership_owner
|
||||
mock_get_paginated.return_value = ([mock_org_member], False)
|
||||
|
||||
# Act
|
||||
await OrgMemberService.get_org_members(
|
||||
org_id=org_id,
|
||||
current_user_id=current_user_id,
|
||||
page_id=None,
|
||||
limit=10,
|
||||
email_filter='alice',
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_get_paginated.assert_called_once_with(
|
||||
org_id=org_id, offset=0, limit=10, email_filter='alice'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_metadata_correct_for_page_2(
|
||||
self, org_id, current_user_id, mock_org_member, requester_membership_owner
|
||||
):
|
||||
"""Test pagination metadata is correct for page 2."""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_member_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
):
|
||||
mock_get_member.return_value = requester_membership_owner
|
||||
mock_get_paginated.return_value = ([mock_org_member], True)
|
||||
|
||||
# Act - Request page 2 (offset 10) with limit 10
|
||||
success, error_code, data = await OrgMemberService.get_org_members(
|
||||
org_id=org_id,
|
||||
current_user_id=current_user_id,
|
||||
page_id='10',
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert success is True
|
||||
assert data is not None
|
||||
assert data.current_page == 2
|
||||
assert data.per_page == 10
|
||||
|
||||
|
||||
class TestOrgMemberServiceGetOrgMembersCount:
|
||||
"""Test cases for OrgMemberService.get_org_members_count."""
|
||||
|
||||
@pytest.fixture
|
||||
def requester_membership(self, org_id, current_user_id):
|
||||
"""Create a mock requester membership."""
|
||||
membership = MagicMock(spec=OrgMember)
|
||||
membership.org_id = org_id
|
||||
membership.user_id = current_user_id
|
||||
membership.role_id = 1
|
||||
return membership
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_succeeds_returns_count(
|
||||
self, org_id, current_user_id, requester_membership
|
||||
):
|
||||
"""Test that successful count returns the member count."""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_member_service.OrgMemberStore.get_org_members_count',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_count,
|
||||
):
|
||||
mock_get_member.return_value = requester_membership
|
||||
mock_get_count.return_value = 42
|
||||
|
||||
# Act
|
||||
count = await OrgMemberService.get_org_members_count(
|
||||
org_id=org_id,
|
||||
current_user_id=current_user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert count == 42
|
||||
mock_get_count.assert_called_once_with(org_id=org_id, email_filter=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_with_email_filter(
|
||||
self, org_id, current_user_id, requester_membership
|
||||
):
|
||||
"""Test that email filter is passed to store method."""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_member_service.OrgMemberStore.get_org_members_count',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_count,
|
||||
):
|
||||
mock_get_member.return_value = requester_membership
|
||||
mock_get_count.return_value = 5
|
||||
|
||||
# Act
|
||||
count = await OrgMemberService.get_org_members_count(
|
||||
org_id=org_id,
|
||||
current_user_id=current_user_id,
|
||||
email_filter='alice',
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert count == 5
|
||||
mock_get_count.assert_called_once_with(org_id=org_id, email_filter='alice')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_a_member_raises_error(self, org_id, current_user_id):
|
||||
"""Test that non-member raises OrgMemberNotFoundError."""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.services.org_member_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member:
|
||||
mock_get_member.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgMemberNotFoundError):
|
||||
await OrgMemberService.get_org_members_count(
|
||||
org_id=org_id,
|
||||
current_user_id=current_user_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_membership_owner(org_id, target_user_id, owner_role):
|
||||
|
||||
@@ -655,3 +655,180 @@ async def test_get_org_members_paginated_eager_loading(async_session_maker):
|
||||
assert member.role is not None
|
||||
assert member.role.name == 'owner'
|
||||
assert member.role.rank == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_count_no_filter(async_session_maker):
|
||||
"""Test get_org_members_count returns correct count without email filter."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
users = [
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
|
||||
for i in range(5)
|
||||
]
|
||||
session.add_all(users)
|
||||
await session.flush()
|
||||
|
||||
org_members = [
|
||||
OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key=f'test-key-{i}',
|
||||
status='active',
|
||||
)
|
||||
for i, user in enumerate(users)
|
||||
]
|
||||
session.add_all(org_members)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
count = await OrgMemberStore.get_org_members_count(org_id=org_id)
|
||||
|
||||
# Assert
|
||||
assert count == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_count_with_email_filter(async_session_maker):
|
||||
"""Test get_org_members_count filters by email correctly."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
users = [
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email='alice@example.com'),
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email='bob@example.com'),
|
||||
User(
|
||||
id=uuid.uuid4(), current_org_id=org.id, email='alice.smith@example.com'
|
||||
),
|
||||
]
|
||||
session.add_all(users)
|
||||
await session.flush()
|
||||
|
||||
org_members = [
|
||||
OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key=f'test-key-{i}',
|
||||
status='active',
|
||||
)
|
||||
for i, user in enumerate(users)
|
||||
]
|
||||
session.add_all(org_members)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
count = await OrgMemberStore.get_org_members_count(
|
||||
org_id=org_id, email_filter='alice'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_paginated_with_email_filter(async_session_maker):
|
||||
"""Test get_org_members_paginated filters by email correctly."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
users = [
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email='alice@example.com'),
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email='bob@example.com'),
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email='charlie@example.com'),
|
||||
]
|
||||
session.add_all(users)
|
||||
await session.flush()
|
||||
|
||||
org_members = [
|
||||
OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key=f'test-key-{i}',
|
||||
status='active',
|
||||
)
|
||||
for i, user in enumerate(users)
|
||||
]
|
||||
session.add_all(org_members)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=0, limit=10, email_filter='bob'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(members) == 1
|
||||
assert members[0].user.email == 'bob@example.com'
|
||||
assert has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_paginated_email_filter_case_insensitive(
|
||||
async_session_maker,
|
||||
):
|
||||
"""Test email filter is case-insensitive."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id, email='Alice@Example.COM')
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=0, limit=10, email_filter='alice@example'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(members) == 1
|
||||
assert members[0].user.email == 'Alice@Example.COM'
|
||||
|
||||
Reference in New Issue
Block a user