From 030ff59c4096a536f903ca2ad9eed88bf1bb2051 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:14:27 +0700 Subject: [PATCH] feat(backend): develop post /api/organizations api (org project) (#12263) Co-authored-by: rohitvinodmalhotra@gmail.com Co-authored-by: openhands Co-authored-by: Chuck Butkus --- enterprise/poetry.lock | 2 +- enterprise/saas_server.py | 2 + enterprise/server/auth/saas_user_auth.py | 13 +- enterprise/server/email_validation.py | 68 +++ enterprise/server/routes/org_models.py | 67 +++ enterprise/server/routes/orgs.py | 117 ++++ enterprise/storage/lite_llm_manager.py | 46 +- enterprise/storage/org_service.py | 443 ++++++++++++++ enterprise/storage/org_store.py | 26 + enterprise/storage/user_store.py | 6 +- .../tests/unit/server/routes/test_orgs.py | 377 ++++++++++++ .../unit/test_email_validation_dependency.py | 272 +++++++++ .../tests/unit/test_lite_llm_manager.py | 146 ++++- enterprise/tests/unit/test_org_service.py | 564 ++++++++++++++++++ enterprise/tests/unit/test_org_store.py | 222 ++++++- enterprise/tests/unit/test_saas_user_auth.py | 53 +- 16 files changed, 2399 insertions(+), 25 deletions(-) create mode 100644 enterprise/server/email_validation.py create mode 100644 enterprise/server/routes/org_models.py create mode 100644 enterprise/server/routes/orgs.py create mode 100644 enterprise/storage/org_service.py create mode 100644 enterprise/tests/unit/server/routes/test_orgs.py create mode 100644 enterprise/tests/unit/test_email_validation_dependency.py create mode 100644 enterprise/tests/unit/test_org_service.py diff --git a/enterprise/poetry.lock b/enterprise/poetry.lock index 6096729fe1..274e138a06 100644 --- a/enterprise/poetry.lock +++ b/enterprise/poetry.lock @@ -6126,7 +6126,7 @@ wsproto = ">=1.2.0" [[package]] name = "openhands-ai" -version = "1.1.0" +version = "1.2.1" description = "OpenHands: Code Less, Make More" optional = false python-versions = "^3.12,<3.14" diff --git a/enterprise/saas_server.py b/enterprise/saas_server.py index 803530a794..1734536d53 100644 --- a/enterprise/saas_server.py +++ b/enterprise/saas_server.py @@ -38,6 +38,7 @@ from server.routes.integration.linear import linear_integration_router # noqa: from server.routes.integration.slack import slack_router # noqa: E402 from server.routes.mcp_patch import patch_mcp_server # noqa: E402 from server.routes.oauth_device import oauth_device_router # noqa: E402 +from server.routes.orgs import org_router # noqa: E402 from server.routes.readiness import readiness_router # noqa: E402 from server.routes.user import saas_user_router # noqa: E402 from server.sharing.shared_conversation_router import ( # noqa: E402 @@ -90,6 +91,7 @@ if GITLAB_APP_CLIENT_ID: base_app.include_router(gitlab_integration_router) base_app.include_router(api_keys_router) # Add routes for API key management +base_app.include_router(org_router) # Add routes for organization management add_github_proxy_routes(base_app) add_debugging_routes( base_app diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index 893cda0527..5cd6a1e2c4 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -77,6 +77,15 @@ class SaasUserAuth(UserAuth): self.access_token = SecretStr(tokens['access_token']) self.refresh_token = SecretStr(tokens['refresh_token']) self.refreshed = True + if not self.email or not self.email_verified or not self.user_id: + # We don't need to verify the signature here because we just refreshed + # this token from the IDP via token_manager.refresh() + access_token_payload = jwt.decode( + tokens['access_token'], options={'verify_signature': False} + ) + self.user_id = access_token_payload['sub'] + self.email = access_token_payload['email'] + self.email_verified = access_token_payload['email_verified'] def _is_token_expired(self, token: SecretStr): logger.debug('saas_user_auth_is_token_expired') @@ -273,11 +282,13 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None: if not user_id: return None offline_token = await token_manager.load_offline_token(user_id) - return SaasUserAuth( + saas_user_auth = SaasUserAuth( user_id=user_id, refresh_token=SecretStr(offline_token), auth_type=AuthType.BEARER, ) + await saas_user_auth.refresh() + return saas_user_auth except Exception as exc: raise BearerTokenError from exc diff --git a/enterprise/server/email_validation.py b/enterprise/server/email_validation.py new file mode 100644 index 0000000000..b3242cd281 --- /dev/null +++ b/enterprise/server/email_validation.py @@ -0,0 +1,68 @@ +""" +Email domain validation utilities for enterprise endpoints. +""" + +from fastapi import Depends, HTTPException, Request, status + +from openhands.core.logger import openhands_logger as logger +from openhands.server.user_auth import get_user_auth, get_user_id + + +async def get_admin_user_id( + request: Request, user_id: str | None = Depends(get_user_id) +) -> str: + """ + Dependency that validates user has @openhands.dev email domain. + + This dependency can be used in place of get_user_id for endpoints that + should only be accessible to admin users. Currently, this is implemented + by checking for @openhands.dev email domain. + + TODO: In the future, this should be replaced with an explicit is_admin flag + in user/org settings instead of relying on email domain validation. + + Args: + request: FastAPI request object + user_id: User ID from get_user_id dependency + + Returns: + str: User ID if email domain is valid + + Raises: + HTTPException: 403 if email domain is not @openhands.dev + HTTPException: 401 if user is not authenticated + + Example: + @router.post('/endpoint') + async def create_resource( + user_id: str = Depends(get_admin_user_id), + ): + # Only admin users can access this endpoint + pass + """ + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='User not authenticated', + ) + + user_auth = await get_user_auth(request) + user_email = await user_auth.get_user_email() + + if not user_email: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='User email not available', + ) + + if not user_email.endswith('@openhands.dev'): + logger.warning( + 'Access denied - invalid email domain', + extra={'user_id': user_id, 'email_domain': user_email.split('@')[-1]}, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail='Access restricted to @openhands.dev users', + ) + + return user_id diff --git a/enterprise/server/routes/org_models.py b/enterprise/server/routes/org_models.py new file mode 100644 index 0000000000..4fe480bf3a --- /dev/null +++ b/enterprise/server/routes/org_models.py @@ -0,0 +1,67 @@ +from pydantic import BaseModel, EmailStr, Field + + +class OrgCreationError(Exception): + """Base exception for organization creation errors.""" + + pass + + +class OrgNameExistsError(OrgCreationError): + """Raised when an organization name already exists.""" + + def __init__(self, name: str): + self.name = name + super().__init__(f'Organization with name "{name}" already exists') + + +class LiteLLMIntegrationError(OrgCreationError): + """Raised when LiteLLM integration fails.""" + + pass + + +class OrgDatabaseError(OrgCreationError): + """Raised when database operations fail.""" + + pass + + +class OrgCreate(BaseModel): + """Request model for creating a new organization.""" + + # Required fields + name: str = Field(min_length=1, max_length=255, strip_whitespace=True) + contact_name: str + contact_email: EmailStr = Field(strip_whitespace=True) + + +class OrgResponse(BaseModel): + """Response model for organization.""" + + id: str + name: str + contact_name: str + contact_email: str + conversation_expiration: int | None = None + agent: str | None = None + default_max_iterations: int | None = None + security_analyzer: str | None = None + confirmation_mode: bool | None = None + default_llm_model: str | None = None + default_llm_api_key_for_byor: str | None = None + default_llm_base_url: str | None = None + remote_runtime_resource_factor: int | None = None + enable_default_condenser: bool = True + billing_margin: float | None = None + enable_proactive_conversation_starters: bool = True + sandbox_base_container_image: str | None = None + sandbox_runtime_container_image: str | None = None + org_version: int = 0 + mcp_config: dict | None = None + search_api_key: str | None = None + sandbox_api_key: str | None = None + max_budget_per_task: float | None = None + enable_solvability_analysis: bool | None = None + v1_enabled: bool | None = None + credits: float | None = None diff --git a/enterprise/server/routes/orgs.py b/enterprise/server/routes/orgs.py new file mode 100644 index 0000000000..fa61cb3664 --- /dev/null +++ b/enterprise/server/routes/orgs.py @@ -0,0 +1,117 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from server.email_validation import get_admin_user_id +from server.routes.org_models import ( + LiteLLMIntegrationError, + OrgCreate, + OrgDatabaseError, + OrgNameExistsError, + OrgResponse, +) +from storage.org_service import OrgService + +from openhands.core.logger import openhands_logger as logger + +# Initialize API router +org_router = APIRouter(prefix='/api/organizations') + + +@org_router.post('', response_model=OrgResponse, status_code=status.HTTP_201_CREATED) +async def create_org( + org_data: OrgCreate, + user_id: str = Depends(get_admin_user_id), +) -> OrgResponse: + """Create a new organization. + + This endpoint allows authenticated users with @openhands.dev email to create + a new organization. The user who creates the organization automatically becomes + its owner. + + Args: + org_data: Organization creation data + user_id: Authenticated user ID (injected by dependency) + + Returns: + OrgResponse: The created organization details + + Raises: + HTTPException: 403 if user email domain is not @openhands.dev + HTTPException: 409 if organization name already exists + HTTPException: 500 if creation fails + """ + logger.info( + 'Creating new organization', + extra={ + 'user_id': user_id, + 'org_name': org_data.name, + }, + ) + + try: + # Use service layer to create organization + org = await OrgService.create_org_with_owner( + name=org_data.name, + contact_name=org_data.contact_name, + contact_email=org_data.contact_email, + user_id=user_id, + ) + + # Retrieve credits from LiteLLM + credits = await OrgService.get_org_credits(user_id, org.id) + + return OrgResponse( + id=str(org.id), + name=org.name, + contact_name=org.contact_name, + contact_email=org.contact_email, + conversation_expiration=org.conversation_expiration, + agent=org.agent, + default_max_iterations=org.default_max_iterations, + security_analyzer=org.security_analyzer, + confirmation_mode=org.confirmation_mode, + default_llm_model=org.default_llm_model, + default_llm_base_url=org.default_llm_base_url, + remote_runtime_resource_factor=org.remote_runtime_resource_factor, + enable_default_condenser=org.enable_default_condenser, + billing_margin=org.billing_margin, + enable_proactive_conversation_starters=org.enable_proactive_conversation_starters, + sandbox_base_container_image=org.sandbox_base_container_image, + sandbox_runtime_container_image=org.sandbox_runtime_container_image, + org_version=org.org_version, + mcp_config=org.mcp_config, + max_budget_per_task=org.max_budget_per_task, + enable_solvability_analysis=org.enable_solvability_analysis, + v1_enabled=org.v1_enabled, + credits=credits, + ) + except OrgNameExistsError as e: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(e), + ) + except LiteLLMIntegrationError as e: + logger.error( + 'LiteLLM integration failed', + extra={'user_id': user_id, 'error': str(e)}, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to create LiteLLM integration', + ) + except OrgDatabaseError as e: + logger.error( + 'Database operation failed', + extra={'user_id': user_id, 'error': str(e)}, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to create organization', + ) + except Exception as e: + logger.exception( + 'Unexpected error creating organization', + extra={'user_id': user_id, 'error': str(e)}, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='An unexpected error occurred', + ) diff --git a/enterprise/storage/lite_llm_manager.py b/enterprise/storage/lite_llm_manager.py index 09a7a5147c..b23e3c28d2 100644 --- a/enterprise/storage/lite_llm_manager.py +++ b/enterprise/storage/lite_llm_manager.py @@ -38,6 +38,7 @@ class LiteLlmManager: org_id: str, keycloak_user_id: str, oss_settings: Settings, + create_user: bool, ) -> Settings | None: logger.info( 'SettingsStore:update_settings_with_litellm_default:start', @@ -64,9 +65,10 @@ class LiteLlmManager: client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET ) - await LiteLlmManager._create_user( - client, keycloak_user_info.get('email'), keycloak_user_id - ) + if create_user: + await LiteLlmManager._create_user( + client, keycloak_user_info.get('email'), keycloak_user_id + ) await LiteLlmManager._add_user_to_team( client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET @@ -76,7 +78,7 @@ class LiteLlmManager: client, keycloak_user_id, org_id, - f'OpenHands Cloud - user {keycloak_user_id}', + f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}', None, ) @@ -474,6 +476,41 @@ class LiteLlmManager: ) response.raise_for_status() + @staticmethod + async def _delete_team( + client: httpx.AsyncClient, + team_id: str, + ): + if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None: + logger.warning('LiteLLM API configuration not found') + return + response = await client.post( + f'{LITE_LLM_API_URL}/team/delete', + json={'team_ids': [team_id]}, + ) + + if not response.is_success: + if response.status_code == 404: + # Team doesn't exist, that's fine + logger.info( + 'Team already deleted or does not exist', + extra={'team_id': team_id}, + ) + return + logger.error( + 'error_deleting_litellm_team', + extra={ + 'status_code': response.status_code, + 'text': response.text, + 'team_id': team_id, + }, + ) + response.raise_for_status() + logger.info( + 'LiteLlmManager:_delete_team:team_deleted', + extra={'team_id': team_id}, + ) + @staticmethod async def _add_user_to_team( client: httpx.AsyncClient, @@ -817,6 +854,7 @@ class LiteLlmManager: get_user = staticmethod(with_http_client(_get_user)) update_user = staticmethod(with_http_client(_update_user)) delete_user = staticmethod(with_http_client(_delete_user)) + delete_team = staticmethod(with_http_client(_delete_team)) add_user_to_team = staticmethod(with_http_client(_add_user_to_team)) get_user_team_info = staticmethod(with_http_client(_get_user_team_info)) update_user_in_team = staticmethod(with_http_client(_update_user_in_team)) diff --git a/enterprise/storage/org_service.py b/enterprise/storage/org_service.py new file mode 100644 index 0000000000..1bd3afb17d --- /dev/null +++ b/enterprise/storage/org_service.py @@ -0,0 +1,443 @@ +""" +Service class for managing organization operations. +Separates business logic from route handlers. +""" + +from uuid import UUID, uuid4 +from uuid import UUID as parse_uuid + +from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model +from server.routes.org_models import ( + LiteLLMIntegrationError, + OrgDatabaseError, + OrgNameExistsError, +) +from storage.lite_llm_manager import LiteLlmManager +from storage.org import Org +from storage.org_member import OrgMember +from storage.org_member_store import OrgMemberStore +from storage.org_store import OrgStore +from storage.role_store import RoleStore +from storage.user_store import UserStore + +from openhands.core.logger import openhands_logger as logger + + +class OrgService: + """Service for handling organization-related operations.""" + + @staticmethod + def validate_name_uniqueness(name: str) -> None: + """ + Validate that organization name is unique. + + Args: + name: Organization name to validate + + Raises: + OrgNameExistsError: If organization name already exists + """ + existing_org = OrgStore.get_org_by_name(name) + if existing_org is not None: + raise OrgNameExistsError(name) + + @staticmethod + async def create_litellm_integration(org_id: UUID, user_id: str) -> dict: + """ + Create LiteLLM team integration for the organization. + + Args: + org_id: Organization ID + user_id: User ID who will own the organization + + Returns: + dict: LiteLLM settings object + + Raises: + LiteLLMIntegrationError: If LiteLLM integration fails + """ + try: + settings = await UserStore.create_default_settings( + org_id=str(org_id), user_id=user_id, create_user=False + ) + + if not settings: + logger.error( + 'Failed to create LiteLLM settings', + extra={'org_id': str(org_id), 'user_id': user_id}, + ) + raise LiteLLMIntegrationError('Failed to create LiteLLM settings') + + logger.debug( + 'LiteLLM integration created', + extra={'org_id': str(org_id), 'user_id': user_id}, + ) + return settings + + except LiteLLMIntegrationError: + raise + except Exception as e: + logger.exception( + 'Error creating LiteLLM integration', + extra={'org_id': str(org_id), 'user_id': user_id, 'error': str(e)}, + ) + raise LiteLLMIntegrationError(f'LiteLLM integration failed: {str(e)}') + + @staticmethod + def create_org_entity( + org_id: UUID, + name: str, + contact_name: str, + contact_email: str, + ) -> Org: + """ + Create an organization entity with basic information. + + Args: + org_id: Organization UUID + name: Organization name + contact_name: Contact person name + contact_email: Contact email address + + Returns: + Org: New organization entity (not yet persisted) + """ + return Org( + id=org_id, + name=name, + contact_name=contact_name, + contact_email=contact_email, + org_version=ORG_SETTINGS_VERSION, + default_llm_model=get_default_litellm_model(), + ) + + @staticmethod + def apply_litellm_settings_to_org(org: Org, settings: dict) -> None: + """ + Apply LiteLLM settings to organization entity. + + Args: + org: Organization entity to update + settings: LiteLLM settings object + """ + org_kwargs = OrgStore.get_kwargs_from_settings(settings) + for key, value in org_kwargs.items(): + if hasattr(org, key): + setattr(org, key, value) + + @staticmethod + def get_owner_role(): + """ + Get the owner role from the database. + + Returns: + Role: The owner role object + + Raises: + Exception: If owner role not found + """ + owner_role = RoleStore.get_role_by_name('owner') + if not owner_role: + raise Exception('Owner role not found in database') + return owner_role + + @staticmethod + def create_org_member_entity( + org_id: UUID, + user_id: str, + role_id: int, + settings: dict, + ) -> OrgMember: + """ + Create an organization member entity. + + Args: + org_id: Organization UUID + user_id: User ID (string that will be converted to UUID) + role_id: Role ID + settings: LiteLLM settings object + + Returns: + OrgMember: New organization member entity (not yet persisted) + """ + org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings) + return OrgMember( + org_id=org_id, + user_id=parse_uuid(user_id), + role_id=role_id, + status='active', + **org_member_kwargs, + ) + + @staticmethod + async def create_org_with_owner( + name: str, + contact_name: str, + contact_email: str, + user_id: str, + ) -> Org: + """ + Create a new organization with the specified user as owner. + + This method orchestrates the complete organization creation workflow: + 1. Validates that the organization name doesn't already exist + 2. Generates a unique organization ID + 3. Creates LiteLLM team integration + 4. Creates the organization entity + 5. Applies LiteLLM settings + 6. Creates owner membership + 7. Persists everything in a transaction + + If database persistence fails, LiteLLM resources are cleaned up (compensation). + + Args: + name: Organization name (must be unique) + contact_name: Contact person name + contact_email: Contact email address + user_id: ID of the user who will be the owner + + Returns: + Org: The created organization object + + Raises: + OrgNameExistsError: If organization name already exists + LiteLLMIntegrationError: If LiteLLM integration fails + OrgDatabaseError: If database operations fail + """ + logger.info( + 'Starting organization creation', + extra={'user_id': user_id, 'org_name': name}, + ) + + # Step 1: Validate name uniqueness (fails early, no cleanup needed) + OrgService.validate_name_uniqueness(name) + + # Step 2: Generate organization ID + org_id = uuid4() + + # Step 3: Create LiteLLM integration (external state created) + settings = await OrgService.create_litellm_integration(org_id, user_id) + + # Steps 4-7: Create entities and persist with compensation + # If any of these fail, we need to clean up LiteLLM resources + try: + # Step 4: Create organization entity + org = OrgService.create_org_entity( + org_id=org_id, + name=name, + contact_name=contact_name, + contact_email=contact_email, + ) + + # Step 5: Apply LiteLLM settings + OrgService.apply_litellm_settings_to_org(org, settings) + + # Step 6: Get owner role and create member entity + owner_role = OrgService.get_owner_role() + org_member = OrgService.create_org_member_entity( + org_id=org_id, + user_id=user_id, + role_id=owner_role.id, + settings=settings, + ) + + # Step 7: Persist in transaction (critical section) + persisted_org = await OrgService._persist_with_compensation( + org, org_member, org_id, user_id + ) + + logger.info( + 'Successfully created organization', + extra={ + 'org_id': str(persisted_org.id), + 'org_name': persisted_org.name, + 'user_id': user_id, + 'role': 'owner', + }, + ) + + return persisted_org + + except OrgDatabaseError: + # Already handled by _persist_with_compensation, just re-raise + raise + except Exception as e: + # Unexpected error in steps 4-6, need to clean up LiteLLM + logger.error( + 'Unexpected error during organization creation, initiating cleanup', + extra={ + 'org_id': str(org_id), + 'user_id': user_id, + 'error': str(e), + }, + ) + await OrgService._handle_failure_with_cleanup( + org_id, user_id, e, 'Failed to create organization' + ) + + @staticmethod + async def _persist_with_compensation( + org: Org, + org_member: OrgMember, + org_id: UUID, + user_id: str, + ) -> Org: + """ + Persist organization with compensation on failure. + + If database persistence fails, cleans up LiteLLM resources. + + Args: + org: Organization entity to persist + org_member: Organization member entity to persist + org_id: Organization ID (for cleanup) + user_id: User ID (for cleanup) + + Returns: + Org: The persisted organization object + + Raises: + OrgDatabaseError: If database operations fail + """ + try: + persisted_org = OrgStore.persist_org_with_owner(org, org_member) + return persisted_org + + except Exception as e: + logger.error( + 'Database persistence failed, initiating LiteLLM cleanup', + extra={ + 'org_id': str(org_id), + 'user_id': user_id, + 'error': str(e), + }, + ) + await OrgService._handle_failure_with_cleanup( + org_id, user_id, e, 'Failed to create organization' + ) + + @staticmethod + async def _handle_failure_with_cleanup( + org_id: UUID, + user_id: str, + original_error: Exception, + error_message: str, + ) -> None: + """ + Handle failure by cleaning up LiteLLM resources and raising appropriate error. + + This method performs compensating transaction and raises OrgDatabaseError. + + Args: + org_id: Organization ID + user_id: User ID + original_error: The original exception that caused the failure + error_message: Base error message for the exception + + Raises: + OrgDatabaseError: Always raises with details about the failure + """ + cleanup_error = await OrgService._cleanup_litellm_resources(org_id, user_id) + + if cleanup_error: + logger.error( + 'Both operation and cleanup failed', + extra={ + 'org_id': str(org_id), + 'user_id': user_id, + 'original_error': str(original_error), + 'cleanup_error': str(cleanup_error), + }, + ) + raise OrgDatabaseError( + f'{error_message}: {str(original_error)}. ' + f'Cleanup also failed: {str(cleanup_error)}' + ) + + raise OrgDatabaseError(f'{error_message}: {str(original_error)}') + + @staticmethod + async def _cleanup_litellm_resources( + org_id: UUID, user_id: str + ) -> Exception | None: + """ + Compensating transaction: Clean up LiteLLM resources. + + Deletes the team which should cascade to remove keys and memberships. + This is a best-effort operation - errors are logged but not raised. + + Args: + org_id: Organization ID + user_id: User ID + + Returns: + Exception | None: Exception if cleanup failed, None if successful + """ + try: + await LiteLlmManager.delete_team(str(org_id)) + + logger.info( + 'Successfully cleaned up LiteLLM team', + extra={'org_id': str(org_id), 'user_id': user_id}, + ) + return None + + except Exception as e: + logger.error( + 'Failed to cleanup LiteLLM team (resources may be orphaned)', + extra={ + 'org_id': str(org_id), + 'user_id': user_id, + 'error': str(e), + }, + ) + return e + + @staticmethod + async def get_org_credits(user_id: str, org_id: UUID) -> float | None: + """ + Get organization credits from LiteLLM team. + + Args: + user_id: User ID + org_id: Organization ID + + Returns: + float | None: Credits (max_budget - spend) or None if LiteLLM not configured + """ + try: + user_team_info = await LiteLlmManager.get_user_team_info( + user_id, str(org_id) + ) + if not user_team_info: + logger.warning( + 'No team info available from LiteLLM', + extra={'user_id': user_id, 'org_id': str(org_id)}, + ) + return None + + max_budget = (user_team_info.get('litellm_budget_table') or {}).get( + 'max_budget', 0 + ) + spend = user_team_info.get('spend', 0) + credits = max(max_budget - spend, 0) + + logger.debug( + 'Retrieved organization credits', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'credits': credits, + 'max_budget': max_budget, + 'spend': spend, + }, + ) + + return credits + + except Exception as e: + logger.warning( + 'Failed to retrieve organization credits', + extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)}, + ) + return None diff --git a/enterprise/storage/org_store.py b/enterprise/storage/org_store.py index f7a706349b..87e67865ed 100644 --- a/enterprise/storage/org_store.py +++ b/enterprise/storage/org_store.py @@ -13,6 +13,7 @@ from server.constants import ( from sqlalchemy.orm import joinedload from storage.database import session_maker from storage.org import Org +from storage.org_member import OrgMember from storage.user import User from storage.user_settings import UserSettings @@ -160,3 +161,28 @@ class OrgStore: kwargs['org_version'] = user_settings.user_version return kwargs + + @staticmethod + def persist_org_with_owner( + org: Org, + org_member: OrgMember, + ) -> Org: + """ + Persist organization and owner membership in a single transaction. + + Args: + org: Organization entity to persist + org_member: Organization member entity to persist + + Returns: + Org: The persisted organization object + + Raises: + Exception: If database operations fail + """ + with session_maker() as session: + session.add(org) + session.add(org_member) + session.commit() + session.refresh(org) + return org diff --git a/enterprise/storage/user_store.py b/enterprise/storage/user_store.py index 847cd19e0d..ca30612285 100644 --- a/enterprise/storage/user_store.py +++ b/enterprise/storage/user_store.py @@ -435,7 +435,7 @@ class UserStore: @staticmethod async def create_default_settings( - org_id: str, user_id: str + org_id: str, user_id: str, create_user: bool = True ) -> Optional['Settings']: logger.info( 'UserStore:create_default_settings:start', @@ -451,7 +451,9 @@ class UserStore: from storage.lite_llm_manager import LiteLlmManager - settings = await LiteLlmManager.create_entries(org_id, user_id, settings) + settings = await LiteLlmManager.create_entries( + org_id, user_id, settings, create_user + ) if not settings: logger.info( 'UserStore:create_default_settings:litellm_create_failed', diff --git a/enterprise/tests/unit/server/routes/test_orgs.py b/enterprise/tests/unit/server/routes/test_orgs.py new file mode 100644 index 0000000000..1ec450b4de --- /dev/null +++ b/enterprise/tests/unit/server/routes/test_orgs.py @@ -0,0 +1,377 @@ +""" +Integration tests for organization API routes. + +Tests the POST /api/organizations endpoint with various scenarios. +""" + +import uuid +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI, HTTPException, status +from fastapi.testclient import TestClient + +# Mock database before imports +with patch('storage.database.engine', create=True), patch( + 'storage.database.a_engine', create=True +): + from server.email_validation import get_admin_user_id + from server.routes.org_models import ( + LiteLLMIntegrationError, + OrgDatabaseError, + OrgNameExistsError, + ) + from server.routes.orgs import org_router + from storage.org import Org + + +@pytest.fixture +def mock_app(): + """Create a test FastAPI app with organization routes and mocked auth.""" + app = FastAPI() + app.include_router(org_router) + + # Override the auth dependency to return a test user + def mock_get_openhands_user_id(): + return 'test-user-123' + + app.dependency_overrides[get_admin_user_id] = mock_get_openhands_user_id + + return app + + +@pytest.mark.asyncio +async def test_create_org_success(mock_app): + """ + GIVEN: Valid organization creation request + WHEN: POST /api/organizations is called + THEN: Organization is created and returned with 201 status + """ + # Arrange + org_id = uuid.uuid4() + mock_org = Org( + id=org_id, + name='Test Organization', + contact_name='John Doe', + contact_email='john@example.com', + org_version=5, + default_llm_model='claude-opus-4-5-20251101', + enable_default_condenser=True, + enable_proactive_conversation_starters=True, + ) + + request_data = { + 'name': 'Test Organization', + 'contact_name': 'John Doe', + 'contact_email': 'john@example.com', + } + + with ( + patch( + 'server.routes.orgs.OrgService.create_org_with_owner', + AsyncMock(return_value=mock_org), + ), + patch( + 'server.routes.orgs.OrgService.get_org_credits', + AsyncMock(return_value=100.0), + ), + ): + client = TestClient(mock_app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_201_CREATED + response_data = response.json() + assert response_data['name'] == 'Test Organization' + assert response_data['contact_name'] == 'John Doe' + assert response_data['contact_email'] == 'john@example.com' + assert response_data['credits'] == 100.0 + assert response_data['org_version'] == 5 + assert response_data['default_llm_model'] == 'claude-opus-4-5-20251101' + + +@pytest.mark.asyncio +async def test_create_org_invalid_email(mock_app): + """ + GIVEN: Request with invalid email format + WHEN: POST /api/organizations is called + THEN: 422 validation error is returned + """ + # Arrange + request_data = { + 'name': 'Test Organization', + 'contact_name': 'John Doe', + 'contact_email': 'invalid-email', # Missing @ + } + + client = TestClient(mock_app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +async def test_create_org_empty_name(mock_app): + """ + GIVEN: Request with empty organization name + WHEN: POST /api/organizations is called + THEN: 422 validation error is returned + """ + # Arrange + request_data = { + 'name': '', # Empty string (after whitespace stripping) + 'contact_name': 'John Doe', + 'contact_email': 'john@example.com', + } + + client = TestClient(mock_app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +async def test_create_org_duplicate_name(mock_app): + """ + GIVEN: Organization name already exists + WHEN: POST /api/organizations is called + THEN: 409 Conflict error is returned + """ + # Arrange + request_data = { + 'name': 'Existing Organization', + 'contact_name': 'John Doe', + 'contact_email': 'john@example.com', + } + + with patch( + 'server.routes.orgs.OrgService.create_org_with_owner', + AsyncMock(side_effect=OrgNameExistsError('Existing Organization')), + ): + client = TestClient(mock_app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_409_CONFLICT + assert 'already exists' in response.json()['detail'].lower() + + +@pytest.mark.asyncio +async def test_create_org_litellm_failure(mock_app): + """ + GIVEN: LiteLLM integration fails + WHEN: POST /api/organizations is called + THEN: 500 Internal Server Error is returned + """ + # Arrange + request_data = { + 'name': 'Test Organization', + 'contact_name': 'John Doe', + 'contact_email': 'john@example.com', + } + + with patch( + 'server.routes.orgs.OrgService.create_org_with_owner', + AsyncMock(side_effect=LiteLLMIntegrationError('LiteLLM API unavailable')), + ): + client = TestClient(mock_app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert 'LiteLLM integration' in response.json()['detail'] + + +@pytest.mark.asyncio +async def test_create_org_database_failure(mock_app): + """ + GIVEN: Database operation fails + WHEN: POST /api/organizations is called + THEN: 500 Internal Server Error is returned + """ + # Arrange + request_data = { + 'name': 'Test Organization', + 'contact_name': 'John Doe', + 'contact_email': 'john@example.com', + } + + with patch( + 'server.routes.orgs.OrgService.create_org_with_owner', + AsyncMock(side_effect=OrgDatabaseError('Database connection failed')), + ): + client = TestClient(mock_app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert 'Failed to create organization' in response.json()['detail'] + + +@pytest.mark.asyncio +async def test_create_org_unexpected_error(mock_app): + """ + GIVEN: Unexpected error occurs + WHEN: POST /api/organizations is called + THEN: 500 Internal Server Error is returned with generic message + """ + # Arrange + request_data = { + 'name': 'Test Organization', + 'contact_name': 'John Doe', + 'contact_email': 'john@example.com', + } + + with patch( + 'server.routes.orgs.OrgService.create_org_with_owner', + AsyncMock(side_effect=RuntimeError('Unexpected system error')), + ): + client = TestClient(mock_app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert 'unexpected error' in response.json()['detail'].lower() + + +@pytest.mark.asyncio +async def test_create_org_unauthorized(): + """ + GIVEN: User is not authenticated + WHEN: POST /api/organizations is called + THEN: 401 Unauthorized error is returned + """ + # Arrange + app = FastAPI() + app.include_router(org_router) + + # Override to simulate unauthenticated user + async def mock_unauthenticated(): + raise HTTPException(status_code=401, detail='User not authenticated') + + app.dependency_overrides[get_admin_user_id] = mock_unauthenticated + + request_data = { + 'name': 'Test Organization', + 'contact_name': 'John Doe', + 'contact_email': 'john@example.com', + } + + client = TestClient(app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +@pytest.mark.asyncio +async def test_create_org_forbidden_non_openhands_email(): + """ + GIVEN: User email is not @openhands.dev + WHEN: POST /api/organizations is called + THEN: 403 Forbidden error is returned + """ + # Arrange + app = FastAPI() + app.include_router(org_router) + + # Override to simulate non-@openhands.dev user + async def mock_forbidden(): + raise HTTPException( + status_code=403, detail='Access restricted to @openhands.dev users' + ) + + app.dependency_overrides[get_admin_user_id] = mock_forbidden + + request_data = { + 'name': 'Test Organization', + 'contact_name': 'John Doe', + 'contact_email': 'john@example.com', + } + + client = TestClient(app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_403_FORBIDDEN + assert 'openhands.dev' in response.json()['detail'].lower() + + +@pytest.mark.asyncio +async def test_create_org_sensitive_fields_not_exposed(mock_app): + """ + GIVEN: Organization is created successfully + WHEN: Response is returned + THEN: Sensitive fields (API keys) are not exposed + """ + # Arrange + org_id = uuid.uuid4() + mock_org = Org( + id=org_id, + name='Test Organization', + contact_name='John Doe', + contact_email='john@example.com', + org_version=5, + default_llm_model='claude-opus-4-5-20251101', + enable_default_condenser=True, + enable_proactive_conversation_starters=True, + ) + + request_data = { + 'name': 'Test Organization', + 'contact_name': 'John Doe', + 'contact_email': 'john@example.com', + } + + with ( + patch( + 'server.routes.orgs.OrgService.create_org_with_owner', + AsyncMock(return_value=mock_org), + ), + patch( + 'server.routes.orgs.OrgService.get_org_credits', + AsyncMock(return_value=100.0), + ), + ): + client = TestClient(mock_app) + + # Act + response = client.post('/api/organizations', json=request_data) + + # Assert + assert response.status_code == status.HTTP_201_CREATED + response_data = response.json() + + # Verify sensitive fields are not in response or are None + assert ( + 'default_llm_api_key_for_byor' not in response_data + or response_data.get('default_llm_api_key_for_byor') is None + ) + assert ( + 'search_api_key' not in response_data + or response_data.get('search_api_key') is None + ) + assert ( + 'sandbox_api_key' not in response_data + or response_data.get('sandbox_api_key') is None + ) diff --git a/enterprise/tests/unit/test_email_validation_dependency.py b/enterprise/tests/unit/test_email_validation_dependency.py new file mode 100644 index 0000000000..8ac481a6e2 --- /dev/null +++ b/enterprise/tests/unit/test_email_validation_dependency.py @@ -0,0 +1,272 @@ +""" +Unit tests for email validation dependency (get_admin_user_id). + +Tests the FastAPI dependency that validates @openhands.dev email domain. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException, Request +from server.email_validation import get_admin_user_id + + +@pytest.fixture +def mock_request(): + """Create a mock FastAPI request.""" + return MagicMock(spec=Request) + + +@pytest.fixture +def mock_user_auth(): + """Create a mock user auth object.""" + mock_auth = AsyncMock() + mock_auth.get_user_email = AsyncMock() + return mock_auth + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_success(mock_request, mock_user_auth): + """ + GIVEN: Valid user ID and @openhands.dev email + WHEN: get_admin_user_id is called + THEN: User ID is returned successfully + """ + # Arrange + user_id = 'test-user-123' + mock_user_auth.get_user_email.return_value = 'test@openhands.dev' + + with patch('server.email_validation.get_user_auth', return_value=mock_user_auth): + # Act + result = await get_admin_user_id(mock_request, user_id) + + # Assert + assert result == user_id + mock_user_auth.get_user_email.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_no_user_id(mock_request): + """ + GIVEN: No user ID provided (None) + WHEN: get_admin_user_id is called + THEN: 401 Unauthorized is raised + """ + # Arrange + user_id = None + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_admin_user_id(mock_request, user_id) + + assert exc_info.value.status_code == 401 + assert 'not authenticated' in exc_info.value.detail.lower() + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_no_email(mock_request, mock_user_auth): + """ + GIVEN: User ID provided but email is None + WHEN: get_admin_user_id is called + THEN: 401 Unauthorized is raised + """ + # Arrange + user_id = 'test-user-123' + mock_user_auth.get_user_email.return_value = None + + with patch('server.email_validation.get_user_auth', return_value=mock_user_auth): + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_admin_user_id(mock_request, user_id) + + assert exc_info.value.status_code == 401 + assert 'email not available' in exc_info.value.detail.lower() + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_invalid_domain(mock_request, mock_user_auth): + """ + GIVEN: User ID and email with non-@openhands.dev domain + WHEN: get_admin_user_id is called + THEN: 403 Forbidden is raised + """ + # Arrange + user_id = 'test-user-123' + mock_user_auth.get_user_email.return_value = 'test@external.com' + + with patch('server.email_validation.get_user_auth', return_value=mock_user_auth): + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_admin_user_id(mock_request, user_id) + + assert exc_info.value.status_code == 403 + assert 'openhands.dev' in exc_info.value.detail.lower() + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_empty_string_user_id(mock_request): + """ + GIVEN: Empty string user ID + WHEN: get_admin_user_id is called + THEN: 401 Unauthorized is raised + """ + # Arrange + user_id = '' + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_admin_user_id(mock_request, user_id) + + assert exc_info.value.status_code == 401 + assert 'not authenticated' in exc_info.value.detail.lower() + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_case_sensitivity(mock_request, mock_user_auth): + """ + GIVEN: Email with uppercase @OPENHANDS.DEV domain + WHEN: get_admin_user_id is called + THEN: 403 Forbidden is raised (case-sensitive check) + """ + # Arrange + user_id = 'test-user-123' + mock_user_auth.get_user_email.return_value = 'test@OPENHANDS.DEV' + + with patch('server.email_validation.get_user_auth', return_value=mock_user_auth): + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_admin_user_id(mock_request, user_id) + + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_subdomain_not_allowed( + mock_request, mock_user_auth +): + """ + GIVEN: Email with subdomain like @test.openhands.dev + WHEN: get_admin_user_id is called + THEN: 403 Forbidden is raised + """ + # Arrange + user_id = 'test-user-123' + mock_user_auth.get_user_email.return_value = 'test@test.openhands.dev' + + with patch('server.email_validation.get_user_auth', return_value=mock_user_auth): + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_admin_user_id(mock_request, user_id) + + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_similar_domain_not_allowed( + mock_request, mock_user_auth +): + """ + GIVEN: Email with similar but different domain like @openhands.dev.fake.com + WHEN: get_admin_user_id is called + THEN: 403 Forbidden is raised + """ + # Arrange + user_id = 'test-user-123' + mock_user_auth.get_user_email.return_value = 'test@openhands.dev.fake.com' + + with patch('server.email_validation.get_user_auth', return_value=mock_user_auth): + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_admin_user_id(mock_request, user_id) + + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_logs_warning_on_invalid_domain( + mock_request, mock_user_auth +): + """ + GIVEN: User with invalid email domain + WHEN: get_admin_user_id is called + THEN: Warning is logged with user_id and email_domain + """ + # Arrange + user_id = 'test-user-123' + invalid_email = 'test@external.com' + mock_user_auth.get_user_email.return_value = invalid_email + + with ( + patch('server.email_validation.get_user_auth', return_value=mock_user_auth), + patch('server.email_validation.logger') as mock_logger, + ): + # Act & Assert + with pytest.raises(HTTPException): + await get_admin_user_id(mock_request, user_id) + + # Verify warning was logged + mock_logger.warning.assert_called_once() + call_args = mock_logger.warning.call_args + assert 'Access denied' in call_args[0][0] + assert call_args[1]['extra']['user_id'] == user_id + assert call_args[1]['extra']['email_domain'] == 'external.com' + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_with_plus_addressing(mock_request, mock_user_auth): + """ + GIVEN: Email with plus addressing (test+tag@openhands.dev) + WHEN: get_admin_user_id is called + THEN: User ID is returned successfully + """ + # Arrange + user_id = 'test-user-123' + mock_user_auth.get_user_email.return_value = 'test+tag@openhands.dev' + + with patch('server.email_validation.get_user_auth', return_value=mock_user_auth): + # Act + result = await get_admin_user_id(mock_request, user_id) + + # Assert + assert result == user_id + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_with_dots_in_local_part( + mock_request, mock_user_auth +): + """ + GIVEN: Email with dots in local part (first.last@openhands.dev) + WHEN: get_admin_user_id is called + THEN: User ID is returned successfully + """ + # Arrange + user_id = 'test-user-123' + mock_user_auth.get_user_email.return_value = 'first.last@openhands.dev' + + with patch('server.email_validation.get_user_auth', return_value=mock_user_auth): + # Act + result = await get_admin_user_id(mock_request, user_id) + + # Assert + assert result == user_id + + +@pytest.mark.asyncio +async def test_get_openhands_user_id_empty_email(mock_request, mock_user_auth): + """ + GIVEN: Empty string email + WHEN: get_admin_user_id is called + THEN: 401 Unauthorized is raised + """ + # Arrange + user_id = 'test-user-123' + mock_user_auth.get_user_email.return_value = '' + + with patch('server.email_validation.get_user_auth', return_value=mock_user_auth): + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_admin_user_id(mock_request, user_id) + + assert exc_info.value.status_code == 401 + assert 'email not available' in exc_info.value.detail.lower() diff --git a/enterprise/tests/unit/test_lite_llm_manager.py b/enterprise/tests/unit/test_lite_llm_manager.py index 986869f81c..daf4166de4 100644 --- a/enterprise/tests/unit/test_lite_llm_manager.py +++ b/enterprise/tests/unit/test_lite_llm_manager.py @@ -113,7 +113,7 @@ class TestLiteLlmManager: with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None): with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None): result = await LiteLlmManager.create_entries( - 'test-org-id', 'test-user-id', mock_settings + 'test-org-id', 'test-user-id', mock_settings, create_user=True ) assert result is None @@ -126,7 +126,7 @@ class TestLiteLlmManager: 'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com' ): result = await LiteLlmManager.create_entries( - 'test-org-id', 'test-user-id', mock_settings + 'test-org-id', 'test-user-id', mock_settings, create_user=True ) assert result is not None @@ -158,7 +158,10 @@ class TestLiteLlmManager: mock_client.post.return_value = mock_response result = await LiteLlmManager.create_entries( - 'test-org-id', 'test-user-id', mock_settings + 'test-org-id', + 'test-user-id', + mock_settings, + create_user=False, ) assert result is not None @@ -171,7 +174,7 @@ class TestLiteLlmManager: # Verify API calls were made assert ( - mock_client.post.call_count == 4 + mock_client.post.call_count == 3 ) # create_team, create_user, add_user_to_team, generate_key @pytest.mark.asyncio @@ -988,3 +991,138 @@ class TestLiteLlmManager: # Verify no HTTP calls were made mock_client.get.assert_not_called() mock_client.post.assert_not_called() + + @pytest.mark.asyncio + async def test_delete_team_success(self, mock_http_client, mock_response): + """ + GIVEN: Valid team_id and configured LiteLLM API + WHEN: delete_team is called + THEN: Team is deleted successfully via POST /team/delete + """ + # Arrange + team_id = 'test-team-123' + mock_response.is_success = True + mock_response.status_code = 200 + mock_http_client.post.return_value = mock_response + + with ( + patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'), + patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'), + patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test-team'), + ): + # Act + await LiteLlmManager._delete_team(mock_http_client, team_id) + + # Assert + mock_http_client.post.assert_called_once_with( + 'http://test.url/team/delete', + json={'team_ids': [team_id]}, + ) + + @pytest.mark.asyncio + async def test_delete_team_not_found_is_idempotent( + self, mock_http_client, mock_response + ): + """ + GIVEN: Team does not exist (404 response) + WHEN: delete_team is called + THEN: Operation succeeds without raising exception (idempotent) + """ + # Arrange + team_id = 'non-existent-team' + mock_response.is_success = False + mock_response.status_code = 404 + mock_http_client.post.return_value = mock_response + + with ( + patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'), + patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'), + patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test-team'), + ): + # Act - should not raise + await LiteLlmManager._delete_team(mock_http_client, team_id) + + # Assert + mock_http_client.post.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_team_api_error_raises_exception( + self, mock_http_client, mock_response + ): + """ + GIVEN: LiteLLM API returns error (non-404) + WHEN: delete_team is called + THEN: HTTPStatusError is raised + """ + # Arrange + team_id = 'test-team-123' + mock_response.is_success = False + mock_response.status_code = 500 + mock_response.text = 'Internal Server Error' + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError( + 'Server error', request=MagicMock(), response=mock_response + ) + ) + mock_http_client.post.return_value = mock_response + + with ( + patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'), + patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'), + patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test-team'), + ): + # Act & Assert + with pytest.raises(httpx.HTTPStatusError): + await LiteLlmManager._delete_team(mock_http_client, team_id) + + @pytest.mark.asyncio + async def test_delete_team_no_config_returns_early(self, mock_http_client): + """ + GIVEN: LiteLLM API is not configured + WHEN: delete_team is called + THEN: Function returns early without making API call + """ + # Arrange + team_id = 'test-team-123' + + with ( + patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None), + patch('storage.lite_llm_manager.LITE_LLM_API_URL', None), + ): + # Act + await LiteLlmManager._delete_team(mock_http_client, team_id) + + # Assert + mock_http_client.post.assert_not_called() + + @pytest.mark.asyncio + async def test_delete_team_public_method(self): + """ + GIVEN: Valid team_id + WHEN: Public delete_team method is called + THEN: HTTP client is created and team is deleted + """ + # Arrange + team_id = 'test-team-123' + mock_response = AsyncMock() + mock_response.is_success = True + mock_response.status_code = 200 + + with ( + patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'), + patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'), + patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test-team'), + patch('httpx.AsyncClient') as mock_client_class, + ): + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Act + await LiteLlmManager.delete_team(team_id) + + # Assert + mock_client.post.assert_called_once_with( + 'http://test.url/team/delete', + json={'team_ids': [team_id]}, + ) diff --git a/enterprise/tests/unit/test_org_service.py b/enterprise/tests/unit/test_org_service.py new file mode 100644 index 0000000000..8b9f214392 --- /dev/null +++ b/enterprise/tests/unit/test_org_service.py @@ -0,0 +1,564 @@ +""" +Unit tests for OrgService. + +Tests the organization creation workflow with compensation pattern, +including LiteLLM integration and cleanup on failures. +""" + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Mock the database module before importing OrgService +with patch('storage.database.engine', create=True), patch( + 'storage.database.a_engine', create=True +): + from server.routes.org_models import ( + LiteLLMIntegrationError, + OrgDatabaseError, + OrgNameExistsError, + ) + from storage.org import Org + from storage.org_member import OrgMember + from storage.org_service import OrgService + from storage.role import Role + from storage.user import User + + +@pytest.fixture +def mock_litellm_api(): + """Mock LiteLLM API for testing.""" + api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key') + api_url_patch = patch( + 'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url' + ) + team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team') + client_patch = patch('httpx.AsyncClient') + + with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client: + mock_response = AsyncMock() + mock_response.is_success = True + mock_response.status_code = 200 + mock_response.json = MagicMock( + return_value={ + 'team_id': 'test-team-id', + 'user_id': 'test-user-id', + 'key': 'test-api-key', + } + ) + mock_client.return_value.__aenter__.return_value.post.return_value = ( + mock_response + ) + mock_client.return_value.__aenter__.return_value.get.return_value = ( + mock_response + ) + yield mock_client + + +@pytest.fixture +def owner_role(session_maker): + """Create owner role in database.""" + with session_maker() as session: + role = Role(id=1, name='owner', rank=1) + session.add(role) + session.commit() + return role + + +def test_validate_name_uniqueness_with_unique_name(session_maker): + """ + GIVEN: A unique organization name + WHEN: validate_name_uniqueness is called + THEN: No exception is raised + """ + # Arrange + unique_name = 'unique-org-name' + + # Act & Assert - should not raise + with patch('storage.org_store.session_maker', session_maker): + OrgService.validate_name_uniqueness(unique_name) + + +def test_validate_name_uniqueness_with_duplicate_name(session_maker): + """ + GIVEN: An organization name that already exists + WHEN: validate_name_uniqueness is called + THEN: OrgNameExistsError is raised + """ + # Arrange + existing_name = 'existing-org' + existing_org = Org(name=existing_name) + + # Mock OrgStore.get_org_by_name to return the existing org + with patch( + 'storage.org_service.OrgStore.get_org_by_name', + return_value=existing_org, + ): + # Act & Assert + with pytest.raises(OrgNameExistsError) as exc_info: + OrgService.validate_name_uniqueness(existing_name) + + assert existing_name in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_create_org_with_owner_success( + session_maker, owner_role, mock_litellm_api +): + """ + GIVEN: Valid organization data and user ID + WHEN: create_org_with_owner is called + THEN: Organization and owner membership are created successfully + """ + # Arrange + org_name = 'test-org' + contact_name = 'John Doe' + contact_email = 'john@example.com' + user_id = uuid.uuid4() + temp_org_id = uuid.uuid4() + + # Create user in database first + with session_maker() as session: + user = User(id=user_id, current_org_id=temp_org_id) + session.add(user) + session.commit() + + mock_settings = {'team_id': 'test-team', 'user_id': str(user_id)} + + with ( + patch('storage.org_store.session_maker', session_maker), + patch('storage.role_store.session_maker', session_maker), + patch( + 'storage.org_service.UserStore.create_default_settings', + AsyncMock(return_value=mock_settings), + ), + patch( + 'storage.org_service.OrgStore.get_kwargs_from_settings', + return_value={}, + ), + patch( + 'storage.org_service.OrgMemberStore.get_kwargs_from_settings', + return_value={'llm_api_key': 'test-key'}, + ), + ): + # Act + result = await OrgService.create_org_with_owner( + name=org_name, + contact_name=contact_name, + contact_email=contact_email, + user_id=str(user_id), + ) + + # Assert + assert result is not None + assert result.name == org_name + assert result.contact_name == contact_name + assert result.contact_email == contact_email + assert result.org_version > 0 # Should be set to ORG_SETTINGS_VERSION + assert result.default_llm_model is not None # Should be set + + # Verify organization was persisted + with session_maker() as session: + persisted_org = session.get(Org, result.id) + assert persisted_org is not None + assert persisted_org.name == org_name + + # Verify owner membership was created + org_member = ( + session.query(OrgMember) + .filter_by(org_id=result.id, user_id=user_id) + .first() + ) + assert org_member is not None + assert org_member.role_id == 1 # owner role id + assert org_member.status == 'active' + + +@pytest.mark.asyncio +async def test_create_org_with_owner_duplicate_name( + session_maker, owner_role, mock_litellm_api +): + """ + GIVEN: An organization name that already exists + WHEN: create_org_with_owner is called + THEN: OrgNameExistsError is raised without creating LiteLLM resources + """ + # Arrange + existing_name = 'existing-org' + with session_maker() as session: + org = Org(name=existing_name) + session.add(org) + session.commit() + + mock_create_settings = AsyncMock() + + # Act & Assert + with ( + patch('storage.org_store.session_maker', session_maker), + patch('storage.role_store.session_maker', session_maker), + patch( + 'storage.org_service.UserStore.create_default_settings', + mock_create_settings, + ), + ): + with pytest.raises(OrgNameExistsError): + await OrgService.create_org_with_owner( + name=existing_name, + contact_name='John Doe', + contact_email='john@example.com', + user_id='test-user-123', + ) + + # Verify no LiteLLM API calls were made (early exit) + mock_create_settings.assert_not_called() + + +@pytest.mark.asyncio +async def test_create_org_with_owner_litellm_failure( + session_maker, owner_role, mock_litellm_api +): + """ + GIVEN: LiteLLM integration fails + WHEN: create_org_with_owner is called + THEN: LiteLLMIntegrationError is raised and no database records are created + """ + # Arrange + org_name = 'test-org' + + # Mock LiteLLM failure + with ( + patch('storage.org_store.session_maker', session_maker), + patch( + 'storage.org_service.UserStore.create_default_settings', + AsyncMock(return_value=None), + ), + ): + # Act & Assert + with pytest.raises(LiteLLMIntegrationError): + await OrgService.create_org_with_owner( + name=org_name, + contact_name='John Doe', + contact_email='john@example.com', + user_id='test-user-123', + ) + + # Verify no organization was created in database + with session_maker() as session: + org = session.query(Org).filter_by(name=org_name).first() + assert org is None + + +@pytest.mark.asyncio +async def test_create_org_with_owner_database_failure_triggers_cleanup( + session_maker, owner_role, mock_litellm_api +): + """ + GIVEN: Database persistence fails after LiteLLM integration succeeds + WHEN: create_org_with_owner is called + THEN: OrgDatabaseError is raised and LiteLLM cleanup is triggered + """ + # Arrange + org_name = 'test-org' + user_id = str(uuid.uuid4()) + cleanup_called = False + + def mock_cleanup(*args, **kwargs): + nonlocal cleanup_called + cleanup_called = True + return None + + mock_settings = {'team_id': 'test-team', 'user_id': user_id} + + with ( + patch('storage.org_store.session_maker', session_maker), + patch('storage.role_store.session_maker', session_maker), + patch( + 'storage.org_service.UserStore.create_default_settings', + AsyncMock(return_value=mock_settings), + ), + patch( + 'storage.org_service.OrgStore.get_kwargs_from_settings', + return_value={}, + ), + patch( + 'storage.org_service.OrgMemberStore.get_kwargs_from_settings', + return_value={'llm_api_key': 'test-key'}, + ), + patch( + 'storage.org_service.OrgStore.persist_org_with_owner', + side_effect=Exception('Database connection failed'), + ), + patch( + 'storage.org_service.OrgService._cleanup_litellm_resources', + AsyncMock(side_effect=mock_cleanup), + ), + ): + # Act & Assert + with pytest.raises(OrgDatabaseError) as exc_info: + await OrgService.create_org_with_owner( + name=org_name, + contact_name='John Doe', + contact_email='john@example.com', + user_id=user_id, + ) + + # Verify cleanup was called + assert cleanup_called + assert 'Database connection failed' in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_create_org_with_owner_entity_creation_failure_triggers_cleanup( + session_maker, owner_role, mock_litellm_api +): + """ + GIVEN: Entity creation fails after LiteLLM integration succeeds + WHEN: create_org_with_owner is called + THEN: OrgDatabaseError is raised and LiteLLM cleanup is triggered + """ + # Arrange + org_name = 'test-org' + user_id = str(uuid.uuid4()) + + mock_settings = {'team_id': 'test-team', 'user_id': user_id} + + with ( + patch('storage.org_store.session_maker', session_maker), + patch( + 'storage.org_service.UserStore.create_default_settings', + AsyncMock(return_value=mock_settings), + ), + patch( + 'storage.org_service.OrgStore.get_kwargs_from_settings', + return_value={}, + ), + patch( + 'storage.org_service.OrgMemberStore.get_kwargs_from_settings', + return_value={'llm_api_key': 'test-key'}, + ), + patch( + 'storage.org_service.OrgService.get_owner_role', + side_effect=Exception('Owner role not found'), + ), + patch( + 'storage.org_service.LiteLlmManager.delete_team', + AsyncMock(), + ) as mock_delete, + ): + # Act & Assert + with pytest.raises(OrgDatabaseError) as exc_info: + await OrgService.create_org_with_owner( + name=org_name, + contact_name='John Doe', + contact_email='john@example.com', + user_id=user_id, + ) + + # Verify cleanup was called + mock_delete.assert_called_once() + assert 'Owner role not found' in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_cleanup_litellm_resources_success(mock_litellm_api): + """ + GIVEN: Valid org_id and user_id + WHEN: _cleanup_litellm_resources is called + THEN: LiteLLM team is deleted successfully and None is returned + """ + # Arrange + org_id = uuid.uuid4() + user_id = 'test-user-123' + + with patch( + 'storage.org_service.LiteLlmManager.delete_team', + AsyncMock(), + ) as mock_delete: + # Act + result = await OrgService._cleanup_litellm_resources(org_id, user_id) + + # Assert + assert result is None + mock_delete.assert_called_once_with(str(org_id)) + + +@pytest.mark.asyncio +async def test_cleanup_litellm_resources_failure_returns_exception(mock_litellm_api): + """ + GIVEN: LiteLLM delete_team fails + WHEN: _cleanup_litellm_resources is called + THEN: Exception is returned (not raised) for logging + """ + # Arrange + org_id = uuid.uuid4() + user_id = 'test-user-123' + expected_error = Exception('LiteLLM API unavailable') + + with patch( + 'storage.org_service.LiteLlmManager.delete_team', + AsyncMock(side_effect=expected_error), + ): + # Act + result = await OrgService._cleanup_litellm_resources(org_id, user_id) + + # Assert + assert result is expected_error + assert 'LiteLLM API unavailable' in str(result) + + +@pytest.mark.asyncio +async def test_handle_failure_with_cleanup_success(): + """ + GIVEN: Original error and successful cleanup + WHEN: _handle_failure_with_cleanup is called + THEN: OrgDatabaseError is raised with original error message + """ + # Arrange + org_id = uuid.uuid4() + user_id = 'test-user-123' + original_error = Exception('Database write failed') + + with patch( + 'storage.org_service.OrgService._cleanup_litellm_resources', + AsyncMock(return_value=None), + ): + # Act & Assert + with pytest.raises(OrgDatabaseError) as exc_info: + await OrgService._handle_failure_with_cleanup( + org_id, user_id, original_error, 'Failed to create organization' + ) + + assert 'Database write failed' in str(exc_info.value) + assert 'Cleanup also failed' not in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_handle_failure_with_cleanup_both_fail(): + """ + GIVEN: Original error and cleanup also fails + WHEN: _handle_failure_with_cleanup is called + THEN: OrgDatabaseError is raised with both error messages + """ + # Arrange + org_id = uuid.uuid4() + user_id = 'test-user-123' + original_error = Exception('Database write failed') + cleanup_error = Exception('LiteLLM API unavailable') + + with patch( + 'storage.org_service.OrgService._cleanup_litellm_resources', + AsyncMock(return_value=cleanup_error), + ): + # Act & Assert + with pytest.raises(OrgDatabaseError) as exc_info: + await OrgService._handle_failure_with_cleanup( + org_id, user_id, original_error, 'Failed to create organization' + ) + + error_message = str(exc_info.value) + assert 'Database write failed' in error_message + assert 'Cleanup also failed' in error_message + assert 'LiteLLM API unavailable' in error_message + + +@pytest.mark.asyncio +async def test_get_org_credits_success(mock_litellm_api): + """ + GIVEN: Valid user_id and org_id with LiteLLM team info + WHEN: get_org_credits is called + THEN: Credits are calculated correctly (max_budget - spend) + """ + # Arrange + user_id = 'test-user-123' + org_id = uuid.uuid4() + max_budget = 100.0 + spend = 25.0 + + mock_team_info = { + 'litellm_budget_table': {'max_budget': max_budget}, + 'spend': spend, + } + + with patch( + 'storage.org_service.LiteLlmManager.get_user_team_info', + AsyncMock(return_value=mock_team_info), + ): + # Act + credits = await OrgService.get_org_credits(user_id, org_id) + + # Assert + assert credits == 75.0 # 100 - 25 + + +@pytest.mark.asyncio +async def test_get_org_credits_no_team_info(mock_litellm_api): + """ + GIVEN: LiteLLM returns no team info + WHEN: get_org_credits is called + THEN: None is returned + """ + # Arrange + user_id = 'test-user-123' + org_id = uuid.uuid4() + + with patch( + 'storage.org_service.LiteLlmManager.get_user_team_info', + AsyncMock(return_value=None), + ): + # Act + credits = await OrgService.get_org_credits(user_id, org_id) + + # Assert + assert credits is None + + +@pytest.mark.asyncio +async def test_get_org_credits_negative_credits_returns_zero(mock_litellm_api): + """ + GIVEN: Spend exceeds max_budget + WHEN: get_org_credits is called + THEN: Zero credits are returned (not negative) + """ + # Arrange + user_id = 'test-user-123' + org_id = uuid.uuid4() + max_budget = 100.0 + spend = 150.0 # Over budget + + mock_team_info = { + 'litellm_budget_table': {'max_budget': max_budget}, + 'spend': spend, + } + + with patch( + 'storage.org_service.LiteLlmManager.get_user_team_info', + AsyncMock(return_value=mock_team_info), + ): + # Act + credits = await OrgService.get_org_credits(user_id, org_id) + + # Assert + assert credits == 0.0 + + +@pytest.mark.asyncio +async def test_get_org_credits_api_failure_returns_none(mock_litellm_api): + """ + GIVEN: LiteLLM API call fails + WHEN: get_org_credits is called + THEN: None is returned and error is logged + """ + # Arrange + user_id = 'test-user-123' + org_id = uuid.uuid4() + + with patch( + 'storage.org_service.LiteLlmManager.get_user_team_info', + AsyncMock(side_effect=Exception('API error')), + ): + # Act + credits = await OrgService.get_org_credits(user_id, org_id) + + # Assert + assert credits is None diff --git a/enterprise/tests/unit/test_org_store.py b/enterprise/tests/unit/test_org_store.py index f5257ac26d..3601bbd54b 100644 --- a/enterprise/tests/unit/test_org_store.py +++ b/enterprise/tests/unit/test_org_store.py @@ -3,11 +3,17 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import SecretStr +from sqlalchemy.exc import IntegrityError # Mock the database module before importing OrgStore -with patch('storage.database.engine'), patch('storage.database.a_engine'): +with patch('storage.database.engine', create=True), patch( + 'storage.database.a_engine', create=True +): from storage.org import Org + from storage.org_member import OrgMember from storage.org_store import OrgStore + from storage.role import Role + from storage.user import User from openhands.storage.data_models.settings import Settings @@ -195,3 +201,217 @@ def test_get_kwargs_from_settings(): assert 'llm_api_key' not in kwargs assert 'llm_model' not in kwargs assert 'enable_sound_notifications' not in kwargs + + +def test_persist_org_with_owner_success(session_maker, mock_litellm_api): + """ + GIVEN: Valid org and org_member entities + WHEN: persist_org_with_owner is called + THEN: Both entities are persisted in a single transaction and org is returned + """ + # Arrange + org_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Create user and role first + with session_maker() as session: + user = User(id=user_id, current_org_id=org_id) + role = Role(id=1, name='owner', rank=1) + session.add(user) + session.add(role) + session.commit() + + org = Org( + id=org_id, + name='Test Organization', + contact_name='John Doe', + contact_email='john@example.com', + ) + + org_member = OrgMember( + org_id=org_id, + user_id=user_id, + role_id=1, + status='active', + llm_api_key='test-api-key-123', + ) + + # Act + with patch('storage.org_store.session_maker', session_maker): + result = OrgStore.persist_org_with_owner(org, org_member) + + # Assert + assert result is not None + assert result.id == org_id + assert result.name == 'Test Organization' + + # Verify both entities were persisted + with session_maker() as session: + persisted_org = session.get(Org, org_id) + assert persisted_org is not None + assert persisted_org.name == 'Test Organization' + + persisted_member = ( + session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first() + ) + assert persisted_member is not None + assert persisted_member.status == 'active' + assert persisted_member.role_id == 1 + + +def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litellm_api): + """ + GIVEN: Valid org and org_member entities + WHEN: persist_org_with_owner is called + THEN: The returned org is refreshed from database with all fields populated + """ + # Arrange + org_id = uuid.uuid4() + user_id = uuid.uuid4() + + with session_maker() as session: + user = User(id=user_id, current_org_id=org_id) + role = Role(id=1, name='owner', rank=1) + session.add(user) + session.add(role) + session.commit() + + org = Org( + id=org_id, + name='Test Org', + contact_name='Jane Doe', + contact_email='jane@example.com', + agent='CodeActAgent', + ) + + org_member = OrgMember( + org_id=org_id, + user_id=user_id, + role_id=1, + status='active', + llm_api_key='test-key', + ) + + # Act + with patch('storage.org_store.session_maker', session_maker): + result = OrgStore.persist_org_with_owner(org, org_member) + + # Assert - verify the returned object has database-generated fields + assert result.id == org_id + assert result.name == 'Test Org' + assert result.agent == 'CodeActAgent' + # Verify org_version was set by create_org logic (if applicable) + assert hasattr(result, 'org_version') + + +def test_persist_org_with_owner_transaction_atomicity(session_maker, mock_litellm_api): + """ + GIVEN: Valid org but invalid org_member (missing required field) + WHEN: persist_org_with_owner is called + THEN: Transaction fails and neither entity is persisted + """ + # Arrange + org_id = uuid.uuid4() + user_id = uuid.uuid4() + + with session_maker() as session: + user = User(id=user_id, current_org_id=org_id) + role = Role(id=1, name='owner', rank=1) + session.add(user) + session.add(role) + session.commit() + + org = Org( + id=org_id, + name='Test Org', + contact_name='John Doe', + contact_email='john@example.com', + ) + + # Create invalid org_member (missing required llm_api_key field) + org_member = OrgMember( + org_id=org_id, + user_id=user_id, + role_id=1, + status='active', + # llm_api_key is missing - should cause NOT NULL constraint violation + ) + + # Act & Assert + with patch('storage.org_store.session_maker', session_maker): + with pytest.raises(IntegrityError): # NOT NULL constraint violation + OrgStore.persist_org_with_owner(org, org_member) + + # Verify neither entity was persisted (transaction rolled back) + with session_maker() as session: + persisted_org = session.get(Org, org_id) + assert persisted_org is None + + persisted_member = ( + session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first() + ) + assert persisted_member is None + + +def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm_api): + """ + GIVEN: Org with multiple optional fields populated + WHEN: persist_org_with_owner is called + THEN: All fields are persisted correctly + """ + # Arrange + org_id = uuid.uuid4() + user_id = uuid.uuid4() + + with session_maker() as session: + user = User(id=user_id, current_org_id=org_id) + role = Role(id=1, name='owner', rank=1) + session.add(user) + session.add(role) + session.commit() + + org = Org( + id=org_id, + name='Complex Org', + contact_name='Alice Smith', + contact_email='alice@example.com', + agent='CodeActAgent', + default_max_iterations=50, + confirmation_mode=True, + billing_margin=0.15, + ) + + org_member = OrgMember( + org_id=org_id, + user_id=user_id, + role_id=1, + status='active', + llm_api_key='test-key', + max_iterations=100, + llm_model='gpt-4', + ) + + # Act + with patch('storage.org_store.session_maker', session_maker): + result = OrgStore.persist_org_with_owner(org, org_member) + + # Assert + assert result.name == 'Complex Org' + assert result.agent == 'CodeActAgent' + assert result.default_max_iterations == 50 + assert result.confirmation_mode is True + assert result.billing_margin == 0.15 + + # Verify persistence + with session_maker() as session: + persisted_org = session.get(Org, org_id) + assert persisted_org.agent == 'CodeActAgent' + assert persisted_org.default_max_iterations == 50 + assert persisted_org.confirmation_mode is True + assert persisted_org.billing_margin == 0.15 + + persisted_member = ( + session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first() + ) + assert persisted_member.max_iterations == 100 + assert persisted_member.llm_model == 'gpt-4' diff --git a/enterprise/tests/unit/test_saas_user_auth.py b/enterprise/tests/unit/test_saas_user_auth.py index 6d9ced0057..66b4e45dd0 100644 --- a/enterprise/tests/unit/test_saas_user_auth.py +++ b/enterprise/tests/unit/test_saas_user_auth.py @@ -30,15 +30,27 @@ def mock_request(): return request +def create_mock_jwt_tokens(user_id='test_user_id', exp_offset=3600): + """Helper to create valid JWT tokens for mocking.""" + payload = { + 'sub': user_id, + 'exp': int(time.time()) + exp_offset, + 'email': 'test@example.com', + 'email_verified': True, + } + access_token = jwt.encode(payload, 'secret', algorithm='HS256') + refresh_token = jwt.encode( + {'sub': user_id, 'exp': int(time.time()) + exp_offset}, + 'secret', + algorithm='HS256', + ) + return {'access_token': access_token, 'refresh_token': refresh_token} + + @pytest.fixture def mock_token_manager(): with patch('server.auth.saas_user_auth.token_manager') as mock_tm: - mock_tm.refresh = AsyncMock( - return_value={ - 'access_token': 'new_access_token', - 'refresh_token': 'new_refresh_token', - } - ) + mock_tm.refresh = AsyncMock(return_value=create_mock_jwt_tokens()) mock_tm.get_user_info_from_user_id = AsyncMock( return_value={ 'federatedIdentities': [ @@ -108,8 +120,11 @@ async def test_refresh(mock_token_manager): await user_auth.refresh() mock_token_manager.refresh.assert_called_once_with(refresh_token) - assert user_auth.access_token.get_secret_value() == 'new_access_token' - assert user_auth.refresh_token.get_secret_value() == 'new_refresh_token' + # Access token should be a valid JWT + access_token = user_auth.access_token.get_secret_value() + decoded = jwt.decode(access_token, options={'verify_signature': False}) + assert decoded['sub'] == 'test_user_id' + assert decoded['email'] == 'test@example.com' assert user_auth.refreshed is True @@ -159,7 +174,9 @@ async def test_get_access_token_with_expired_token(mock_token_manager): result = await user_auth.get_access_token() - assert result.get_secret_value() == 'new_access_token' + # Verify the returned token is a valid JWT with correct user_id + decoded = jwt.decode(result.get_secret_value(), options={'verify_signature': False}) + assert decoded['sub'] == 'test_user_id' mock_token_manager.refresh.assert_called_once_with(refresh_token) @@ -182,7 +199,9 @@ async def test_get_access_token_with_no_token(mock_token_manager): result = await user_auth.get_access_token() - assert result.get_secret_value() == 'new_access_token' + # Verify the returned token is a valid JWT with correct user_id + decoded = jwt.decode(result.get_secret_value(), options={'verify_signature': False}) + assert decoded['sub'] == 'test_user_id' mock_token_manager.refresh.assert_called_once_with(refresh_token) @@ -339,6 +358,13 @@ async def test_saas_user_auth_from_bearer_success(): mock_request = MagicMock() mock_request.headers = {'Authorization': 'Bearer test_api_key'} + # Create a valid offline token (refresh token) + offline_token = jwt.encode( + {'sub': 'test_user_id', 'exp': int(time.time()) + 3600}, + 'secret', + algorithm='HS256', + ) + with ( patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls, patch('server.auth.saas_user_auth.token_manager') as mock_token_manager, @@ -347,15 +373,18 @@ async def test_saas_user_auth_from_bearer_success(): mock_api_key_store.validate_api_key.return_value = 'test_user_id' mock_api_key_store_cls.get_instance.return_value = mock_api_key_store - mock_token_manager.load_offline_token = AsyncMock(return_value='offline_token') + mock_token_manager.load_offline_token = AsyncMock(return_value=offline_token) + mock_token_manager.refresh = AsyncMock( + return_value=create_mock_jwt_tokens('test_user_id') + ) result = await saas_user_auth_from_bearer(mock_request) assert isinstance(result, SaasUserAuth) assert result.user_id == 'test_user_id' - assert result.refresh_token.get_secret_value() == 'offline_token' mock_api_key_store.validate_api_key.assert_called_once_with('test_api_key') mock_token_manager.load_offline_token.assert_called_once_with('test_user_id') + mock_token_manager.refresh.assert_called_once_with(offline_token) @pytest.mark.asyncio