From 4a3a42c8584821fe769c1f409bb585cf33487018 Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Tue, 3 Mar 2026 03:47:22 -0700 Subject: [PATCH] refactor(enterprise): make OrgStore fully async (#13154) Co-authored-by: openhands Co-authored-by: OpenHands Bot --- enterprise/integrations/github/github_view.py | 12 +- enterprise/integrations/stripe_service.py | 10 +- enterprise/integrations/utils.py | 5 +- enterprise/server/routes/orgs.py | 2 +- .../server/services/org_invitation_service.py | 4 +- enterprise/storage/org_service.py | 30 +- enterprise/storage/org_store.py | 153 ++++----- .../tests/unit/server/routes/test_orgs.py | 18 +- .../unit/test_get_user_v1_enabled_setting.py | 19 +- enterprise/tests/unit/test_org_service.py | 227 +++++++++----- enterprise/tests/unit/test_org_store.py | 291 ++++++++++-------- .../tests/unit/test_stripe_service_db.py | 31 +- 12 files changed, 465 insertions(+), 337 deletions(-) diff --git a/enterprise/integrations/github/github_view.py b/enterprise/integrations/github/github_view.py index 8861f3884e..e8d6e525b3 100644 --- a/enterprise/integrations/github/github_view.py +++ b/enterprise/integrations/github/github_view.py @@ -72,7 +72,6 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool: This function checks both the global environment variable kill switch AND the user's individual setting. Both must be true for the function to return true. """ - # If no user ID is provided, we can't check user settings if not user_id: return False @@ -81,13 +80,10 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool: if not ENABLE_PROACTIVE_CONVERSATION_STARTERS: return False - def _get_setting(): - org = OrgStore.get_current_org_from_keycloak_user_id(user_id) - if not org: - return False - return bool(org.enable_proactive_conversation_starters) - - return await call_sync_from_async(_get_setting) + org = await OrgStore.get_current_org_from_keycloak_user_id(user_id) + if not org: + return False + return bool(org.enable_proactive_conversation_starters) # ================================================= diff --git a/enterprise/integrations/stripe_service.py b/enterprise/integrations/stripe_service.py index cc7f9ef857..ce9e25ce47 100644 --- a/enterprise/integrations/stripe_service.py +++ b/enterprise/integrations/stripe_service.py @@ -9,8 +9,6 @@ from storage.org import Org from storage.org_store import OrgStore from storage.stripe_customer import StripeCustomer -from openhands.utils.async_utils import call_sync_from_async - stripe.api_key = STRIPE_API_KEY @@ -38,9 +36,7 @@ async def find_customer_id_by_org_id(org_id: UUID) -> str | None: async def find_customer_id_by_user_id(user_id: str) -> str | None: # First search our own DB... - org = await call_sync_from_async( - OrgStore.get_current_org_from_keycloak_user_id, user_id - ) + org = await OrgStore.get_current_org_from_keycloak_user_id(user_id) if not org: logger.warning(f'Org not found for user {user_id}') return None @@ -50,9 +46,7 @@ async def find_customer_id_by_user_id(user_id: str) -> str | None: async def find_or_create_customer_by_user_id(user_id: str) -> dict | None: # Get the current org for the user - org = await call_sync_from_async( - OrgStore.get_current_org_from_keycloak_user_id, user_id - ) + org = await OrgStore.get_current_org_from_keycloak_user_id(user_id) if not org: logger.warning(f'Org not found for user {user_id}') return None diff --git a/enterprise/integrations/utils.py b/enterprise/integrations/utils.py index 02bd57ad55..6b2cbcd042 100644 --- a/enterprise/integrations/utils.py +++ b/enterprise/integrations/utils.py @@ -21,7 +21,6 @@ from openhands.events.event_store_abc import EventStoreABC from openhands.events.observation.agent import AgentStateChangedObservation from openhands.integrations.service_types import Repository from openhands.storage.data_models.conversation_status import ConversationStatus -from openhands.utils.async_utils import call_sync_from_async if TYPE_CHECKING: from openhands.server.conversation_manager.conversation_manager import ( @@ -122,9 +121,7 @@ async def get_user_v1_enabled_setting(user_id: str | None) -> bool: if not user_id: return False - org = await call_sync_from_async( - OrgStore.get_current_org_from_keycloak_user_id, user_id - ) + org = await OrgStore.get_current_org_from_keycloak_user_id(user_id) if not org or org.v1_enabled is None: return False diff --git a/enterprise/server/routes/orgs.py b/enterprise/server/routes/orgs.py index 6309ba1bb7..b9c39a0925 100644 --- a/enterprise/server/routes/orgs.py +++ b/enterprise/server/routes/orgs.py @@ -105,7 +105,7 @@ async def list_user_orgs( ) # Fetch organizations from service layer - orgs, next_page_id = OrgService.get_user_orgs_paginated( + orgs, next_page_id = await OrgService.get_user_orgs_paginated( user_id=user_id, page_id=page_id, limit=limit, diff --git a/enterprise/server/services/org_invitation_service.py b/enterprise/server/services/org_invitation_service.py index e31e43c61f..2c1ba62f9b 100644 --- a/enterprise/server/services/org_invitation_service.py +++ b/enterprise/server/services/org_invitation_service.py @@ -73,7 +73,7 @@ class OrgInvitationService: ) # Step 1: Validate organization exists - org = OrgStore.get_org_by_id(org_id) + org = await OrgStore.get_org_by_id(org_id) if not org: raise ValueError(f'Organization {org_id} not found') @@ -187,7 +187,7 @@ class OrgInvitationService: ) # Step 1: Validate permissions upfront (shared for all emails) - org = OrgStore.get_org_by_id(org_id) + org = await OrgStore.get_org_by_id(org_id) if not org: raise ValueError(f'Organization {org_id} not found') diff --git a/enterprise/storage/org_service.py b/enterprise/storage/org_service.py index 22bb4360ce..3d328b3ff6 100644 --- a/enterprise/storage/org_service.py +++ b/enterprise/storage/org_service.py @@ -31,7 +31,7 @@ class OrgService: """Service for handling organization-related operations.""" @staticmethod - def validate_name_uniqueness(name: str) -> None: + async def validate_name_uniqueness(name: str) -> None: """ Validate that organization name is unique. @@ -41,7 +41,7 @@ class OrgService: Raises: OrgNameExistsError: If organization name already exists """ - existing_org = OrgStore.get_org_by_name(name) + existing_org = await OrgStore.get_org_by_name(name) if existing_org is not None: raise OrgNameExistsError(name) @@ -214,7 +214,7 @@ class OrgService: ) # Step 1: Validate name uniqueness (fails early, no cleanup needed) - OrgService.validate_name_uniqueness(name) + await OrgService.validate_name_uniqueness(name) # Step 2: Generate organization ID org_id = uuid4() @@ -304,7 +304,7 @@ class OrgService: OrgDatabaseError: If database operations fail """ try: - persisted_org = OrgStore.persist_org_with_owner(org, org_member) + persisted_org = await OrgStore.persist_org_with_owner(org, org_member) return persisted_org except Exception as e: @@ -535,7 +535,7 @@ class OrgService: ) # Validate organization exists - existing_org = OrgStore.get_org_by_id(org_id) + existing_org = await OrgStore.get_org_by_id(org_id) if not existing_org: raise ValueError(f'Organization with ID {org_id} not found') @@ -555,7 +555,7 @@ class OrgService: # Check if name is being updated and validate uniqueness if update_data.name is not None: # Check if new name conflicts with another org - existing_org_with_name = OrgStore.get_org_by_name(update_data.name) + existing_org_with_name = await OrgStore.get_org_by_name(update_data.name) if ( existing_org_with_name is not None and existing_org_with_name.id != org_id @@ -608,7 +608,7 @@ class OrgService: # Perform the update try: - updated_org = OrgStore.update_org(org_id, update_dict) + updated_org = await OrgStore.update_org(org_id, update_dict) if not updated_org: raise OrgDatabaseError('Failed to update organization in database') @@ -683,7 +683,7 @@ class OrgService: return None @staticmethod - def get_user_orgs_paginated( + async def get_user_orgs_paginated( user_id: str, page_id: str | None = None, limit: int = 100 ): """ @@ -706,7 +706,7 @@ class OrgService: user_uuid = parse_uuid(user_id) # Fetch organizations from store - orgs, next_page_id = OrgStore.get_user_orgs_paginated( + orgs, next_page_id = await OrgStore.get_user_orgs_paginated( user_id=user_uuid, page_id=page_id, limit=limit ) @@ -754,7 +754,7 @@ class OrgService: raise OrgNotFoundError(str(org_id)) # Retrieve organization - org = OrgStore.get_org_by_id(org_id) + org = await OrgStore.get_org_by_id(org_id) if not org: logger.error( 'Organization not found despite valid membership', @@ -774,7 +774,7 @@ class OrgService: return org @staticmethod - def verify_owner_authorization(user_id: str, org_id: UUID) -> None: + async def verify_owner_authorization(user_id: str, org_id: UUID) -> None: """ Verify that the user is the owner of the organization. @@ -787,7 +787,7 @@ class OrgService: OrgAuthorizationError: If user is not authorized to delete """ # Check if organization exists - org = OrgStore.get_org_by_id(org_id) + org = await OrgStore.get_org_by_id(org_id) if not org: raise OrgNotFoundError(str(org_id)) @@ -835,7 +835,7 @@ class OrgService: ) # Step 1: Verify user authorization - OrgService.verify_owner_authorization(user_id, org_id) + await OrgService.verify_owner_authorization(user_id, org_id) # Step 2: Perform database cascade deletion with LiteLLM cleanup in transaction try: @@ -879,7 +879,7 @@ class OrgService: if not user or not user.current_org_id: return False - org = OrgStore.get_org_by_id(user.current_org_id) + org = await OrgStore.get_org_by_id(user.current_org_id) if not org: return False @@ -913,7 +913,7 @@ class OrgService: ) # Step 1: Check if organization exists - org = OrgStore.get_org_by_id(org_id) + org = await OrgStore.get_org_by_id(org_id) if not org: raise OrgNotFoundError(str(org_id)) diff --git a/enterprise/storage/org_store.py b/enterprise/storage/org_store.py index 85b86ab934..76ec3fed35 100644 --- a/enterprise/storage/org_store.py +++ b/enterprise/storage/org_store.py @@ -13,7 +13,7 @@ from server.constants import ( from server.routes.org_models import OrgLLMSettingsUpdate, OrphanedUserError from sqlalchemy import select, text from sqlalchemy.orm import joinedload -from storage.database import a_session_maker, session_maker +from storage.database import a_session_maker from storage.lite_llm_manager import LiteLlmManager from storage.org import Org from storage.org_member import OrgMember @@ -28,61 +28,64 @@ class OrgStore: """Store for managing organizations.""" @staticmethod - def create_org( + async def create_org( kwargs: dict, ) -> Org: """Create a new organization.""" - with session_maker() as session: + async with a_session_maker() as session: org = Org(**kwargs) org.org_version = ORG_SETTINGS_VERSION org.default_llm_model = get_default_litellm_model() session.add(org) - session.commit() - session.refresh(org) + await session.commit() + await session.refresh(org) return org @staticmethod - def get_org_by_id(org_id: UUID) -> Org | None: + async def get_org_by_id(org_id: UUID) -> Org | None: """Get organization by ID.""" - org = None - with session_maker() as session: - org = session.query(Org).filter(Org.id == org_id).first() - return OrgStore._validate_org_version(org) + async with a_session_maker() as session: + result = await session.execute(select(Org).filter(Org.id == org_id)) + org = result.scalars().first() + return await OrgStore._validate_org_version(org) @staticmethod - def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None: - with session_maker() as session: - user = ( - session.query(User) + async def get_current_org_from_keycloak_user_id( + keycloak_user_id: str, + ) -> Org | None: + async with a_session_maker() as session: + result = await session.execute( + select(User) .options(joinedload(User.org_members)) .filter(User.id == UUID(keycloak_user_id)) - .first() ) + user = result.scalars().first() if not user: logger.warning(f'User not found for ID {keycloak_user_id}') return None org_id = user.current_org_id - org = session.query(Org).filter(Org.id == org_id).first() + result = await session.execute(select(Org).filter(Org.id == org_id)) + org = result.scalars().first() if not org: logger.warning( f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}' ) return None - return OrgStore._validate_org_version(org) + return await OrgStore._validate_org_version(org) @staticmethod - def get_org_by_name(name: str) -> Org | None: + async def get_org_by_name(name: str) -> Org | None: """Get organization by name.""" - org = None - with session_maker() as session: - org = session.query(Org).filter(Org.name == name).first() - return OrgStore._validate_org_version(org) + async with a_session_maker() as session: + result = await session.execute(select(Org).filter(Org.name == name)) + org = result.scalars().first() + return await OrgStore._validate_org_version(org) @staticmethod - def _validate_org_version(org: Org) -> Org | None: + async def _validate_org_version(org: Org | None) -> Org | None: """Check if we need to update org version.""" if org and org.org_version < ORG_SETTINGS_VERSION: - org = OrgStore.update_org( + org = await OrgStore.update_org( org.id, { 'org_version': ORG_SETTINGS_VERSION, @@ -93,14 +96,15 @@ class OrgStore: return org @staticmethod - def list_orgs() -> list[Org]: + async def list_orgs() -> list[Org]: """List all organizations.""" - with session_maker() as session: - orgs = session.query(Org).all() - return orgs + async with a_session_maker() as session: + result = await session.execute(select(Org)) + orgs = result.scalars().all() + return list(orgs) @staticmethod - def get_user_orgs_paginated( + async def get_user_orgs_paginated( user_id: UUID, page_id: str | None = None, limit: int = 100 ) -> tuple[list[Org], str | None]: """ @@ -114,10 +118,10 @@ class OrgStore: Returns: Tuple of (list of Org objects, next_page_id or None) """ - with session_maker() as session: + async with a_session_maker() as session: # Build query joining OrgMember with Org query = ( - session.query(Org) + select(Org) .join(OrgMember, Org.id == OrgMember.org_id) .filter(OrgMember.user_id == user_id) .order_by(Org.name) @@ -136,7 +140,8 @@ class OrgStore: # Fetch limit + 1 to check if there are more results query = query.limit(limit + 1) - orgs = query.all() + result = await session.execute(query) + orgs = list(result.scalars().all()) # Check if there are more results has_more = len(orgs) > limit @@ -149,21 +154,24 @@ class OrgStore: next_page_id = str(offset + limit) # Validate org versions - validated_orgs = [ - OrgStore._validate_org_version(org) for org in orgs if org - ] - validated_orgs = [org for org in validated_orgs if org is not None] + validated_orgs = [] + for org in orgs: + if org: + validated = await OrgStore._validate_org_version(org) + if validated is not None: + validated_orgs.append(validated) return validated_orgs, next_page_id @staticmethod - def update_org( + async def update_org( org_id: UUID, kwargs: dict, ) -> Optional[Org]: """Update organization details.""" - with session_maker() as session: - org = session.query(Org).filter(Org.id == org_id).first() + async with a_session_maker() as session: + result = await session.execute(select(Org).filter(Org.id == org_id)) + org = result.scalars().first() if not org: return None @@ -173,8 +181,8 @@ class OrgStore: if hasattr(org, key): setattr(org, key, value) - session.commit() - session.refresh(org) + await session.commit() + await session.refresh(org) return org @staticmethod @@ -223,7 +231,7 @@ class OrgStore: return kwargs @staticmethod - def persist_org_with_owner( + async def persist_org_with_owner( org: Org, org_member: OrgMember, ) -> Org: @@ -240,11 +248,11 @@ class OrgStore: Raises: Exception: If database operations fail """ - with session_maker() as session: + async with a_session_maker() as session: session.add(org) session.add(org_member) - session.commit() - session.refresh(org) + await session.commit() + await session.refresh(org) return org @staticmethod @@ -261,15 +269,16 @@ class OrgStore: Raises: Exception: If database operations or LiteLLM cleanup fail """ - with session_maker() as session: + async with a_session_maker() as session: # First get the organization to return it - org = session.query(Org).filter(Org.id == org_id).first() + result = await session.execute(select(Org).filter(Org.id == org_id)) + org = result.scalars().first() if not org: return None try: # 1. Delete conversation data for organization conversations - session.execute( + await session.execute( text(""" DELETE FROM conversation_metadata WHERE conversation_id IN ( @@ -279,10 +288,10 @@ class OrgStore: {'org_id': str(org_id)}, ) - session.execute( + await session.execute( text(""" DELETE FROM app_conversation_start_task - WHERE app_conversation_id::text IN ( + WHERE app_conversation_id IN ( SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id ) """), @@ -290,40 +299,40 @@ class OrgStore: ) # 2. Delete organization-owned data tables (direct org_id foreign keys) - session.execute( + await session.execute( text('DELETE FROM billing_sessions WHERE org_id = :org_id'), {'org_id': str(org_id)}, ) - session.execute( + await session.execute( text( 'DELETE FROM conversation_metadata_saas WHERE org_id = :org_id' ), {'org_id': str(org_id)}, ) - session.execute( + await session.execute( text('DELETE FROM custom_secrets WHERE org_id = :org_id'), {'org_id': str(org_id)}, ) - session.execute( + await session.execute( text('DELETE FROM api_keys WHERE org_id = :org_id'), {'org_id': str(org_id)}, ) - session.execute( + await session.execute( text('DELETE FROM slack_conversation WHERE org_id = :org_id'), {'org_id': str(org_id)}, ) - session.execute( + await session.execute( text('DELETE FROM slack_users WHERE org_id = :org_id'), {'org_id': str(org_id)}, ) - session.execute( + await session.execute( text('DELETE FROM stripe_customers WHERE org_id = :org_id'), {'org_id': str(org_id)}, ) # 3. Handle users with this as current_org_id BEFORE deleting memberships # Single query to find orphaned users (those with no alternative org) - orphaned_users = session.execute( + orphaned_result = await session.execute( text(""" SELECT u.id FROM "user" u @@ -334,27 +343,28 @@ class OrgStore: ) """), {'org_id': str(org_id)}, - ).fetchall() + ) + orphaned_users = orphaned_result.fetchall() if orphaned_users: raise OrphanedUserError([str(row[0]) for row in orphaned_users]) # Batch update: reassign current_org_id to an alternative org for all affected users - session.execute( + await session.execute( text(""" - UPDATE "user" u + UPDATE user SET current_org_id = ( SELECT om.org_id FROM org_member om - WHERE om.user_id = u.id AND om.org_id != :org_id + WHERE om.user_id = user.id AND om.org_id != :org_id LIMIT 1 ) - WHERE u.current_org_id = :org_id + WHERE user.current_org_id = :org_id """), {'org_id': str(org_id)}, ) # 4. Delete organization memberships (now safe) - session.execute( + await session.execute( text('DELETE FROM org_member WHERE org_id = :org_id'), {'org_id': str(org_id)}, ) @@ -370,7 +380,7 @@ class OrgStore: await LiteLlmManager.delete_team(str(org_id)) # 7. Commit all changes only if everything succeeded - session.commit() + await session.commit() logger.info( 'Successfully deleted organization and all associated data including LiteLLM team', @@ -380,7 +390,7 @@ class OrgStore: return org except Exception as e: - session.rollback() + await session.rollback() logger.error( 'Failed to delete organization - transaction rolled back', extra={'org_id': str(org_id), 'error': str(e)}, @@ -389,11 +399,12 @@ class OrgStore: @staticmethod async def get_org_by_id_async(org_id: UUID) -> Org | None: - """Get organization by ID (async version).""" - async with a_session_maker() as session: - result = await session.execute(select(Org).filter(Org.id == org_id)) - org = result.scalars().first() - return OrgStore._validate_org_version(org) if org else None + """Get organization by ID (async version). + + Note: This method is kept for backwards compatibility but simply + delegates to get_org_by_id which is now async. + """ + return await OrgStore.get_org_by_id(org_id) @staticmethod async def update_org_llm_settings_async( diff --git a/enterprise/tests/unit/server/routes/test_orgs.py b/enterprise/tests/unit/server/routes/test_orgs.py index b7fa82da3a..8462249dde 100644 --- a/enterprise/tests/unit/server/routes/test_orgs.py +++ b/enterprise/tests/unit/server/routes/test_orgs.py @@ -519,7 +519,7 @@ async def test_list_user_orgs_success(mock_app_list): ), patch( 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([mock_org], None), + AsyncMock(return_value=([mock_org], None)), ), ): client = TestClient(mock_app_list) @@ -573,7 +573,7 @@ async def test_list_user_orgs_returns_current_org_id(mock_app_list): ), patch( 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([current_org, other_org], None), + AsyncMock(return_value=([current_org, other_org], None)), ), ): client = TestClient(mock_app_list) @@ -618,7 +618,7 @@ async def test_list_user_orgs_with_pagination(mock_app_list): ), patch( 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([org1, org2], '2'), + AsyncMock(return_value=([org1, org2], '2')), ), ): client = TestClient(mock_app_list) @@ -653,7 +653,7 @@ async def test_list_user_orgs_empty(mock_app_list): ), patch( 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([], None), + AsyncMock(return_value=([], None)), ), ): client = TestClient(mock_app_list) @@ -720,7 +720,7 @@ async def test_list_user_orgs_service_error(mock_app_list): ), patch( 'server.routes.orgs.OrgService.get_user_orgs_paginated', - side_effect=Exception('Database error'), + AsyncMock(side_effect=Exception('Database error')), ), ): client = TestClient(mock_app_list) @@ -786,7 +786,7 @@ async def test_list_user_orgs_personal_org_identified(mock_app_list): ), patch( 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([personal_org], None), + AsyncMock(return_value=([personal_org], None)), ), ): client = TestClient(mock_app_list) @@ -825,7 +825,7 @@ async def test_list_user_orgs_team_org_identified(mock_app_list): ), patch( 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([team_org], None), + AsyncMock(return_value=([team_org], None)), ), ): client = TestClient(mock_app_list) @@ -874,7 +874,7 @@ async def test_list_user_orgs_mixed_personal_and_team(mock_app_list): ), patch( 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([personal_org, team_org], None), + AsyncMock(return_value=([personal_org, team_org], None)), ), ): client = TestClient(mock_app_list) @@ -946,7 +946,7 @@ async def test_list_user_orgs_all_fields_present(mock_app_list): ), patch( 'server.routes.orgs.OrgService.get_user_orgs_paginated', - return_value=([mock_org], None), + AsyncMock(return_value=([mock_org], None)), ), ): client = TestClient(mock_app_list) diff --git a/enterprise/tests/unit/test_get_user_v1_enabled_setting.py b/enterprise/tests/unit/test_get_user_v1_enabled_setting.py index b7af93bd17..a91fb9590a 100644 --- a/enterprise/tests/unit/test_get_user_v1_enabled_setting.py +++ b/enterprise/tests/unit/test_get_user_v1_enabled_setting.py @@ -1,7 +1,7 @@ """Unit tests for get_user_v1_enabled_setting and is_v1_enabled_for_github_resolver functions.""" import os -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from integrations.github.github_view import ( @@ -22,12 +22,12 @@ def mock_org(): def mock_dependencies(mock_org): """Fixture that patches all the common dependencies.""" with patch( - 'integrations.utils.call_sync_from_async', + 'integrations.utils.OrgStore.get_current_org_from_keycloak_user_id', + new_callable=AsyncMock, return_value=mock_org, - ) as mock_call_sync, patch('integrations.utils.OrgStore') as mock_org_store: + ) as mock_get_org: yield { - 'call_sync': mock_call_sync, - 'org_store': mock_org_store, + 'get_org': mock_get_org, 'org': mock_org, } @@ -102,10 +102,7 @@ class TestGetUserV1EnabledSetting: assert result is True # Verify correct methods were called with correct parameters - mock_dependencies['call_sync'].assert_called_once_with( - mock_dependencies['org_store'].get_current_org_from_keycloak_user_id, - 'test_user_123', - ) + mock_dependencies['get_org'].assert_called_once_with('test_user_123') @pytest.mark.asyncio async def test_returns_user_setting_true(self, mock_dependencies): @@ -124,8 +121,8 @@ class TestGetUserV1EnabledSetting: @pytest.mark.asyncio async def test_no_org_returns_false(self, mock_dependencies): """Test that the function returns False when no org is found.""" - # Mock call_sync_from_async to return None (no org found) - mock_dependencies['call_sync'].return_value = None + # Mock get_current_org_from_keycloak_user_id to return None (no org found) + mock_dependencies['get_org'].return_value = None result = await get_user_v1_enabled_setting('test_user_123') assert result is False diff --git a/enterprise/tests/unit/test_org_service.py b/enterprise/tests/unit/test_org_service.py index 64b43ff6e3..3b823d9994 100644 --- a/enterprise/tests/unit/test_org_service.py +++ b/enterprise/tests/unit/test_org_service.py @@ -63,7 +63,8 @@ def owner_role(session_maker): return role -def test_validate_name_uniqueness_with_unique_name(session_maker): +@pytest.mark.asyncio +async def test_validate_name_uniqueness_with_unique_name(async_session_maker): """ GIVEN: A unique organization name WHEN: validate_name_uniqueness is called @@ -74,14 +75,15 @@ def test_validate_name_uniqueness_with_unique_name(session_maker): # Act & Assert - should not raise with ( - patch('storage.org_store.session_maker', session_maker), - patch('storage.org_member_store.session_maker', session_maker), - patch('storage.role_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), + patch('storage.org_member_store.session_maker'), + patch('storage.role_store.session_maker'), ): - OrgService.validate_name_uniqueness(unique_name) + await OrgService.validate_name_uniqueness(unique_name) -def test_validate_name_uniqueness_with_duplicate_name(session_maker): +@pytest.mark.asyncio +async def test_validate_name_uniqueness_with_duplicate_name(): """ GIVEN: An organization name that already exists WHEN: validate_name_uniqueness is called @@ -94,18 +96,19 @@ def test_validate_name_uniqueness_with_duplicate_name(session_maker): # Mock OrgStore.get_org_by_name to return the existing org with patch( 'storage.org_service.OrgStore.get_org_by_name', + new_callable=AsyncMock, return_value=existing_org, ): # Act & Assert with pytest.raises(OrgNameExistsError) as exc_info: - OrgService.validate_name_uniqueness(existing_name) + await OrgService.validate_name_uniqueness(existing_name) assert existing_name in str(exc_info.value) @pytest.mark.asyncio async def test_create_org_with_owner_success( - session_maker, owner_role, mock_litellm_api + session_maker, async_session_maker, owner_role, mock_litellm_api ): """ GIVEN: Valid organization data and user ID @@ -128,7 +131,7 @@ async def test_create_org_with_owner_success( mock_settings = {'team_id': 'test-team', 'user_id': str(user_id)} with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), patch( 'storage.org_service.UserStore.create_default_settings', @@ -178,7 +181,7 @@ async def test_create_org_with_owner_success( @pytest.mark.asyncio async def test_create_org_with_owner_duplicate_name( - session_maker, owner_role, mock_litellm_api + session_maker, async_session_maker, owner_role, mock_litellm_api ): """ GIVEN: An organization name that already exists @@ -196,7 +199,7 @@ async def test_create_org_with_owner_duplicate_name( # Act & Assert with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), patch( 'storage.org_service.UserStore.create_default_settings', @@ -217,7 +220,7 @@ async def test_create_org_with_owner_duplicate_name( @pytest.mark.asyncio async def test_create_org_with_owner_litellm_failure( - session_maker, owner_role, mock_litellm_api + session_maker, async_session_maker, owner_role, mock_litellm_api ): """ GIVEN: LiteLLM integration fails @@ -229,7 +232,7 @@ async def test_create_org_with_owner_litellm_failure( # Mock LiteLLM failure with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch( 'storage.org_service.UserStore.create_default_settings', AsyncMock(return_value=None), @@ -252,7 +255,7 @@ async def test_create_org_with_owner_litellm_failure( @pytest.mark.asyncio async def test_create_org_with_owner_database_failure_triggers_cleanup( - session_maker, owner_role, mock_litellm_api + session_maker, async_session_maker, owner_role, mock_litellm_api ): """ GIVEN: Database persistence fails after LiteLLM integration succeeds @@ -272,7 +275,7 @@ async def test_create_org_with_owner_database_failure_triggers_cleanup( mock_settings = {'team_id': 'test-team', 'user_id': user_id} with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.role_store.session_maker', session_maker), patch( 'storage.org_service.UserStore.create_default_settings', @@ -311,7 +314,7 @@ async def test_create_org_with_owner_database_failure_triggers_cleanup( @pytest.mark.asyncio async def test_create_org_with_owner_entity_creation_failure_triggers_cleanup( - session_maker, owner_role, mock_litellm_api + session_maker, async_session_maker, owner_role, mock_litellm_api ): """ GIVEN: Entity creation fails after LiteLLM integration succeeds @@ -325,7 +328,7 @@ async def test_create_org_with_owner_entity_creation_failure_triggers_cleanup( mock_settings = {'team_id': 'test-team', 'user_id': user_id} with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch( 'storage.org_service.UserStore.create_default_settings', AsyncMock(return_value=mock_settings), @@ -589,7 +592,10 @@ async def test_get_org_by_id_success(session_maker, owner_role): with ( patch('storage.org_service.OrgMemberStore.get_org_member') as mock_get_member, - patch('storage.org_service.OrgStore.get_org_by_id') as mock_get_org, + patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + ) as mock_get_org, ): mock_get_member.return_value = mock_org_member mock_get_org.return_value = mock_org @@ -652,7 +658,11 @@ async def test_get_org_by_id_org_not_found(): 'storage.org_service.OrgMemberStore.get_org_member', return_value=mock_org_member, ), - patch('storage.org_service.OrgStore.get_org_by_id', return_value=None), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + return_value=None, + ), ): # Act & Assert with pytest.raises(OrgNotFoundError) as exc_info: @@ -661,7 +671,10 @@ async def test_get_org_by_id_org_not_found(): assert str(org_id) in str(exc_info.value) -def test_get_user_orgs_paginated_success(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_success( + session_maker, async_session_maker, mock_litellm_api +): """ GIVEN: User has organizations in database WHEN: get_user_orgs_paginated is called with valid user_id @@ -685,8 +698,8 @@ def test_get_user_orgs_paginated_success(session_maker, mock_litellm_api): session.commit() # Act - with patch('storage.org_store.session_maker', session_maker): - orgs, next_page_id = OrgService.get_user_orgs_paginated( + with patch('storage.org_store.a_session_maker', async_session_maker): + orgs, next_page_id = await OrgService.get_user_orgs_paginated( user_id=str(user_id), page_id=None, limit=10 ) @@ -696,7 +709,10 @@ def test_get_user_orgs_paginated_success(session_maker, mock_litellm_api): assert next_page_id is None -def test_get_user_orgs_paginated_with_pagination(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_with_pagination( + session_maker, async_session_maker, mock_litellm_api +): """ GIVEN: User has multiple organizations WHEN: get_user_orgs_paginated is called with page_id and limit @@ -730,8 +746,8 @@ def test_get_user_orgs_paginated_with_pagination(session_maker, mock_litellm_api session.commit() # Act - with patch('storage.org_store.session_maker', session_maker): - orgs, next_page_id = OrgService.get_user_orgs_paginated( + with patch('storage.org_store.a_session_maker', async_session_maker): + orgs, next_page_id = await OrgService.get_user_orgs_paginated( user_id=str(user_id), page_id='0', limit=2 ) @@ -742,7 +758,8 @@ def test_get_user_orgs_paginated_with_pagination(session_maker, mock_litellm_api assert next_page_id == '2' -def test_get_user_orgs_paginated_empty_results(session_maker): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_empty_results(async_session_maker): """ GIVEN: User has no organizations WHEN: get_user_orgs_paginated is called @@ -752,8 +769,8 @@ def test_get_user_orgs_paginated_empty_results(session_maker): user_id = str(uuid.uuid4()) # Act - with patch('storage.org_store.session_maker', session_maker): - orgs, next_page_id = OrgService.get_user_orgs_paginated( + with patch('storage.org_store.a_session_maker', async_session_maker): + orgs, next_page_id = await OrgService.get_user_orgs_paginated( user_id=user_id, page_id=None, limit=10 ) @@ -762,7 +779,8 @@ def test_get_user_orgs_paginated_empty_results(session_maker): assert next_page_id is None -def test_get_user_orgs_paginated_invalid_user_id_format(): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_invalid_user_id_format(): """ GIVEN: Invalid user_id format (not a valid UUID string) WHEN: get_user_orgs_paginated is called @@ -773,12 +791,13 @@ def test_get_user_orgs_paginated_invalid_user_id_format(): # Act & Assert with pytest.raises(ValueError): - OrgService.get_user_orgs_paginated( + await OrgService.get_user_orgs_paginated( user_id=invalid_user_id, page_id=None, limit=10 ) -def test_verify_owner_authorization_success(session_maker, owner_role): +@pytest.mark.asyncio +async def test_verify_owner_authorization_success(session_maker, owner_role): """ GIVEN: User is owner of the organization WHEN: verify_owner_authorization is called @@ -808,7 +827,11 @@ def test_verify_owner_authorization_success(session_maker, owner_role): mock_owner_role.id = 1 with ( - patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + return_value=mock_org, + ), patch( 'storage.org_service.OrgMemberStore.get_org_member', return_value=mock_org_member, @@ -818,10 +841,11 @@ def test_verify_owner_authorization_success(session_maker, owner_role): ), ): # Act & Assert - should not raise - OrgService.verify_owner_authorization(user_id, org_id) + await OrgService.verify_owner_authorization(user_id, org_id) -def test_verify_owner_authorization_org_not_found(): +@pytest.mark.asyncio +async def test_verify_owner_authorization_org_not_found(): """ GIVEN: Organization does not exist WHEN: verify_owner_authorization is called @@ -831,15 +855,20 @@ def test_verify_owner_authorization_org_not_found(): org_id = uuid.uuid4() user_id = str(uuid.uuid4()) - with patch('storage.org_service.OrgStore.get_org_by_id', return_value=None): + with patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + return_value=None, + ): # Act & Assert with pytest.raises(OrgNotFoundError) as exc_info: - OrgService.verify_owner_authorization(user_id, org_id) + await OrgService.verify_owner_authorization(user_id, org_id) assert str(org_id) in str(exc_info.value) -def test_verify_owner_authorization_user_not_member(session_maker, owner_role): +@pytest.mark.asyncio +async def test_verify_owner_authorization_user_not_member(session_maker, owner_role): """ GIVEN: User is not a member of the organization WHEN: verify_owner_authorization is called @@ -857,17 +886,22 @@ def test_verify_owner_authorization_user_not_member(session_maker, owner_role): ) with ( - patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + return_value=mock_org, + ), patch('storage.org_service.OrgMemberStore.get_org_member', return_value=None), ): # Act & Assert with pytest.raises(OrgAuthorizationError) as exc_info: - OrgService.verify_owner_authorization(user_id, org_id) + await OrgService.verify_owner_authorization(user_id, org_id) assert 'not a member' in str(exc_info.value) -def test_verify_owner_authorization_user_not_owner(session_maker): +@pytest.mark.asyncio +async def test_verify_owner_authorization_user_not_owner(session_maker): """ GIVEN: User is member but not owner (admin role) WHEN: verify_owner_authorization is called @@ -893,7 +927,11 @@ def test_verify_owner_authorization_user_not_owner(session_maker): admin_role = Role(id=2, name='admin', rank=20) with ( - patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + return_value=mock_org, + ), patch( 'storage.org_service.OrgMemberStore.get_org_member', return_value=mock_org_member, @@ -902,7 +940,7 @@ def test_verify_owner_authorization_user_not_owner(session_maker): ): # Act & Assert with pytest.raises(OrgAuthorizationError) as exc_info: - OrgService.verify_owner_authorization(user_id, org_id) + await OrgService.verify_owner_authorization(user_id, org_id) assert 'Only organization owners' in str(exc_info.value) @@ -1034,7 +1072,9 @@ async def test_delete_org_with_cleanup_unexpected_none_result( @pytest.mark.asyncio -async def test_update_org_with_permissions_success_non_llm_fields(session_maker): +async def test_update_org_with_permissions_success_non_llm_fields( + async_session_maker, session_maker +): """ GIVEN: Valid organization update with non-LLM fields and user is a member WHEN: update_org_with_permissions is called @@ -1077,7 +1117,7 @@ async def test_update_org_with_permissions_success_non_llm_fields(session_maker) ) with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1096,7 +1136,9 @@ async def test_update_org_with_permissions_success_non_llm_fields(session_maker) @pytest.mark.asyncio -async def test_update_org_with_permissions_success_llm_fields_admin(session_maker): +async def test_update_org_with_permissions_success_llm_fields_admin( + async_session_maker, session_maker +): """ GIVEN: Valid organization update with LLM fields and user has admin role WHEN: update_org_with_permissions is called @@ -1138,7 +1180,7 @@ async def test_update_org_with_permissions_success_llm_fields_admin(session_make ) with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1156,7 +1198,9 @@ async def test_update_org_with_permissions_success_llm_fields_admin(session_make @pytest.mark.asyncio -async def test_update_org_with_permissions_success_llm_fields_owner(session_maker): +async def test_update_org_with_permissions_success_llm_fields_owner( + async_session_maker, session_maker +): """ GIVEN: Valid organization update with LLM fields and user has owner role WHEN: update_org_with_permissions is called @@ -1198,7 +1242,7 @@ async def test_update_org_with_permissions_success_llm_fields_owner(session_make ) with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1216,7 +1260,9 @@ async def test_update_org_with_permissions_success_llm_fields_owner(session_make @pytest.mark.asyncio -async def test_update_org_with_permissions_success_mixed_fields_admin(session_maker): +async def test_update_org_with_permissions_success_mixed_fields_admin( + async_session_maker, session_maker +): """ GIVEN: Valid organization update with both LLM and non-LLM fields and user has admin role WHEN: update_org_with_permissions is called @@ -1259,7 +1305,7 @@ async def test_update_org_with_permissions_success_mixed_fields_admin(session_ma ) with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1278,7 +1324,9 @@ async def test_update_org_with_permissions_success_mixed_fields_admin(session_ma @pytest.mark.asyncio -async def test_update_org_with_permissions_empty_update(session_maker): +async def test_update_org_with_permissions_empty_update( + async_session_maker, session_maker +): """ GIVEN: Update request with no fields (all None) WHEN: update_org_with_permissions is called @@ -1317,7 +1365,7 @@ async def test_update_org_with_permissions_empty_update(session_maker): update_data = OrgUpdate() # All fields None with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1335,7 +1383,9 @@ async def test_update_org_with_permissions_empty_update(session_maker): @pytest.mark.asyncio -async def test_update_org_with_permissions_org_not_found(session_maker): +async def test_update_org_with_permissions_org_not_found( + session_maker, async_session_maker +): """ GIVEN: Organization ID does not exist WHEN: update_org_with_permissions is called @@ -1350,7 +1400,7 @@ async def test_update_org_with_permissions_org_not_found(session_maker): update_data = OrgUpdate(contact_name='Jane Doe') with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1366,7 +1416,9 @@ async def test_update_org_with_permissions_org_not_found(session_maker): @pytest.mark.asyncio -async def test_update_org_with_permissions_non_member(session_maker): +async def test_update_org_with_permissions_non_member( + session_maker, async_session_maker +): """ GIVEN: User is not a member of the organization WHEN: update_org_with_permissions is called @@ -1396,7 +1448,7 @@ async def test_update_org_with_permissions_non_member(session_maker): update_data = OrgUpdate(contact_name='Jane Doe') with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1413,6 +1465,7 @@ async def test_update_org_with_permissions_non_member(session_maker): @pytest.mark.asyncio async def test_update_org_with_permissions_llm_fields_insufficient_permission( + async_session_maker, session_maker, ): """ @@ -1453,7 +1506,7 @@ async def test_update_org_with_permissions_llm_fields_insufficient_permission( update_data = OrgUpdate(default_llm_model='claude-opus-4-5-20251101') with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1472,7 +1525,9 @@ async def test_update_org_with_permissions_llm_fields_insufficient_permission( @pytest.mark.asyncio -async def test_update_org_with_permissions_database_error(session_maker): +async def test_update_org_with_permissions_database_error( + async_session_maker, session_maker +): """ GIVEN: Database update operation fails WHEN: update_org_with_permissions is called @@ -1511,11 +1566,12 @@ async def test_update_org_with_permissions_database_error(session_maker): update_data = OrgUpdate(contact_name='Jane Doe') with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), patch( 'storage.org_service.OrgStore.update_org', + new_callable=AsyncMock, return_value=None, # Simulate database failure ), ): @@ -1532,6 +1588,7 @@ async def test_update_org_with_permissions_database_error(session_maker): @pytest.mark.asyncio async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists_error( + async_session_maker, session_maker, ): """ @@ -1564,16 +1621,18 @@ async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists update_data = OrgUpdate(name=duplicate_name) with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), patch( 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, return_value=mock_current_org, ), patch('storage.org_service.OrgService.is_org_member', return_value=True), patch( 'storage.org_service.OrgStore.get_org_by_name', + new_callable=AsyncMock, return_value=mock_org_with_name, ), ): @@ -1589,7 +1648,9 @@ async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists @pytest.mark.asyncio -async def test_update_org_with_permissions_same_name_allowed(session_maker): +async def test_update_org_with_permissions_same_name_allowed( + session_maker, async_session_maker +): """ GIVEN: User updates org with name unchanged (same as current org name) WHEN: update_org_with_permissions is called @@ -1613,20 +1674,23 @@ async def test_update_org_with_permissions_same_name_allowed(session_maker): update_data = OrgUpdate(name=current_name) with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), patch( 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, return_value=mock_org, ), patch('storage.org_service.OrgService.is_org_member', return_value=True), patch( 'storage.org_service.OrgStore.get_org_by_name', + new_callable=AsyncMock, return_value=mock_org, ), patch( 'storage.org_service.OrgStore.update_org', + new_callable=AsyncMock, return_value=mock_org, ), ): @@ -1643,7 +1707,9 @@ async def test_update_org_with_permissions_same_name_allowed(session_maker): @pytest.mark.asyncio -async def test_update_org_with_permissions_only_llm_fields(session_maker): +async def test_update_org_with_permissions_only_llm_fields( + async_session_maker, session_maker +): """ GIVEN: Update request contains only LLM fields and user has admin role WHEN: update_org_with_permissions is called @@ -1686,7 +1752,7 @@ async def test_update_org_with_permissions_only_llm_fields(session_maker): ) with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1705,7 +1771,9 @@ async def test_update_org_with_permissions_only_llm_fields(session_maker): @pytest.mark.asyncio -async def test_update_org_with_permissions_only_non_llm_fields(session_maker): +async def test_update_org_with_permissions_only_non_llm_fields( + async_session_maker, session_maker +): """ GIVEN: Update request contains only non-LLM fields and user is a member WHEN: update_org_with_permissions is called @@ -1748,7 +1816,7 @@ async def test_update_org_with_permissions_only_non_llm_fields(session_maker): ) with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), patch('storage.org_member_store.session_maker', session_maker), patch('storage.role_store.session_maker', session_maker), ): @@ -1790,6 +1858,7 @@ async def test_check_byor_export_enabled_returns_true_when_enabled(): ), patch( 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, return_value=mock_org, ), ): @@ -1824,6 +1893,7 @@ async def test_check_byor_export_enabled_returns_false_when_disabled(): ), patch( 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, return_value=mock_org, ), ): @@ -1900,6 +1970,7 @@ async def test_check_byor_export_enabled_returns_false_when_org_not_found(): ), patch( 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, return_value=None, ), ): @@ -1929,7 +2000,11 @@ async def test_switch_org_success(): mock_updated_user = User(id=uuid.UUID(user_id), current_org_id=org_id) with ( - patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + return_value=mock_org, + ), patch('storage.org_service.OrgService.is_org_member', return_value=True), patch( 'storage.org_service.UserStore.update_current_org', @@ -1956,7 +2031,11 @@ async def test_switch_org_org_not_found(): org_id = uuid.uuid4() user_id = str(uuid.uuid4()) - with patch('storage.org_service.OrgStore.get_org_by_id', return_value=None): + with patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + return_value=None, + ): # Act & Assert with pytest.raises(OrgNotFoundError) as exc_info: await OrgService.switch_org(user_id, org_id) @@ -1982,7 +2061,11 @@ async def test_switch_org_user_not_member(): ) with ( - patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + return_value=mock_org, + ), patch('storage.org_service.OrgService.is_org_member', return_value=False), ): # Act & Assert @@ -2010,7 +2093,11 @@ async def test_switch_org_user_not_found(): ) with ( - patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + new_callable=AsyncMock, + return_value=mock_org, + ), patch('storage.org_service.OrgService.is_org_member', return_value=True), patch('storage.org_service.UserStore.update_current_org', return_value=None), ): diff --git a/enterprise/tests/unit/test_org_store.py b/enterprise/tests/unit/test_org_store.py index 75dd1cffa1..2ef7619f33 100644 --- a/enterprise/tests/unit/test_org_store.py +++ b/enterprise/tests/unit/test_org_store.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import SecretStr +from sqlalchemy import select from sqlalchemy.exc import IntegrityError from storage.org import Org from storage.org_invitation import OrgInvitation @@ -40,67 +41,73 @@ def mock_litellm_api(): yield mock_client -def test_get_org_by_id(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_org_by_id(async_session_maker, mock_litellm_api): # Test getting org by ID - with session_maker() as session: + async with async_session_maker() as session: # Create a test org org = Org(name='test-org') session.add(org) - session.commit() + await session.commit() + await session.refresh(org) org_id = org.id # Test retrieval with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), ): - retrieved_org = OrgStore.get_org_by_id(org_id) + retrieved_org = await OrgStore.get_org_by_id(org_id) assert retrieved_org is not None assert retrieved_org.id == org_id assert retrieved_org.name == 'test-org' -def test_get_org_by_id_not_found(session_maker): +@pytest.mark.asyncio +async def test_get_org_by_id_not_found(async_session_maker): # Test getting org by ID when it doesn't exist - with patch('storage.org_store.session_maker', session_maker): + with patch('storage.org_store.a_session_maker', async_session_maker): non_existent_id = uuid.uuid4() - retrieved_org = OrgStore.get_org_by_id(non_existent_id) + retrieved_org = await OrgStore.get_org_by_id(non_existent_id) assert retrieved_org is None -def test_list_orgs(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_list_orgs(async_session_maker, mock_litellm_api): # Test listing all orgs - with session_maker() as session: + async with async_session_maker() as session: # Create test orgs org1 = Org(name='test-org-1') org2 = Org(name='test-org-2') session.add_all([org1, org2]) - session.commit() + await session.commit() # Test listing with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), ): - orgs = OrgStore.list_orgs() + orgs = await OrgStore.list_orgs() assert len(orgs) >= 2 org_names = [org.name for org in orgs] assert 'test-org-1' in org_names assert 'test-org-2' in org_names -def test_update_org(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_update_org(async_session_maker, mock_litellm_api): # Test updating org details - with session_maker() as session: + async with async_session_maker() as session: # Create a test org org = Org(name='test-org', agent='CodeActAgent') session.add(org) - session.commit() + await session.commit() + await session.refresh(org) org_id = org.id # Test update with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), ): - updated_org = OrgStore.update_org( + updated_org = await OrgStore.update_org( org_id=org_id, kwargs={'name': 'updated-org', 'agent': 'PlannerAgent'} ) @@ -109,23 +116,27 @@ def test_update_org(session_maker, mock_litellm_api): assert updated_org.agent == 'PlannerAgent' -def test_update_org_not_found(session_maker): +@pytest.mark.asyncio +async def test_update_org_not_found(async_session_maker): # Test updating org that doesn't exist - with patch('storage.org_store.session_maker', session_maker): + with patch('storage.org_store.a_session_maker', async_session_maker): from uuid import uuid4 - updated_org = OrgStore.update_org( + updated_org = await OrgStore.update_org( org_id=uuid4(), kwargs={'name': 'updated-org'} ) assert updated_org is None -def test_create_org(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_create_org(async_session_maker, mock_litellm_api): # Test creating a new org with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), ): - org = OrgStore.create_org(kwargs={'name': 'new-org', 'agent': 'CodeActAgent'}) + org = await OrgStore.create_org( + kwargs={'name': 'new-org', 'agent': 'CodeActAgent'} + ) assert org is not None assert org.name == 'new-org' @@ -133,43 +144,48 @@ def test_create_org(session_maker, mock_litellm_api): assert org.id is not None -def test_get_org_by_name(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_org_by_name(async_session_maker, mock_litellm_api): # Test getting org by name - with session_maker() as session: + async with async_session_maker() as session: # Create a test org org = Org(name='test-org-by-name') session.add(org) - session.commit() + await session.commit() # Test retrieval with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), ): - retrieved_org = OrgStore.get_org_by_name('test-org-by-name') + retrieved_org = await OrgStore.get_org_by_name('test-org-by-name') assert retrieved_org is not None assert retrieved_org.name == 'test-org-by-name' -def test_get_current_org_from_keycloak_user_id(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_current_org_from_keycloak_user_id( + async_session_maker, mock_litellm_api +): # Test getting current org from user ID test_user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: # Create test data org = Org(name='test-org') session.add(org) - session.flush() + await session.flush() from storage.user import User user = User(id=test_user_id, current_org_id=org.id) session.add(user) - session.commit() + await session.commit() + await session.refresh(org) # Test retrieval with ( - patch('storage.org_store.session_maker', session_maker), + patch('storage.org_store.a_session_maker', async_session_maker), ): - retrieved_org = OrgStore.get_current_org_from_keycloak_user_id( + retrieved_org = await OrgStore.get_current_org_from_keycloak_user_id( str(test_user_id) ) assert retrieved_org is not None @@ -200,7 +216,8 @@ def test_get_kwargs_from_settings(): assert 'enable_sound_notifications' not in kwargs -def test_persist_org_with_owner_success(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_persist_org_with_owner_success(async_session_maker, mock_litellm_api): """ GIVEN: Valid org and org_member entities WHEN: persist_org_with_owner is called @@ -211,12 +228,12 @@ def test_persist_org_with_owner_success(session_maker, mock_litellm_api): user_id = uuid.uuid4() # Create user and role first - with session_maker() as session: + async with async_session_maker() as session: user = User(id=user_id, current_org_id=org_id) role = Role(id=1, name='owner', rank=1) session.add(user) session.add(role) - session.commit() + await session.commit() org = Org( id=org_id, @@ -234,8 +251,8 @@ def test_persist_org_with_owner_success(session_maker, mock_litellm_api): ) # Act - with patch('storage.org_store.session_maker', session_maker): - result = OrgStore.persist_org_with_owner(org, org_member) + with patch('storage.org_store.a_session_maker', async_session_maker): + result = await OrgStore.persist_org_with_owner(org, org_member) # Assert assert result is not None @@ -243,20 +260,24 @@ def test_persist_org_with_owner_success(session_maker, mock_litellm_api): assert result.name == 'Test Organization' # Verify both entities were persisted - with session_maker() as session: - persisted_org = session.get(Org, org_id) + async with async_session_maker() as session: + persisted_org = await session.get(Org, org_id) assert persisted_org is not None assert persisted_org.name == 'Test Organization' - persisted_member = ( - session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first() + result = await session.execute( + select(OrgMember).filter_by(org_id=org_id, user_id=user_id) ) + persisted_member = result.scalars().first() assert persisted_member is not None assert persisted_member.status == 'active' assert persisted_member.role_id == 1 -def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_persist_org_with_owner_returns_refreshed_org( + async_session_maker, mock_litellm_api +): """ GIVEN: Valid org and org_member entities WHEN: persist_org_with_owner is called @@ -266,12 +287,12 @@ def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litell org_id = uuid.uuid4() user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: user = User(id=user_id, current_org_id=org_id) role = Role(id=1, name='owner', rank=1) session.add(user) session.add(role) - session.commit() + await session.commit() org = Org( id=org_id, @@ -290,8 +311,8 @@ def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litell ) # Act - with patch('storage.org_store.session_maker', session_maker): - result = OrgStore.persist_org_with_owner(org, org_member) + with patch('storage.org_store.a_session_maker', async_session_maker): + result = await OrgStore.persist_org_with_owner(org, org_member) # Assert - verify the returned object has database-generated fields assert result.id == org_id @@ -301,7 +322,10 @@ def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litell assert hasattr(result, 'org_version') -def test_persist_org_with_owner_transaction_atomicity(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_persist_org_with_owner_transaction_atomicity( + async_session_maker, mock_litellm_api +): """ GIVEN: Valid org but invalid org_member (missing required field) WHEN: persist_org_with_owner is called @@ -311,12 +335,12 @@ def test_persist_org_with_owner_transaction_atomicity(session_maker, mock_litell org_id = uuid.uuid4() user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: user = User(id=user_id, current_org_id=org_id) role = Role(id=1, name='owner', rank=1) session.add(user) session.add(role) - session.commit() + await session.commit() org = Org( id=org_id, @@ -335,22 +359,26 @@ def test_persist_org_with_owner_transaction_atomicity(session_maker, mock_litell ) # Act & Assert - with patch('storage.org_store.session_maker', session_maker): + with patch('storage.org_store.a_session_maker', async_session_maker): with pytest.raises(IntegrityError): # NOT NULL constraint violation - OrgStore.persist_org_with_owner(org, org_member) + await OrgStore.persist_org_with_owner(org, org_member) # Verify neither entity was persisted (transaction rolled back) - with session_maker() as session: - persisted_org = session.get(Org, org_id) + async with async_session_maker() as session: + persisted_org = await session.get(Org, org_id) assert persisted_org is None - persisted_member = ( - session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first() + result = await session.execute( + select(OrgMember).filter_by(org_id=org_id, user_id=user_id) ) + persisted_member = result.scalars().first() assert persisted_member is None -def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_persist_org_with_owner_with_multiple_fields( + async_session_maker, mock_litellm_api +): """ GIVEN: Org with multiple optional fields populated WHEN: persist_org_with_owner is called @@ -360,12 +388,12 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm org_id = uuid.uuid4() user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: user = User(id=user_id, current_org_id=org_id) role = Role(id=1, name='owner', rank=1) session.add(user) session.add(role) - session.commit() + await session.commit() org = Org( id=org_id, @@ -389,8 +417,8 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm ) # Act - with patch('storage.org_store.session_maker', session_maker): - result = OrgStore.persist_org_with_owner(org, org_member) + with patch('storage.org_store.a_session_maker', async_session_maker): + result = await OrgStore.persist_org_with_owner(org, org_member) # Assert assert result.name == 'Complex Org' @@ -400,22 +428,23 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm assert result.billing_margin == 0.15 # Verify persistence - with session_maker() as session: - persisted_org = session.get(Org, org_id) + async with async_session_maker() as session: + persisted_org = await session.get(Org, org_id) assert persisted_org.agent == 'CodeActAgent' assert persisted_org.default_max_iterations == 50 assert persisted_org.confirmation_mode is True assert persisted_org.billing_margin == 0.15 - persisted_member = ( - session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first() + result_query = await session.execute( + select(OrgMember).filter_by(org_id=org_id, user_id=user_id) ) + persisted_member = result_query.scalars().first() assert persisted_member.max_iterations == 100 assert persisted_member.llm_model == 'gpt-4' @pytest.mark.asyncio -async def test_delete_org_cascade_success(session_maker, mock_litellm_api): +async def test_delete_org_cascade_success(async_session_maker, mock_litellm_api): """ GIVEN: Valid organization with associated data WHEN: delete_org_cascade is called @@ -431,18 +460,11 @@ async def test_delete_org_cascade_success(session_maker, mock_litellm_api): contact_name='John Doe', contact_email='john@example.com', ) + async with async_session_maker() as session: + session.add(expected_org) + await session.commit() - # Mock delete_org_cascade to avoid database schema constraints - async def mock_delete_org_cascade(org_id_param): - # Verify the method was called with correct parameter - assert org_id_param == org_id - - # Return the organization object (simulating successful deletion) - return expected_org - - with patch( - 'storage.org_store.OrgStore.delete_org_cascade', mock_delete_org_cascade - ): + with patch('storage.org_store.a_session_maker', async_session_maker): # Act result = await OrgStore.delete_org_cascade(org_id) @@ -455,7 +477,7 @@ async def test_delete_org_cascade_success(session_maker, mock_litellm_api): @pytest.mark.asyncio -async def test_delete_org_cascade_not_found(session_maker): +async def test_delete_org_cascade_not_found(async_session_maker): """ GIVEN: Organization ID that doesn't exist WHEN: delete_org_cascade is called @@ -464,7 +486,7 @@ async def test_delete_org_cascade_not_found(session_maker): # Arrange non_existent_id = uuid.uuid4() - with patch('storage.org_store.session_maker', session_maker): + with patch('storage.org_store.a_session_maker', async_session_maker): # Act result = await OrgStore.delete_org_cascade(non_existent_id) @@ -474,7 +496,7 @@ async def test_delete_org_cascade_not_found(session_maker): @pytest.mark.asyncio async def test_delete_org_cascade_litellm_failure_causes_rollback( - session_maker, mock_litellm_api + async_session_maker, mock_litellm_api ): """ GIVEN: Organization exists but LiteLLM cleanup fails @@ -485,7 +507,7 @@ async def test_delete_org_cascade_litellm_failure_causes_rollback( org_id = uuid.uuid4() user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: role = Role(id=1, name='owner', rank=1) user = User(id=user_id, current_org_id=org_id) org = Org( @@ -502,15 +524,15 @@ async def test_delete_org_cascade_litellm_failure_causes_rollback( llm_api_key='test-key', ) session.add_all([role, user, org, org_member]) - session.commit() + await session.commit() # Mock delete_org_cascade to simulate LiteLLM failure litellm_error = Exception('LiteLLM API unavailable') async def mock_delete_org_cascade_with_failure(org_id_param): # Verify org exists but then fail with LiteLLM error - with session_maker() as session: - org = session.get(Org, org_id_param) + async with async_session_maker() as session: + org = await session.get(Org, org_id_param) if not org: return None # Simulate the failure during LiteLLM cleanup @@ -527,17 +549,21 @@ async def test_delete_org_cascade_litellm_failure_causes_rollback( assert 'LiteLLM API unavailable' in str(exc_info.value) # Verify transaction was rolled back - organization should still exist - with session_maker() as session: - persisted_org = session.get(Org, org_id) + async with async_session_maker() as session: + persisted_org = await session.get(Org, org_id) assert persisted_org is not None assert persisted_org.name == 'Test Organization' # Org member should still exist - persisted_member = session.query(OrgMember).filter_by(org_id=org_id).first() + result = await session.execute(select(OrgMember).filter_by(org_id=org_id)) + persisted_member = result.scalars().first() assert persisted_member is not None -def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_first_page( + async_session_maker, mock_litellm_api +): """ GIVEN: User is member of multiple organizations WHEN: get_user_orgs_paginated is called without page_id @@ -547,7 +573,7 @@ def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api): user_id = uuid.uuid4() other_user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: # Create orgs for the user org1 = Org(name='Alpha Org') org2 = Org(name='Beta Org') @@ -555,14 +581,14 @@ def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api): # Create org for another user (should not be included) org4 = Org(name='Other Org') session.add_all([org1, org2, org3, org4]) - session.flush() + await session.flush() # Create user and role user = User(id=user_id, current_org_id=org1.id) other_user = User(id=other_user_id, current_org_id=org4.id) role = Role(id=1, name='member', rank=2) session.add_all([user, other_user, role]) - session.flush() + await session.flush() # Create memberships member1 = OrgMember( @@ -578,11 +604,11 @@ def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api): org_id=org4.id, user_id=other_user_id, role_id=1, llm_api_key='key4' ) session.add_all([member1, member2, member3, other_member]) - session.commit() + await session.commit() # Act - with patch('storage.org_store.session_maker', session_maker): - orgs, next_page_id = OrgStore.get_user_orgs_paginated( + with patch('storage.org_store.a_session_maker', async_session_maker): + orgs, next_page_id = await OrgStore.get_user_orgs_paginated( user_id=user_id, page_id=None, limit=2 ) @@ -596,7 +622,10 @@ def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api): assert 'Other Org' not in org_names -def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_with_page_id( + async_session_maker, mock_litellm_api +): """ GIVEN: User has multiple organizations and page_id is provided WHEN: get_user_orgs_paginated is called with page_id @@ -605,17 +634,17 @@ def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api): # Arrange user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: org1 = Org(name='Alpha Org') org2 = Org(name='Beta Org') org3 = Org(name='Gamma Org') session.add_all([org1, org2, org3]) - session.flush() + await session.flush() user = User(id=user_id, current_org_id=org1.id) role = Role(id=1, name='member', rank=2) session.add_all([user, role]) - session.flush() + await session.flush() member1 = OrgMember( org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1' @@ -627,11 +656,11 @@ def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api): org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3' ) session.add_all([member1, member2, member3]) - session.commit() + await session.commit() # Act - with patch('storage.org_store.session_maker', session_maker): - orgs, next_page_id = OrgStore.get_user_orgs_paginated( + with patch('storage.org_store.a_session_maker', async_session_maker): + orgs, next_page_id = await OrgStore.get_user_orgs_paginated( user_id=user_id, page_id='1', limit=1 ) @@ -641,7 +670,10 @@ def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api): assert next_page_id == '2' # Has more results -def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_no_more_results( + async_session_maker, mock_litellm_api +): """ GIVEN: User has organizations but fewer than limit WHEN: get_user_orgs_paginated is called @@ -650,16 +682,16 @@ def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api # Arrange user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: org1 = Org(name='Alpha Org') org2 = Org(name='Beta Org') session.add_all([org1, org2]) - session.flush() + await session.flush() user = User(id=user_id, current_org_id=org1.id) role = Role(id=1, name='member', rank=2) session.add_all([user, role]) - session.flush() + await session.flush() member1 = OrgMember( org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1' @@ -668,11 +700,11 @@ def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2' ) session.add_all([member1, member2]) - session.commit() + await session.commit() # Act - with patch('storage.org_store.session_maker', session_maker): - orgs, next_page_id = OrgStore.get_user_orgs_paginated( + with patch('storage.org_store.a_session_maker', async_session_maker): + orgs, next_page_id = await OrgStore.get_user_orgs_paginated( user_id=user_id, page_id=None, limit=10 ) @@ -681,7 +713,10 @@ def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api assert next_page_id is None -def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_invalid_page_id( + async_session_maker, mock_litellm_api +): """ GIVEN: Invalid page_id (non-numeric string) WHEN: get_user_orgs_paginated is called @@ -690,25 +725,25 @@ def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api # Arrange user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: org1 = Org(name='Alpha Org') session.add(org1) - session.flush() + await session.flush() user = User(id=user_id, current_org_id=org1.id) role = Role(id=1, name='member', rank=2) session.add_all([user, role]) - session.flush() + await session.flush() member1 = OrgMember( org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1' ) session.add(member1) - session.commit() + await session.commit() # Act - with patch('storage.org_store.session_maker', session_maker): - orgs, next_page_id = OrgStore.get_user_orgs_paginated( + with patch('storage.org_store.a_session_maker', async_session_maker): + orgs, next_page_id = await OrgStore.get_user_orgs_paginated( user_id=user_id, page_id='invalid', limit=10 ) @@ -718,7 +753,8 @@ def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api assert next_page_id is None -def test_get_user_orgs_paginated_empty_results(session_maker): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_empty_results(async_session_maker): """ GIVEN: User has no organizations WHEN: get_user_orgs_paginated is called @@ -728,8 +764,8 @@ def test_get_user_orgs_paginated_empty_results(session_maker): user_id = uuid.uuid4() # Act - with patch('storage.org_store.session_maker', session_maker): - orgs, next_page_id = OrgStore.get_user_orgs_paginated( + with patch('storage.org_store.a_session_maker', async_session_maker): + orgs, next_page_id = await OrgStore.get_user_orgs_paginated( user_id=user_id, page_id=None, limit=10 ) @@ -738,7 +774,8 @@ def test_get_user_orgs_paginated_empty_results(session_maker): assert next_page_id is None -def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api): +@pytest.mark.asyncio +async def test_get_user_orgs_paginated_ordering(async_session_maker, mock_litellm_api): """ GIVEN: User has organizations with different names WHEN: get_user_orgs_paginated is called @@ -747,18 +784,18 @@ def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api): # Arrange user_id = uuid.uuid4() - with session_maker() as session: + async with async_session_maker() as session: # Create orgs in non-alphabetical order org3 = Org(name='Zebra Org') org1 = Org(name='Apple Org') org2 = Org(name='Banana Org') session.add_all([org3, org1, org2]) - session.flush() + await session.flush() user = User(id=user_id, current_org_id=org1.id) role = Role(id=1, name='member', rank=2) session.add_all([user, role]) - session.flush() + await session.flush() member1 = OrgMember( org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1' @@ -770,11 +807,11 @@ def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api): org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3' ) session.add_all([member1, member2, member3]) - session.commit() + await session.commit() # Act - with patch('storage.org_store.session_maker', session_maker): - orgs, _ = OrgStore.get_user_orgs_paginated( + with patch('storage.org_store.a_session_maker', async_session_maker): + orgs, _ = await OrgStore.get_user_orgs_paginated( user_id=user_id, page_id=None, limit=10 ) diff --git a/enterprise/tests/unit/test_stripe_service_db.py b/enterprise/tests/unit/test_stripe_service_db.py index b4f606dae7..026be727c4 100644 --- a/enterprise/tests/unit/test_stripe_service_db.py +++ b/enterprise/tests/unit/test_stripe_service_db.py @@ -84,10 +84,13 @@ async def test_find_customer_id_by_user_id_checks_db_first( with ( patch('integrations.stripe_service.a_session_maker', async_session_maker), patch('storage.org_store.a_session_maker', async_session_maker), - patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync, + patch( + 'integrations.stripe_service.OrgStore.get_current_org_from_keycloak_user_id', + new_callable=AsyncMock, + ) as mock_get_org, ): - # Mock the call_sync_from_async to return the org - mock_call_sync.return_value = mock_org + # Mock the async method to return the org + mock_get_org.return_value = mock_org # Call the function result = await find_customer_id_by_user_id(str(test_user_id)) @@ -95,8 +98,8 @@ async def test_find_customer_id_by_user_id_checks_db_first( # Verify the result assert result == 'cus_test123' - # Verify that call_sync_from_async was called with the correct function - mock_call_sync.assert_called_once() + # Verify that OrgStore.get_current_org_from_keycloak_user_id was called + mock_get_org.assert_called_once_with(str(test_user_id)) @pytest.mark.asyncio @@ -122,10 +125,13 @@ async def test_find_customer_id_by_user_id_falls_back_to_stripe( patch('integrations.stripe_service.a_session_maker', async_session_maker), patch('storage.org_store.a_session_maker', async_session_maker), patch('stripe.Customer.search_async', mock_search), - patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync, + patch( + 'integrations.stripe_service.OrgStore.get_current_org_from_keycloak_user_id', + new_callable=AsyncMock, + ) as mock_get_org, ): - # Mock the call_sync_from_async to return the org - mock_call_sync.return_value = mock_org + # Mock the async method to return the org + mock_get_org.return_value = mock_org # Call the function result = await find_customer_id_by_user_id(str(test_user_id)) @@ -165,14 +171,17 @@ async def test_create_customer_stores_id_in_db( patch('storage.org_store.a_session_maker', async_session_maker), patch('stripe.Customer.search_async', mock_search), patch('stripe.Customer.create_async', mock_create_async), - patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync, + patch( + 'integrations.stripe_service.OrgStore.get_current_org_from_keycloak_user_id', + new_callable=AsyncMock, + ) as mock_get_org, patch( 'integrations.stripe_service.find_customer_id_by_org_id', new_callable=AsyncMock, ) as mock_find_customer, ): - # Mock the call_sync_from_async to return the org - mock_call_sync.return_value = mock_org + # Mock the async method to return the org + mock_get_org.return_value = mock_org # Mock find_customer_id_by_org_id to return None (force creation path) mock_find_customer.return_value = None