mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
feat(backend): develop get /api/organizations api (#12373)
Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com> Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Chuck Butkus <chuck@all-hands.dev>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
class OrgCreationError(Exception):
|
||||
@@ -65,3 +66,54 @@ class OrgResponse(BaseModel):
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
credits: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
|
||||
"""Create an OrgResponse from an Org entity.
|
||||
|
||||
Args:
|
||||
org: The organization entity to convert
|
||||
credits: Optional credits value (defaults to None)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The response model instance
|
||||
"""
|
||||
return cls(
|
||||
id=str(org.id),
|
||||
name=org.name,
|
||||
contact_name=org.contact_name,
|
||||
contact_email=org.contact_email,
|
||||
conversation_expiration=org.conversation_expiration,
|
||||
agent=org.agent,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_api_key_for_byor=None,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser
|
||||
if org.enable_default_condenser is not None
|
||||
else True,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
|
||||
if org.enable_proactive_conversation_starters is not None
|
||||
else True,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
org_version=org.org_version if org.org_version is not None else 0,
|
||||
mcp_config=org.mcp_config,
|
||||
search_api_key=None,
|
||||
sandbox_api_key=None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
)
|
||||
|
||||
|
||||
class OrgPage(BaseModel):
|
||||
"""Paginated response model for organization list."""
|
||||
|
||||
items: list[OrgResponse]
|
||||
next_page_id: str | None = None
|
||||
|
||||
@@ -1,20 +1,94 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgCreate,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgPage,
|
||||
OrgResponse,
|
||||
)
|
||||
from storage.org_service import OrgService
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# Initialize API router
|
||||
org_router = APIRouter(prefix='/api/organizations')
|
||||
|
||||
|
||||
@org_router.get('', response_model=OrgPage)
|
||||
async def list_user_orgs(
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgPage:
|
||||
"""List organizations for the authenticated user.
|
||||
|
||||
This endpoint returns a paginated list of all organizations that the
|
||||
authenticated user is a member of.
|
||||
|
||||
Args:
|
||||
page_id: Optional page ID (offset) for pagination
|
||||
limit: Maximum number of organizations to return (1-100, default 100)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgPage: Paginated list of organizations
|
||||
|
||||
Raises:
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Listing organizations for user',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'page_id': page_id,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch organizations from service layer
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Convert Org entities to OrgResponse objects
|
||||
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organizations',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_count': len(org_responses),
|
||||
'has_more': next_page_id is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return OrgPage(items=org_responses, next_page_id=next_page_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error listing organizations',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve organizations',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post('', response_model=OrgResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_org(
|
||||
org_data: OrgCreate,
|
||||
@@ -58,31 +132,7 @@ async def create_org(
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse(
|
||||
id=str(org.id),
|
||||
name=org.name,
|
||||
contact_name=org.contact_name,
|
||||
contact_email=org.contact_email,
|
||||
conversation_expiration=org.conversation_expiration,
|
||||
agent=org.agent,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
org_version=org.org_version,
|
||||
mcp_config=org.mcp_config,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
)
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
|
||||
@@ -441,3 +441,42 @@ class OrgService:
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
user_id: str, page_id: str | None = None, limit: int = 100
|
||||
):
|
||||
"""
|
||||
Get paginated list of organizations for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
page_id: Optional page ID (offset as string) for pagination
|
||||
limit: Maximum number of organizations to return
|
||||
|
||||
Returns:
|
||||
Tuple of (list of Org objects, next_page_id or None)
|
||||
"""
|
||||
logger.debug(
|
||||
'Fetching paginated organizations for user',
|
||||
extra={'user_id': user_id, 'page_id': page_id, 'limit': limit},
|
||||
)
|
||||
|
||||
# Convert user_id string to UUID
|
||||
user_uuid = parse_uuid(user_id)
|
||||
|
||||
# Fetch organizations from store
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_uuid, page_id=page_id, limit=limit
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'Retrieved organizations for user',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_count': len(orgs),
|
||||
'has_more': next_page_id is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return orgs, next_page_id
|
||||
|
||||
@@ -96,6 +96,63 @@ class OrgStore:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
user_id: UUID, page_id: str | None = None, limit: int = 100
|
||||
) -> tuple[list[Org], str | None]:
|
||||
"""
|
||||
Get paginated list of organizations for a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
page_id: Optional page ID (offset as string) for pagination
|
||||
limit: Maximum number of organizations to return
|
||||
|
||||
Returns:
|
||||
Tuple of (list of Org objects, next_page_id or None)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Build query joining OrgMember with Org
|
||||
query = (
|
||||
session.query(Org)
|
||||
.join(OrgMember, Org.id == OrgMember.org_id)
|
||||
.filter(OrgMember.user_id == user_id)
|
||||
.order_by(Org.name)
|
||||
)
|
||||
|
||||
# Apply pagination offset
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Fetch limit + 1 to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
orgs = query.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(orgs) > limit
|
||||
if has_more:
|
||||
orgs = orgs[:limit]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
# Validate org versions
|
||||
validated_orgs = [
|
||||
OrgStore._validate_org_version(org) for org in orgs if org
|
||||
]
|
||||
validated_orgs = [org for org in validated_orgs if org is not None]
|
||||
|
||||
return validated_orgs, next_page_id
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
org_id: UUID,
|
||||
|
||||
@@ -24,6 +24,8 @@ with patch('storage.database.engine', create=True), patch(
|
||||
from server.routes.orgs import org_router
|
||||
from storage.org import Org
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
@@ -32,10 +34,10 @@ def mock_app():
|
||||
app.include_router(org_router)
|
||||
|
||||
# Override the auth dependency to return a test user
|
||||
def mock_get_openhands_user_id():
|
||||
def mock_get_admin_user_id():
|
||||
return 'test-user-123'
|
||||
|
||||
app.dependency_overrides[get_admin_user_id] = mock_get_openhands_user_id
|
||||
app.dependency_overrides[get_admin_user_id] = mock_get_admin_user_id
|
||||
|
||||
return app
|
||||
|
||||
@@ -375,3 +377,276 @@ async def test_create_org_sensitive_fields_not_exposed(mock_app):
|
||||
'sandbox_api_key' not in response_data
|
||||
or response_data.get('sandbox_api_key') is None
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_list():
|
||||
"""Create a test FastAPI app with organization routes and mocked auth for list endpoint."""
|
||||
app = FastAPI()
|
||||
app.include_router(org_router)
|
||||
|
||||
# Override the auth dependency to return a test user
|
||||
test_user_id = str(uuid.uuid4())
|
||||
|
||||
def mock_get_user_id():
|
||||
return test_user_id
|
||||
|
||||
app.dependency_overrides[get_user_id] = mock_get_user_id
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_orgs_success(mock_app_list):
|
||||
"""
|
||||
GIVEN: User has organizations
|
||||
WHEN: GET /api/organizations is called
|
||||
THEN: Paginated list of organizations is returned with 200 status
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Test Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
org_version=5,
|
||||
default_llm_model='claude-opus-4-5-20251101',
|
||||
)
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([mock_org], None),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
assert 'items' in response_data
|
||||
assert 'next_page_id' in response_data
|
||||
assert len(response_data['items']) == 1
|
||||
assert response_data['items'][0]['name'] == 'Test Organization'
|
||||
assert response_data['items'][0]['id'] == str(org_id)
|
||||
assert response_data['next_page_id'] is None
|
||||
# Credits should be None in list view
|
||||
assert response_data['items'][0]['credits'] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_orgs_with_pagination(mock_app_list):
|
||||
"""
|
||||
GIVEN: User has multiple organizations
|
||||
WHEN: GET /api/organizations is called with pagination params
|
||||
THEN: Paginated results are returned with next_page_id
|
||||
"""
|
||||
# Arrange
|
||||
org1 = Org(
|
||||
id=uuid.uuid4(),
|
||||
name='Alpha Org',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
org2 = Org(
|
||||
id=uuid.uuid4(),
|
||||
name='Beta Org',
|
||||
contact_name='Jane Doe',
|
||||
contact_email='jane@example.com',
|
||||
)
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([org1, org2], '2'),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations?page_id=0&limit=2')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
assert len(response_data['items']) == 2
|
||||
assert response_data['next_page_id'] == '2'
|
||||
assert response_data['items'][0]['name'] == 'Alpha Org'
|
||||
assert response_data['items'][1]['name'] == 'Beta Org'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_orgs_empty(mock_app_list):
|
||||
"""
|
||||
GIVEN: User has no organizations
|
||||
WHEN: GET /api/organizations is called
|
||||
THEN: Empty list is returned with 200 status
|
||||
"""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([], None),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
assert response_data['items'] == []
|
||||
assert response_data['next_page_id'] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_orgs_invalid_limit_negative(mock_app_list):
|
||||
"""
|
||||
GIVEN: Invalid limit parameter (negative)
|
||||
WHEN: GET /api/organizations is called
|
||||
THEN: 422 validation error is returned
|
||||
"""
|
||||
# Arrange
|
||||
client = TestClient(mock_app_list)
|
||||
|
||||
# Act - FastAPI should validate and reject limit <= 0
|
||||
response = client.get('/api/organizations?limit=-1')
|
||||
|
||||
# Assert - FastAPI should return 422 for validation error
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_orgs_invalid_limit_zero(mock_app_list):
|
||||
"""
|
||||
GIVEN: Invalid limit parameter (zero or negative)
|
||||
WHEN: GET /api/organizations is called
|
||||
THEN: 422 validation error is returned
|
||||
"""
|
||||
# Arrange
|
||||
client = TestClient(mock_app_list)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations?limit=0')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_orgs_service_error(mock_app_list):
|
||||
"""
|
||||
GIVEN: Service layer raises an exception
|
||||
WHEN: GET /api/organizations is called
|
||||
THEN: 500 Internal Server Error is returned
|
||||
"""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
side_effect=Exception('Database error'),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert 'Failed to retrieve organizations' in response.json()['detail']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_orgs_unauthorized():
|
||||
"""
|
||||
GIVEN: User is not authenticated
|
||||
WHEN: GET /api/organizations is called
|
||||
THEN: 401 Unauthorized error is returned
|
||||
"""
|
||||
# Arrange
|
||||
app = FastAPI()
|
||||
app.include_router(org_router)
|
||||
|
||||
# Override to simulate unauthenticated user
|
||||
async def mock_unauthenticated():
|
||||
raise HTTPException(status_code=401, detail='User not authenticated')
|
||||
|
||||
app.dependency_overrides[get_user_id] = mock_unauthenticated
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_orgs_all_fields_present(mock_app_list):
|
||||
"""
|
||||
GIVEN: Organization with all fields populated
|
||||
WHEN: GET /api/organizations is called
|
||||
THEN: All organization fields are included in response
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Complete Org',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
conversation_expiration=3600,
|
||||
agent='CodeActAgent',
|
||||
default_max_iterations=50,
|
||||
security_analyzer='enabled',
|
||||
confirmation_mode=True,
|
||||
default_llm_model='claude-opus-4-5-20251101',
|
||||
default_llm_base_url='https://api.example.com',
|
||||
remote_runtime_resource_factor=2,
|
||||
enable_default_condenser=True,
|
||||
billing_margin=0.15,
|
||||
enable_proactive_conversation_starters=True,
|
||||
sandbox_base_container_image='test-image',
|
||||
sandbox_runtime_container_image='test-runtime',
|
||||
org_version=5,
|
||||
mcp_config={'key': 'value'},
|
||||
max_budget_per_task=1000.0,
|
||||
enable_solvability_analysis=True,
|
||||
v1_enabled=True,
|
||||
)
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([mock_org], None),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
org_data = response_data['items'][0]
|
||||
assert org_data['name'] == 'Complete Org'
|
||||
assert org_data['contact_name'] == 'John Doe'
|
||||
assert org_data['contact_email'] == 'john@example.com'
|
||||
assert org_data['conversation_expiration'] == 3600
|
||||
assert org_data['agent'] == 'CodeActAgent'
|
||||
assert org_data['default_max_iterations'] == 50
|
||||
assert org_data['security_analyzer'] == 'enabled'
|
||||
assert org_data['confirmation_mode'] is True
|
||||
assert org_data['default_llm_model'] == 'claude-opus-4-5-20251101'
|
||||
assert org_data['default_llm_base_url'] == 'https://api.example.com'
|
||||
assert org_data['remote_runtime_resource_factor'] == 2
|
||||
assert org_data['enable_default_condenser'] is True
|
||||
assert org_data['billing_margin'] == 0.15
|
||||
assert org_data['enable_proactive_conversation_starters'] is True
|
||||
assert org_data['sandbox_base_container_image'] == 'test-image'
|
||||
assert org_data['sandbox_runtime_container_image'] == 'test-runtime'
|
||||
assert org_data['org_version'] == 5
|
||||
assert org_data['mcp_config'] == {'key': 'value'}
|
||||
assert org_data['max_budget_per_task'] == 1000.0
|
||||
assert org_data['enable_solvability_analysis'] is True
|
||||
assert org_data['v1_enabled'] is True
|
||||
assert org_data['credits'] is None
|
||||
|
||||
@@ -562,3 +562,120 @@ async def test_get_org_credits_api_failure_returns_none(mock_litellm_api):
|
||||
|
||||
# Assert
|
||||
assert credits is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_success(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has organizations in database
|
||||
WHEN: get_user_orgs_paginated is called with valid user_id
|
||||
THEN: Organizations are returned with pagination info
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org = Org(id=org_id, name='Test Org')
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([org, user, role])
|
||||
session.flush()
|
||||
|
||||
member = OrgMember(
|
||||
org_id=org_id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
session.add(member)
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=str(user_id), page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == 'Test Org'
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_with_pagination(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has multiple organizations
|
||||
WHEN: get_user_orgs_paginated is called with page_id and limit
|
||||
THEN: Paginated results are returned correctly
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
org3 = Org(name='Gamma Org')
|
||||
session.add_all([org1, org2, org3])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
session.add_all([member1, member2, member3])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=str(user_id), page_id='0', limit=2
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 2
|
||||
assert orgs[0].name == 'Alpha Org'
|
||||
assert orgs[1].name == 'Beta Org'
|
||||
assert next_page_id == '2'
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
"""
|
||||
GIVEN: User has no organizations
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Empty list and None next_page_id are returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 0
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_invalid_user_id_format():
|
||||
"""
|
||||
GIVEN: Invalid user_id format (not a valid UUID string)
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: ValueError is raised
|
||||
"""
|
||||
# Arrange
|
||||
invalid_user_id = 'not-a-uuid'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
OrgService.get_user_orgs_paginated(
|
||||
user_id=invalid_user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
@@ -415,3 +415,251 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm
|
||||
)
|
||||
assert persisted_member.max_iterations == 100
|
||||
assert persisted_member.llm_model == 'gpt-4'
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User is member of multiple organizations
|
||||
WHEN: get_user_orgs_paginated is called without page_id
|
||||
THEN: First page of organizations is returned in alphabetical order
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
other_user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
# Create orgs for the user
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
org3 = Org(name='Gamma Org')
|
||||
# Create org for another user (should not be included)
|
||||
org4 = Org(name='Other Org')
|
||||
session.add_all([org1, org2, org3, org4])
|
||||
session.flush()
|
||||
|
||||
# Create user and role
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
other_user = User(id=other_user_id, current_org_id=org4.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, other_user, role])
|
||||
session.flush()
|
||||
|
||||
# Create memberships
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
other_member = OrgMember(
|
||||
org_id=org4.id, user_id=other_user_id, role_id=1, llm_api_key='key4'
|
||||
)
|
||||
session.add_all([member1, member2, member3, other_member])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=2
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 2
|
||||
assert orgs[0].name == 'Alpha Org'
|
||||
assert orgs[1].name == 'Beta Org'
|
||||
assert next_page_id == '2' # Has more results
|
||||
# Verify other user's org is not included
|
||||
org_names = [org.name for org in orgs]
|
||||
assert 'Other Org' not in org_names
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has multiple organizations and page_id is provided
|
||||
WHEN: get_user_orgs_paginated is called with page_id
|
||||
THEN: Organizations starting from offset are returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
org3 = Org(name='Gamma Org')
|
||||
session.add_all([org1, org2, org3])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
session.add_all([member1, member2, member3])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id='1', limit=1
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == 'Beta Org' # Second org (offset 1)
|
||||
assert next_page_id == '2' # Has more results
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has organizations but fewer than limit
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: All organizations are returned and next_page_id is None
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
session.add_all([member1, member2])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 2
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Invalid page_id (non-numeric string)
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Results start from beginning (offset 0)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
session.add(org1)
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
session.add(member1)
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id='invalid', limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == 'Alpha Org'
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
"""
|
||||
GIVEN: User has no organizations
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Empty list and None next_page_id are returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 0
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has organizations with different names
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Organizations are returned in alphabetical order by name
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
# Create orgs in non-alphabetical order
|
||||
org3 = Org(name='Zebra Org')
|
||||
org1 = Org(name='Apple Org')
|
||||
org2 = Org(name='Banana Org')
|
||||
session.add_all([org3, org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
session.add_all([member1, member2, member3])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, _ = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 3
|
||||
assert orgs[0].name == 'Apple Org'
|
||||
assert orgs[1].name == 'Banana Org'
|
||||
assert orgs[2].name == 'Zebra Org'
|
||||
|
||||
Reference in New Issue
Block a user