mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
refactor(enterprise): make OrgStore fully async (#13154)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: OpenHands Bot <contact@all-hands.dev>
This commit is contained in:
@@ -72,7 +72,6 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
This function checks both the global environment variable kill switch AND
|
||||
the user's individual setting. Both must be true for the function to return true.
|
||||
"""
|
||||
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
@@ -81,14 +80,11 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
return False
|
||||
|
||||
def _get_setting():
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
org = await OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Github view types
|
||||
|
||||
@@ -9,8 +9,6 @@ from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
@@ -38,9 +36,7 @@ async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
org = await OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
@@ -50,9 +46,7 @@ async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
|
||||
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
# Get the current org for the user
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
org = await OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
|
||||
@@ -21,7 +21,6 @@ from openhands.events.event_store_abc import EventStoreABC
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
@@ -122,9 +121,7 @@ async def get_user_v1_enabled_setting(user_id: str | None) -> bool:
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
org = await OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
|
||||
if not org or org.v1_enabled is None:
|
||||
return False
|
||||
|
||||
@@ -105,7 +105,7 @@ async def list_user_orgs(
|
||||
)
|
||||
|
||||
# Fetch organizations from service layer
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
orgs, next_page_id = await OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
|
||||
@@ -73,7 +73,7 @@ class OrgInvitationService:
|
||||
)
|
||||
|
||||
# Step 1: Validate organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
org = await OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
@@ -187,7 +187,7 @@ class OrgInvitationService:
|
||||
)
|
||||
|
||||
# Step 1: Validate permissions upfront (shared for all emails)
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
org = await OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class OrgService:
|
||||
"""Service for handling organization-related operations."""
|
||||
|
||||
@staticmethod
|
||||
def validate_name_uniqueness(name: str) -> None:
|
||||
async def validate_name_uniqueness(name: str) -> None:
|
||||
"""
|
||||
Validate that organization name is unique.
|
||||
|
||||
@@ -41,7 +41,7 @@ class OrgService:
|
||||
Raises:
|
||||
OrgNameExistsError: If organization name already exists
|
||||
"""
|
||||
existing_org = OrgStore.get_org_by_name(name)
|
||||
existing_org = await OrgStore.get_org_by_name(name)
|
||||
if existing_org is not None:
|
||||
raise OrgNameExistsError(name)
|
||||
|
||||
@@ -214,7 +214,7 @@ class OrgService:
|
||||
)
|
||||
|
||||
# Step 1: Validate name uniqueness (fails early, no cleanup needed)
|
||||
OrgService.validate_name_uniqueness(name)
|
||||
await OrgService.validate_name_uniqueness(name)
|
||||
|
||||
# Step 2: Generate organization ID
|
||||
org_id = uuid4()
|
||||
@@ -304,7 +304,7 @@ class OrgService:
|
||||
OrgDatabaseError: If database operations fail
|
||||
"""
|
||||
try:
|
||||
persisted_org = OrgStore.persist_org_with_owner(org, org_member)
|
||||
persisted_org = await OrgStore.persist_org_with_owner(org, org_member)
|
||||
return persisted_org
|
||||
|
||||
except Exception as e:
|
||||
@@ -535,7 +535,7 @@ class OrgService:
|
||||
)
|
||||
|
||||
# Validate organization exists
|
||||
existing_org = OrgStore.get_org_by_id(org_id)
|
||||
existing_org = await OrgStore.get_org_by_id(org_id)
|
||||
if not existing_org:
|
||||
raise ValueError(f'Organization with ID {org_id} not found')
|
||||
|
||||
@@ -555,7 +555,7 @@ class OrgService:
|
||||
# Check if name is being updated and validate uniqueness
|
||||
if update_data.name is not None:
|
||||
# Check if new name conflicts with another org
|
||||
existing_org_with_name = OrgStore.get_org_by_name(update_data.name)
|
||||
existing_org_with_name = await OrgStore.get_org_by_name(update_data.name)
|
||||
if (
|
||||
existing_org_with_name is not None
|
||||
and existing_org_with_name.id != org_id
|
||||
@@ -608,7 +608,7 @@ class OrgService:
|
||||
|
||||
# Perform the update
|
||||
try:
|
||||
updated_org = OrgStore.update_org(org_id, update_dict)
|
||||
updated_org = await OrgStore.update_org(org_id, update_dict)
|
||||
if not updated_org:
|
||||
raise OrgDatabaseError('Failed to update organization in database')
|
||||
|
||||
@@ -683,7 +683,7 @@ class OrgService:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
async def get_user_orgs_paginated(
|
||||
user_id: str, page_id: str | None = None, limit: int = 100
|
||||
):
|
||||
"""
|
||||
@@ -706,7 +706,7 @@ class OrgService:
|
||||
user_uuid = parse_uuid(user_id)
|
||||
|
||||
# Fetch organizations from store
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
orgs, next_page_id = await OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_uuid, page_id=page_id, limit=limit
|
||||
)
|
||||
|
||||
@@ -754,7 +754,7 @@ class OrgService:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Retrieve organization
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
org = await OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
'Organization not found despite valid membership',
|
||||
@@ -774,7 +774,7 @@ class OrgService:
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def verify_owner_authorization(user_id: str, org_id: UUID) -> None:
|
||||
async def verify_owner_authorization(user_id: str, org_id: UUID) -> None:
|
||||
"""
|
||||
Verify that the user is the owner of the organization.
|
||||
|
||||
@@ -787,7 +787,7 @@ class OrgService:
|
||||
OrgAuthorizationError: If user is not authorized to delete
|
||||
"""
|
||||
# Check if organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
org = await OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
@@ -835,7 +835,7 @@ class OrgService:
|
||||
)
|
||||
|
||||
# Step 1: Verify user authorization
|
||||
OrgService.verify_owner_authorization(user_id, org_id)
|
||||
await OrgService.verify_owner_authorization(user_id, org_id)
|
||||
|
||||
# Step 2: Perform database cascade deletion with LiteLLM cleanup in transaction
|
||||
try:
|
||||
@@ -879,7 +879,7 @@ class OrgService:
|
||||
if not user or not user.current_org_id:
|
||||
return False
|
||||
|
||||
org = OrgStore.get_org_by_id(user.current_org_id)
|
||||
org = await OrgStore.get_org_by_id(user.current_org_id)
|
||||
if not org:
|
||||
return False
|
||||
|
||||
@@ -913,7 +913,7 @@ class OrgService:
|
||||
)
|
||||
|
||||
# Step 1: Check if organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
org = await OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from server.constants import (
|
||||
from server.routes.org_models import OrgLLMSettingsUpdate, OrphanedUserError
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
@@ -28,61 +28,64 @@ class OrgStore:
|
||||
"""Store for managing organizations."""
|
||||
|
||||
@staticmethod
|
||||
def create_org(
|
||||
async def create_org(
|
||||
kwargs: dict,
|
||||
) -> Org:
|
||||
"""Create a new organization."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
org = Org(**kwargs)
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_id(org_id: UUID) -> Org | None:
|
||||
async def get_org_by_id(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID."""
|
||||
org = None
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
return OrgStore._validate_org_version(org)
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
return await OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None:
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
async def get_current_org_from_keycloak_user_id(
|
||||
keycloak_user_id: str,
|
||||
) -> Org | None:
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == UUID(keycloak_user_id))
|
||||
.first()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
logger.warning(f'User not found for ID {keycloak_user_id}')
|
||||
return None
|
||||
org_id = user.current_org_id
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}'
|
||||
)
|
||||
return None
|
||||
return OrgStore._validate_org_version(org)
|
||||
return await OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_name(name: str) -> Org | None:
|
||||
async def get_org_by_name(name: str) -> Org | None:
|
||||
"""Get organization by name."""
|
||||
org = None
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.name == name).first()
|
||||
return OrgStore._validate_org_version(org)
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Org).filter(Org.name == name))
|
||||
org = result.scalars().first()
|
||||
return await OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def _validate_org_version(org: Org) -> Org | None:
|
||||
async def _validate_org_version(org: Org | None) -> Org | None:
|
||||
"""Check if we need to update org version."""
|
||||
if org and org.org_version < ORG_SETTINGS_VERSION:
|
||||
org = OrgStore.update_org(
|
||||
org = await OrgStore.update_org(
|
||||
org.id,
|
||||
{
|
||||
'org_version': ORG_SETTINGS_VERSION,
|
||||
@@ -93,14 +96,15 @@ class OrgStore:
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def list_orgs() -> list[Org]:
|
||||
async def list_orgs() -> list[Org]:
|
||||
"""List all organizations."""
|
||||
with session_maker() as session:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Org))
|
||||
orgs = result.scalars().all()
|
||||
return list(orgs)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
async def get_user_orgs_paginated(
|
||||
user_id: UUID, page_id: str | None = None, limit: int = 100
|
||||
) -> tuple[list[Org], str | None]:
|
||||
"""
|
||||
@@ -114,10 +118,10 @@ class OrgStore:
|
||||
Returns:
|
||||
Tuple of (list of Org objects, next_page_id or None)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Build query joining OrgMember with Org
|
||||
query = (
|
||||
session.query(Org)
|
||||
select(Org)
|
||||
.join(OrgMember, Org.id == OrgMember.org_id)
|
||||
.filter(OrgMember.user_id == user_id)
|
||||
.order_by(Org.name)
|
||||
@@ -136,7 +140,8 @@ class OrgStore:
|
||||
|
||||
# Fetch limit + 1 to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
orgs = query.all()
|
||||
result = await session.execute(query)
|
||||
orgs = list(result.scalars().all())
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(orgs) > limit
|
||||
@@ -149,21 +154,24 @@ class OrgStore:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
# Validate org versions
|
||||
validated_orgs = [
|
||||
OrgStore._validate_org_version(org) for org in orgs if org
|
||||
]
|
||||
validated_orgs = [org for org in validated_orgs if org is not None]
|
||||
validated_orgs = []
|
||||
for org in orgs:
|
||||
if org:
|
||||
validated = await OrgStore._validate_org_version(org)
|
||||
if validated is not None:
|
||||
validated_orgs.append(validated)
|
||||
|
||||
return validated_orgs, next_page_id
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
async def update_org(
|
||||
org_id: UUID,
|
||||
kwargs: dict,
|
||||
) -> Optional[Org]:
|
||||
"""Update organization details."""
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
@@ -173,8 +181,8 @@ class OrgStore:
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
@@ -223,7 +231,7 @@ class OrgStore:
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def persist_org_with_owner(
|
||||
async def persist_org_with_owner(
|
||||
org: Org,
|
||||
org_member: OrgMember,
|
||||
) -> Org:
|
||||
@@ -240,11 +248,11 @@ class OrgStore:
|
||||
Raises:
|
||||
Exception: If database operations fail
|
||||
"""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(org)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
@@ -261,15 +269,16 @@ class OrgStore:
|
||||
Raises:
|
||||
Exception: If database operations or LiteLLM cleanup fail
|
||||
"""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# First get the organization to return it
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. Delete conversation data for organization conversations
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text("""
|
||||
DELETE FROM conversation_metadata
|
||||
WHERE conversation_id IN (
|
||||
@@ -279,10 +288,10 @@ class OrgStore:
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text("""
|
||||
DELETE FROM app_conversation_start_task
|
||||
WHERE app_conversation_id::text IN (
|
||||
WHERE app_conversation_id IN (
|
||||
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
|
||||
)
|
||||
"""),
|
||||
@@ -290,40 +299,40 @@ class OrgStore:
|
||||
)
|
||||
|
||||
# 2. Delete organization-owned data tables (direct org_id foreign keys)
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM billing_sessions WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text(
|
||||
'DELETE FROM conversation_metadata_saas WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM custom_secrets WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM api_keys WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM slack_conversation WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM slack_users WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM stripe_customers WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 3. Handle users with this as current_org_id BEFORE deleting memberships
|
||||
# Single query to find orphaned users (those with no alternative org)
|
||||
orphaned_users = session.execute(
|
||||
orphaned_result = await session.execute(
|
||||
text("""
|
||||
SELECT u.id
|
||||
FROM "user" u
|
||||
@@ -334,27 +343,28 @@ class OrgStore:
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
).fetchall()
|
||||
)
|
||||
orphaned_users = orphaned_result.fetchall()
|
||||
|
||||
if orphaned_users:
|
||||
raise OrphanedUserError([str(row[0]) for row in orphaned_users])
|
||||
|
||||
# Batch update: reassign current_org_id to an alternative org for all affected users
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE "user" u
|
||||
UPDATE user
|
||||
SET current_org_id = (
|
||||
SELECT om.org_id FROM org_member om
|
||||
WHERE om.user_id = u.id AND om.org_id != :org_id
|
||||
WHERE om.user_id = user.id AND om.org_id != :org_id
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE u.current_org_id = :org_id
|
||||
WHERE user.current_org_id = :org_id
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 4. Delete organization memberships (now safe)
|
||||
session.execute(
|
||||
await session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
@@ -370,7 +380,7 @@ class OrgStore:
|
||||
await LiteLlmManager.delete_team(str(org_id))
|
||||
|
||||
# 7. Commit all changes only if everything succeeded
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'Successfully deleted organization and all associated data including LiteLLM team',
|
||||
@@ -380,7 +390,7 @@ class OrgStore:
|
||||
return org
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
logger.error(
|
||||
'Failed to delete organization - transaction rolled back',
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
@@ -389,11 +399,12 @@ class OrgStore:
|
||||
|
||||
@staticmethod
|
||||
async def get_org_by_id_async(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID (async version)."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
return OrgStore._validate_org_version(org) if org else None
|
||||
"""Get organization by ID (async version).
|
||||
|
||||
Note: This method is kept for backwards compatibility but simply
|
||||
delegates to get_org_by_id which is now async.
|
||||
"""
|
||||
return await OrgStore.get_org_by_id(org_id)
|
||||
|
||||
@staticmethod
|
||||
async def update_org_llm_settings_async(
|
||||
|
||||
@@ -519,7 +519,7 @@ async def test_list_user_orgs_success(mock_app_list):
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([mock_org], None),
|
||||
AsyncMock(return_value=([mock_org], None)),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
@@ -573,7 +573,7 @@ async def test_list_user_orgs_returns_current_org_id(mock_app_list):
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([current_org, other_org], None),
|
||||
AsyncMock(return_value=([current_org, other_org], None)),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
@@ -618,7 +618,7 @@ async def test_list_user_orgs_with_pagination(mock_app_list):
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([org1, org2], '2'),
|
||||
AsyncMock(return_value=([org1, org2], '2')),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
@@ -653,7 +653,7 @@ async def test_list_user_orgs_empty(mock_app_list):
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([], None),
|
||||
AsyncMock(return_value=([], None)),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
@@ -720,7 +720,7 @@ async def test_list_user_orgs_service_error(mock_app_list):
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
side_effect=Exception('Database error'),
|
||||
AsyncMock(side_effect=Exception('Database error')),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
@@ -786,7 +786,7 @@ async def test_list_user_orgs_personal_org_identified(mock_app_list):
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([personal_org], None),
|
||||
AsyncMock(return_value=([personal_org], None)),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
@@ -825,7 +825,7 @@ async def test_list_user_orgs_team_org_identified(mock_app_list):
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([team_org], None),
|
||||
AsyncMock(return_value=([team_org], None)),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
@@ -874,7 +874,7 @@ async def test_list_user_orgs_mixed_personal_and_team(mock_app_list):
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([personal_org, team_org], None),
|
||||
AsyncMock(return_value=([personal_org, team_org], None)),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
@@ -946,7 +946,7 @@ async def test_list_user_orgs_all_fields_present(mock_app_list):
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_user_orgs_paginated',
|
||||
return_value=([mock_org], None),
|
||||
AsyncMock(return_value=([mock_org], None)),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_list)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for get_user_v1_enabled_setting and is_v1_enabled_for_github_resolver functions."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import (
|
||||
@@ -22,12 +22,12 @@ def mock_org():
|
||||
def mock_dependencies(mock_org):
|
||||
"""Fixture that patches all the common dependencies."""
|
||||
with patch(
|
||||
'integrations.utils.call_sync_from_async',
|
||||
'integrations.utils.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
) as mock_call_sync, patch('integrations.utils.OrgStore') as mock_org_store:
|
||||
) as mock_get_org:
|
||||
yield {
|
||||
'call_sync': mock_call_sync,
|
||||
'org_store': mock_org_store,
|
||||
'get_org': mock_get_org,
|
||||
'org': mock_org,
|
||||
}
|
||||
|
||||
@@ -102,10 +102,7 @@ class TestGetUserV1EnabledSetting:
|
||||
assert result is True
|
||||
|
||||
# Verify correct methods were called with correct parameters
|
||||
mock_dependencies['call_sync'].assert_called_once_with(
|
||||
mock_dependencies['org_store'].get_current_org_from_keycloak_user_id,
|
||||
'test_user_123',
|
||||
)
|
||||
mock_dependencies['get_org'].assert_called_once_with('test_user_123')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_setting_true(self, mock_dependencies):
|
||||
@@ -124,8 +121,8 @@ class TestGetUserV1EnabledSetting:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_org_returns_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when no org is found."""
|
||||
# Mock call_sync_from_async to return None (no org found)
|
||||
mock_dependencies['call_sync'].return_value = None
|
||||
# Mock get_current_org_from_keycloak_user_id to return None (no org found)
|
||||
mock_dependencies['get_org'].return_value = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
@@ -63,7 +63,8 @@ def owner_role(session_maker):
|
||||
return role
|
||||
|
||||
|
||||
def test_validate_name_uniqueness_with_unique_name(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_name_uniqueness_with_unique_name(async_session_maker):
|
||||
"""
|
||||
GIVEN: A unique organization name
|
||||
WHEN: validate_name_uniqueness is called
|
||||
@@ -74,14 +75,15 @@ def test_validate_name_uniqueness_with_unique_name(session_maker):
|
||||
|
||||
# Act & Assert - should not raise
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker'),
|
||||
patch('storage.role_store.session_maker'),
|
||||
):
|
||||
OrgService.validate_name_uniqueness(unique_name)
|
||||
await OrgService.validate_name_uniqueness(unique_name)
|
||||
|
||||
|
||||
def test_validate_name_uniqueness_with_duplicate_name(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_name_uniqueness_with_duplicate_name():
|
||||
"""
|
||||
GIVEN: An organization name that already exists
|
||||
WHEN: validate_name_uniqueness is called
|
||||
@@ -94,18 +96,19 @@ def test_validate_name_uniqueness_with_duplicate_name(session_maker):
|
||||
# Mock OrgStore.get_org_by_name to return the existing org
|
||||
with patch(
|
||||
'storage.org_service.OrgStore.get_org_by_name',
|
||||
new_callable=AsyncMock,
|
||||
return_value=existing_org,
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNameExistsError) as exc_info:
|
||||
OrgService.validate_name_uniqueness(existing_name)
|
||||
await OrgService.validate_name_uniqueness(existing_name)
|
||||
|
||||
assert existing_name in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_success(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
session_maker, async_session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Valid organization data and user ID
|
||||
@@ -128,7 +131,7 @@ async def test_create_org_with_owner_success(
|
||||
mock_settings = {'team_id': 'test-team', 'user_id': str(user_id)}
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
@@ -178,7 +181,7 @@ async def test_create_org_with_owner_success(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_duplicate_name(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
session_maker, async_session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: An organization name that already exists
|
||||
@@ -196,7 +199,7 @@ async def test_create_org_with_owner_duplicate_name(
|
||||
|
||||
# Act & Assert
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
@@ -217,7 +220,7 @@ async def test_create_org_with_owner_duplicate_name(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_litellm_failure(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
session_maker, async_session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: LiteLLM integration fails
|
||||
@@ -229,7 +232,7 @@ async def test_create_org_with_owner_litellm_failure(
|
||||
|
||||
# Mock LiteLLM failure
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
AsyncMock(return_value=None),
|
||||
@@ -252,7 +255,7 @@ async def test_create_org_with_owner_litellm_failure(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_database_failure_triggers_cleanup(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
session_maker, async_session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Database persistence fails after LiteLLM integration succeeds
|
||||
@@ -272,7 +275,7 @@ async def test_create_org_with_owner_database_failure_triggers_cleanup(
|
||||
mock_settings = {'team_id': 'test-team', 'user_id': user_id}
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
@@ -311,7 +314,7 @@ async def test_create_org_with_owner_database_failure_triggers_cleanup(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_entity_creation_failure_triggers_cleanup(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
session_maker, async_session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Entity creation fails after LiteLLM integration succeeds
|
||||
@@ -325,7 +328,7 @@ async def test_create_org_with_owner_entity_creation_failure_triggers_cleanup(
|
||||
mock_settings = {'team_id': 'test-team', 'user_id': user_id}
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
AsyncMock(return_value=mock_settings),
|
||||
@@ -589,7 +592,10 @@ async def test_get_org_by_id_success(session_maker, owner_role):
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgMemberStore.get_org_member') as mock_get_member,
|
||||
patch('storage.org_service.OrgStore.get_org_by_id') as mock_get_org,
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org,
|
||||
):
|
||||
mock_get_member.return_value = mock_org_member
|
||||
mock_get_org.return_value = mock_org
|
||||
@@ -652,7 +658,11 @@ async def test_get_org_by_id_org_not_found():
|
||||
'storage.org_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_org_member,
|
||||
),
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=None),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
@@ -661,7 +671,10 @@ async def test_get_org_by_id_org_not_found():
|
||||
assert str(org_id) in str(exc_info.value)
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_success(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_success(
|
||||
session_maker, async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: User has organizations in database
|
||||
WHEN: get_user_orgs_paginated is called with valid user_id
|
||||
@@ -685,8 +698,8 @@ def test_get_user_orgs_paginated_success(session_maker, mock_litellm_api):
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
orgs, next_page_id = await OrgService.get_user_orgs_paginated(
|
||||
user_id=str(user_id), page_id=None, limit=10
|
||||
)
|
||||
|
||||
@@ -696,7 +709,10 @@ def test_get_user_orgs_paginated_success(session_maker, mock_litellm_api):
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_with_pagination(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_with_pagination(
|
||||
session_maker, async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: User has multiple organizations
|
||||
WHEN: get_user_orgs_paginated is called with page_id and limit
|
||||
@@ -730,8 +746,8 @@ def test_get_user_orgs_paginated_with_pagination(session_maker, mock_litellm_api
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
orgs, next_page_id = await OrgService.get_user_orgs_paginated(
|
||||
user_id=str(user_id), page_id='0', limit=2
|
||||
)
|
||||
|
||||
@@ -742,7 +758,8 @@ def test_get_user_orgs_paginated_with_pagination(session_maker, mock_litellm_api
|
||||
assert next_page_id == '2'
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_empty_results(async_session_maker):
|
||||
"""
|
||||
GIVEN: User has no organizations
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
@@ -752,8 +769,8 @@ def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
orgs, next_page_id = await OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
@@ -762,7 +779,8 @@ def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_invalid_user_id_format():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_invalid_user_id_format():
|
||||
"""
|
||||
GIVEN: Invalid user_id format (not a valid UUID string)
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
@@ -773,12 +791,13 @@ def test_get_user_orgs_paginated_invalid_user_id_format():
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
OrgService.get_user_orgs_paginated(
|
||||
await OrgService.get_user_orgs_paginated(
|
||||
user_id=invalid_user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
|
||||
def test_verify_owner_authorization_success(session_maker, owner_role):
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_owner_authorization_success(session_maker, owner_role):
|
||||
"""
|
||||
GIVEN: User is owner of the organization
|
||||
WHEN: verify_owner_authorization is called
|
||||
@@ -808,7 +827,11 @@ def test_verify_owner_authorization_success(session_maker, owner_role):
|
||||
mock_owner_role.id = 1
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_org_member,
|
||||
@@ -818,10 +841,11 @@ def test_verify_owner_authorization_success(session_maker, owner_role):
|
||||
),
|
||||
):
|
||||
# Act & Assert - should not raise
|
||||
OrgService.verify_owner_authorization(user_id, org_id)
|
||||
await OrgService.verify_owner_authorization(user_id, org_id)
|
||||
|
||||
|
||||
def test_verify_owner_authorization_org_not_found():
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_owner_authorization_org_not_found():
|
||||
"""
|
||||
GIVEN: Organization does not exist
|
||||
WHEN: verify_owner_authorization is called
|
||||
@@ -831,15 +855,20 @@ def test_verify_owner_authorization_org_not_found():
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
with patch('storage.org_service.OrgStore.get_org_by_id', return_value=None):
|
||||
with patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
OrgService.verify_owner_authorization(user_id, org_id)
|
||||
await OrgService.verify_owner_authorization(user_id, org_id)
|
||||
|
||||
assert str(org_id) in str(exc_info.value)
|
||||
|
||||
|
||||
def test_verify_owner_authorization_user_not_member(session_maker, owner_role):
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_owner_authorization_user_not_member(session_maker, owner_role):
|
||||
"""
|
||||
GIVEN: User is not a member of the organization
|
||||
WHEN: verify_owner_authorization is called
|
||||
@@ -857,17 +886,22 @@ def test_verify_owner_authorization_user_not_member(session_maker, owner_role):
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch('storage.org_service.OrgMemberStore.get_org_member', return_value=None),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgAuthorizationError) as exc_info:
|
||||
OrgService.verify_owner_authorization(user_id, org_id)
|
||||
await OrgService.verify_owner_authorization(user_id, org_id)
|
||||
|
||||
assert 'not a member' in str(exc_info.value)
|
||||
|
||||
|
||||
def test_verify_owner_authorization_user_not_owner(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_owner_authorization_user_not_owner(session_maker):
|
||||
"""
|
||||
GIVEN: User is member but not owner (admin role)
|
||||
WHEN: verify_owner_authorization is called
|
||||
@@ -893,7 +927,11 @@ def test_verify_owner_authorization_user_not_owner(session_maker):
|
||||
admin_role = Role(id=2, name='admin', rank=20)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_org_member,
|
||||
@@ -902,7 +940,7 @@ def test_verify_owner_authorization_user_not_owner(session_maker):
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgAuthorizationError) as exc_info:
|
||||
OrgService.verify_owner_authorization(user_id, org_id)
|
||||
await OrgService.verify_owner_authorization(user_id, org_id)
|
||||
|
||||
assert 'Only organization owners' in str(exc_info.value)
|
||||
|
||||
@@ -1034,7 +1072,9 @@ async def test_delete_org_with_cleanup_unexpected_none_result(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_success_non_llm_fields(session_maker):
|
||||
async def test_update_org_with_permissions_success_non_llm_fields(
|
||||
async_session_maker, session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: Valid organization update with non-LLM fields and user is a member
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1077,7 +1117,7 @@ async def test_update_org_with_permissions_success_non_llm_fields(session_maker)
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1096,7 +1136,9 @@ async def test_update_org_with_permissions_success_non_llm_fields(session_maker)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_success_llm_fields_admin(session_maker):
|
||||
async def test_update_org_with_permissions_success_llm_fields_admin(
|
||||
async_session_maker, session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: Valid organization update with LLM fields and user has admin role
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1138,7 +1180,7 @@ async def test_update_org_with_permissions_success_llm_fields_admin(session_make
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1156,7 +1198,9 @@ async def test_update_org_with_permissions_success_llm_fields_admin(session_make
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_success_llm_fields_owner(session_maker):
|
||||
async def test_update_org_with_permissions_success_llm_fields_owner(
|
||||
async_session_maker, session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: Valid organization update with LLM fields and user has owner role
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1198,7 +1242,7 @@ async def test_update_org_with_permissions_success_llm_fields_owner(session_make
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1216,7 +1260,9 @@ async def test_update_org_with_permissions_success_llm_fields_owner(session_make
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_success_mixed_fields_admin(session_maker):
|
||||
async def test_update_org_with_permissions_success_mixed_fields_admin(
|
||||
async_session_maker, session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: Valid organization update with both LLM and non-LLM fields and user has admin role
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1259,7 +1305,7 @@ async def test_update_org_with_permissions_success_mixed_fields_admin(session_ma
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1278,7 +1324,9 @@ async def test_update_org_with_permissions_success_mixed_fields_admin(session_ma
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_empty_update(session_maker):
|
||||
async def test_update_org_with_permissions_empty_update(
|
||||
async_session_maker, session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: Update request with no fields (all None)
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1317,7 +1365,7 @@ async def test_update_org_with_permissions_empty_update(session_maker):
|
||||
update_data = OrgUpdate() # All fields None
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1335,7 +1383,9 @@ async def test_update_org_with_permissions_empty_update(session_maker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_org_not_found(session_maker):
|
||||
async def test_update_org_with_permissions_org_not_found(
|
||||
session_maker, async_session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: Organization ID does not exist
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1350,7 +1400,7 @@ async def test_update_org_with_permissions_org_not_found(session_maker):
|
||||
update_data = OrgUpdate(contact_name='Jane Doe')
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1366,7 +1416,9 @@ async def test_update_org_with_permissions_org_not_found(session_maker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_non_member(session_maker):
|
||||
async def test_update_org_with_permissions_non_member(
|
||||
session_maker, async_session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: User is not a member of the organization
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1396,7 +1448,7 @@ async def test_update_org_with_permissions_non_member(session_maker):
|
||||
update_data = OrgUpdate(contact_name='Jane Doe')
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1413,6 +1465,7 @@ async def test_update_org_with_permissions_non_member(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_llm_fields_insufficient_permission(
|
||||
async_session_maker,
|
||||
session_maker,
|
||||
):
|
||||
"""
|
||||
@@ -1453,7 +1506,7 @@ async def test_update_org_with_permissions_llm_fields_insufficient_permission(
|
||||
update_data = OrgUpdate(default_llm_model='claude-opus-4-5-20251101')
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1472,7 +1525,9 @@ async def test_update_org_with_permissions_llm_fields_insufficient_permission(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_database_error(session_maker):
|
||||
async def test_update_org_with_permissions_database_error(
|
||||
async_session_maker, session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: Database update operation fails
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1511,11 +1566,12 @@ async def test_update_org_with_permissions_database_error(session_maker):
|
||||
update_data = OrgUpdate(contact_name='Jane Doe')
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.update_org',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # Simulate database failure
|
||||
),
|
||||
):
|
||||
@@ -1532,6 +1588,7 @@ async def test_update_org_with_permissions_database_error(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists_error(
|
||||
async_session_maker,
|
||||
session_maker,
|
||||
):
|
||||
"""
|
||||
@@ -1564,16 +1621,18 @@ async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists
|
||||
update_data = OrgUpdate(name=duplicate_name)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_current_org,
|
||||
),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_name',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org_with_name,
|
||||
),
|
||||
):
|
||||
@@ -1589,7 +1648,9 @@ async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_same_name_allowed(session_maker):
|
||||
async def test_update_org_with_permissions_same_name_allowed(
|
||||
session_maker, async_session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: User updates org with name unchanged (same as current org name)
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1613,20 +1674,23 @@ async def test_update_org_with_permissions_same_name_allowed(session_maker):
|
||||
update_data = OrgUpdate(name=current_name)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_name',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.update_org',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
):
|
||||
@@ -1643,7 +1707,9 @@ async def test_update_org_with_permissions_same_name_allowed(session_maker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_only_llm_fields(session_maker):
|
||||
async def test_update_org_with_permissions_only_llm_fields(
|
||||
async_session_maker, session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: Update request contains only LLM fields and user has admin role
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1686,7 +1752,7 @@ async def test_update_org_with_permissions_only_llm_fields(session_maker):
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1705,7 +1771,9 @@ async def test_update_org_with_permissions_only_llm_fields(session_maker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_only_non_llm_fields(session_maker):
|
||||
async def test_update_org_with_permissions_only_non_llm_fields(
|
||||
async_session_maker, session_maker
|
||||
):
|
||||
"""
|
||||
GIVEN: Update request contains only non-LLM fields and user is a member
|
||||
WHEN: update_org_with_permissions is called
|
||||
@@ -1748,7 +1816,7 @@ async def test_update_org_with_permissions_only_non_llm_fields(session_maker):
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
):
|
||||
@@ -1790,6 +1858,7 @@ async def test_check_byor_export_enabled_returns_true_when_enabled():
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
):
|
||||
@@ -1824,6 +1893,7 @@ async def test_check_byor_export_enabled_returns_false_when_disabled():
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
):
|
||||
@@ -1900,6 +1970,7 @@ async def test_check_byor_export_enabled_returns_false_when_org_not_found():
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
@@ -1929,7 +2000,11 @@ async def test_switch_org_success():
|
||||
mock_updated_user = User(id=uuid.UUID(user_id), current_org_id=org_id)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch(
|
||||
'storage.org_service.UserStore.update_current_org',
|
||||
@@ -1956,7 +2031,11 @@ async def test_switch_org_org_not_found():
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
with patch('storage.org_service.OrgStore.get_org_by_id', return_value=None):
|
||||
with patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await OrgService.switch_org(user_id, org_id)
|
||||
@@ -1982,7 +2061,11 @@ async def test_switch_org_user_not_member():
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=False),
|
||||
):
|
||||
# Act & Assert
|
||||
@@ -2010,7 +2093,11 @@ async def test_switch_org_user_not_found():
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch('storage.org_service.UserStore.update_current_org', return_value=None),
|
||||
):
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.org import Org
|
||||
from storage.org_invitation import OrgInvitation
|
||||
@@ -40,67 +41,73 @@ def mock_litellm_api():
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_get_org_by_id(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_by_id(async_session_maker, mock_litellm_api):
|
||||
# Test getting org by ID
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
org_id = org.id
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_id(org_id)
|
||||
retrieved_org = await OrgStore.get_org_by_id(org_id)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.id == org_id
|
||||
assert retrieved_org.name == 'test-org'
|
||||
|
||||
|
||||
def test_get_org_by_id_not_found(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_by_id_not_found(async_session_maker):
|
||||
# Test getting org by ID when it doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
non_existent_id = uuid.uuid4()
|
||||
retrieved_org = OrgStore.get_org_by_id(non_existent_id)
|
||||
retrieved_org = await OrgStore.get_org_by_id(non_existent_id)
|
||||
assert retrieved_org is None
|
||||
|
||||
|
||||
def test_list_orgs(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_orgs(async_session_maker, mock_litellm_api):
|
||||
# Test listing all orgs
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
# Create test orgs
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Test listing
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
):
|
||||
orgs = OrgStore.list_orgs()
|
||||
orgs = await OrgStore.list_orgs()
|
||||
assert len(orgs) >= 2
|
||||
org_names = [org.name for org in orgs]
|
||||
assert 'test-org-1' in org_names
|
||||
assert 'test-org-2' in org_names
|
||||
|
||||
|
||||
def test_update_org(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org(async_session_maker, mock_litellm_api):
|
||||
# Test updating org details
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org', agent='CodeActAgent')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
org_id = org.id
|
||||
|
||||
# Test update
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
):
|
||||
updated_org = OrgStore.update_org(
|
||||
updated_org = await OrgStore.update_org(
|
||||
org_id=org_id, kwargs={'name': 'updated-org', 'agent': 'PlannerAgent'}
|
||||
)
|
||||
|
||||
@@ -109,23 +116,27 @@ def test_update_org(session_maker, mock_litellm_api):
|
||||
assert updated_org.agent == 'PlannerAgent'
|
||||
|
||||
|
||||
def test_update_org_not_found(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_not_found(async_session_maker):
|
||||
# Test updating org that doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
from uuid import uuid4
|
||||
|
||||
updated_org = OrgStore.update_org(
|
||||
updated_org = await OrgStore.update_org(
|
||||
org_id=uuid4(), kwargs={'name': 'updated-org'}
|
||||
)
|
||||
assert updated_org is None
|
||||
|
||||
|
||||
def test_create_org(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org(async_session_maker, mock_litellm_api):
|
||||
# Test creating a new org
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
):
|
||||
org = OrgStore.create_org(kwargs={'name': 'new-org', 'agent': 'CodeActAgent'})
|
||||
org = await OrgStore.create_org(
|
||||
kwargs={'name': 'new-org', 'agent': 'CodeActAgent'}
|
||||
)
|
||||
|
||||
assert org is not None
|
||||
assert org.name == 'new-org'
|
||||
@@ -133,43 +144,48 @@ def test_create_org(session_maker, mock_litellm_api):
|
||||
assert org.id is not None
|
||||
|
||||
|
||||
def test_get_org_by_name(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_by_name(async_session_maker, mock_litellm_api):
|
||||
# Test getting org by name
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org-by-name')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_name('test-org-by-name')
|
||||
retrieved_org = await OrgStore.get_org_by_name('test-org-by-name')
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.name == 'test-org-by-name'
|
||||
|
||||
|
||||
def test_get_current_org_from_keycloak_user_id(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_org_from_keycloak_user_id(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
# Test getting current org from user ID
|
||||
test_user_id = uuid.uuid4()
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
from storage.user import User
|
||||
|
||||
user = User(id=test_user_id, current_org_id=org.id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_current_org_from_keycloak_user_id(
|
||||
retrieved_org = await OrgStore.get_current_org_from_keycloak_user_id(
|
||||
str(test_user_id)
|
||||
)
|
||||
assert retrieved_org is not None
|
||||
@@ -200,7 +216,8 @@ def test_get_kwargs_from_settings():
|
||||
assert 'enable_sound_notifications' not in kwargs
|
||||
|
||||
|
||||
def test_persist_org_with_owner_success(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_org_with_owner_success(async_session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Valid org and org_member entities
|
||||
WHEN: persist_org_with_owner is called
|
||||
@@ -211,12 +228,12 @@ def test_persist_org_with_owner_success(session_maker, mock_litellm_api):
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Create user and role first
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
session.add(user)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
org = Org(
|
||||
id=org_id,
|
||||
@@ -234,8 +251,8 @@ def test_persist_org_with_owner_success(session_maker, mock_litellm_api):
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
result = OrgStore.persist_org_with_owner(org, org_member)
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
result = await OrgStore.persist_org_with_owner(org, org_member)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
@@ -243,20 +260,24 @@ def test_persist_org_with_owner_success(session_maker, mock_litellm_api):
|
||||
assert result.name == 'Test Organization'
|
||||
|
||||
# Verify both entities were persisted
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, org_id)
|
||||
async with async_session_maker() as session:
|
||||
persisted_org = await session.get(Org, org_id)
|
||||
assert persisted_org is not None
|
||||
assert persisted_org.name == 'Test Organization'
|
||||
|
||||
persisted_member = (
|
||||
session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first()
|
||||
result = await session.execute(
|
||||
select(OrgMember).filter_by(org_id=org_id, user_id=user_id)
|
||||
)
|
||||
persisted_member = result.scalars().first()
|
||||
assert persisted_member is not None
|
||||
assert persisted_member.status == 'active'
|
||||
assert persisted_member.role_id == 1
|
||||
|
||||
|
||||
def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_org_with_owner_returns_refreshed_org(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Valid org and org_member entities
|
||||
WHEN: persist_org_with_owner is called
|
||||
@@ -266,12 +287,12 @@ def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litell
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
session.add(user)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
org = Org(
|
||||
id=org_id,
|
||||
@@ -290,8 +311,8 @@ def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litell
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
result = OrgStore.persist_org_with_owner(org, org_member)
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
result = await OrgStore.persist_org_with_owner(org, org_member)
|
||||
|
||||
# Assert - verify the returned object has database-generated fields
|
||||
assert result.id == org_id
|
||||
@@ -301,7 +322,10 @@ def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litell
|
||||
assert hasattr(result, 'org_version')
|
||||
|
||||
|
||||
def test_persist_org_with_owner_transaction_atomicity(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_org_with_owner_transaction_atomicity(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Valid org but invalid org_member (missing required field)
|
||||
WHEN: persist_org_with_owner is called
|
||||
@@ -311,12 +335,12 @@ def test_persist_org_with_owner_transaction_atomicity(session_maker, mock_litell
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
session.add(user)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
org = Org(
|
||||
id=org_id,
|
||||
@@ -335,22 +359,26 @@ def test_persist_org_with_owner_transaction_atomicity(session_maker, mock_litell
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
with pytest.raises(IntegrityError): # NOT NULL constraint violation
|
||||
OrgStore.persist_org_with_owner(org, org_member)
|
||||
await OrgStore.persist_org_with_owner(org, org_member)
|
||||
|
||||
# Verify neither entity was persisted (transaction rolled back)
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, org_id)
|
||||
async with async_session_maker() as session:
|
||||
persisted_org = await session.get(Org, org_id)
|
||||
assert persisted_org is None
|
||||
|
||||
persisted_member = (
|
||||
session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first()
|
||||
result = await session.execute(
|
||||
select(OrgMember).filter_by(org_id=org_id, user_id=user_id)
|
||||
)
|
||||
persisted_member = result.scalars().first()
|
||||
assert persisted_member is None
|
||||
|
||||
|
||||
def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_org_with_owner_with_multiple_fields(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Org with multiple optional fields populated
|
||||
WHEN: persist_org_with_owner is called
|
||||
@@ -360,12 +388,12 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
session.add(user)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
org = Org(
|
||||
id=org_id,
|
||||
@@ -389,8 +417,8 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
result = OrgStore.persist_org_with_owner(org, org_member)
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
result = await OrgStore.persist_org_with_owner(org, org_member)
|
||||
|
||||
# Assert
|
||||
assert result.name == 'Complex Org'
|
||||
@@ -400,22 +428,23 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm
|
||||
assert result.billing_margin == 0.15
|
||||
|
||||
# Verify persistence
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, org_id)
|
||||
async with async_session_maker() as session:
|
||||
persisted_org = await session.get(Org, org_id)
|
||||
assert persisted_org.agent == 'CodeActAgent'
|
||||
assert persisted_org.default_max_iterations == 50
|
||||
assert persisted_org.confirmation_mode is True
|
||||
assert persisted_org.billing_margin == 0.15
|
||||
|
||||
persisted_member = (
|
||||
session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first()
|
||||
result_query = await session.execute(
|
||||
select(OrgMember).filter_by(org_id=org_id, user_id=user_id)
|
||||
)
|
||||
persisted_member = result_query.scalars().first()
|
||||
assert persisted_member.max_iterations == 100
|
||||
assert persisted_member.llm_model == 'gpt-4'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_org_cascade_success(session_maker, mock_litellm_api):
|
||||
async def test_delete_org_cascade_success(async_session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Valid organization with associated data
|
||||
WHEN: delete_org_cascade is called
|
||||
@@ -431,18 +460,11 @@ async def test_delete_org_cascade_success(session_maker, mock_litellm_api):
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
session.add(expected_org)
|
||||
await session.commit()
|
||||
|
||||
# Mock delete_org_cascade to avoid database schema constraints
|
||||
async def mock_delete_org_cascade(org_id_param):
|
||||
# Verify the method was called with correct parameter
|
||||
assert org_id_param == org_id
|
||||
|
||||
# Return the organization object (simulating successful deletion)
|
||||
return expected_org
|
||||
|
||||
with patch(
|
||||
'storage.org_store.OrgStore.delete_org_cascade', mock_delete_org_cascade
|
||||
):
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
# Act
|
||||
result = await OrgStore.delete_org_cascade(org_id)
|
||||
|
||||
@@ -455,7 +477,7 @@ async def test_delete_org_cascade_success(session_maker, mock_litellm_api):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_org_cascade_not_found(session_maker):
|
||||
async def test_delete_org_cascade_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: Organization ID that doesn't exist
|
||||
WHEN: delete_org_cascade is called
|
||||
@@ -464,7 +486,7 @@ async def test_delete_org_cascade_not_found(session_maker):
|
||||
# Arrange
|
||||
non_existent_id = uuid.uuid4()
|
||||
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
# Act
|
||||
result = await OrgStore.delete_org_cascade(non_existent_id)
|
||||
|
||||
@@ -474,7 +496,7 @@ async def test_delete_org_cascade_not_found(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_org_cascade_litellm_failure_causes_rollback(
|
||||
session_maker, mock_litellm_api
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Organization exists but LiteLLM cleanup fails
|
||||
@@ -485,7 +507,7 @@ async def test_delete_org_cascade_litellm_failure_causes_rollback(
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
org = Org(
|
||||
@@ -502,15 +524,15 @@ async def test_delete_org_cascade_litellm_failure_causes_rollback(
|
||||
llm_api_key='test-key',
|
||||
)
|
||||
session.add_all([role, user, org, org_member])
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Mock delete_org_cascade to simulate LiteLLM failure
|
||||
litellm_error = Exception('LiteLLM API unavailable')
|
||||
|
||||
async def mock_delete_org_cascade_with_failure(org_id_param):
|
||||
# Verify org exists but then fail with LiteLLM error
|
||||
with session_maker() as session:
|
||||
org = session.get(Org, org_id_param)
|
||||
async with async_session_maker() as session:
|
||||
org = await session.get(Org, org_id_param)
|
||||
if not org:
|
||||
return None
|
||||
# Simulate the failure during LiteLLM cleanup
|
||||
@@ -527,17 +549,21 @@ async def test_delete_org_cascade_litellm_failure_causes_rollback(
|
||||
assert 'LiteLLM API unavailable' in str(exc_info.value)
|
||||
|
||||
# Verify transaction was rolled back - organization should still exist
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, org_id)
|
||||
async with async_session_maker() as session:
|
||||
persisted_org = await session.get(Org, org_id)
|
||||
assert persisted_org is not None
|
||||
assert persisted_org.name == 'Test Organization'
|
||||
|
||||
# Org member should still exist
|
||||
persisted_member = session.query(OrgMember).filter_by(org_id=org_id).first()
|
||||
result = await session.execute(select(OrgMember).filter_by(org_id=org_id))
|
||||
persisted_member = result.scalars().first()
|
||||
assert persisted_member is not None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_first_page(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: User is member of multiple organizations
|
||||
WHEN: get_user_orgs_paginated is called without page_id
|
||||
@@ -547,7 +573,7 @@ def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
|
||||
user_id = uuid.uuid4()
|
||||
other_user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
# Create orgs for the user
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
@@ -555,14 +581,14 @@ def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
|
||||
# Create org for another user (should not be included)
|
||||
org4 = Org(name='Other Org')
|
||||
session.add_all([org1, org2, org3, org4])
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
# Create user and role
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
other_user = User(id=other_user_id, current_org_id=org4.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, other_user, role])
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
# Create memberships
|
||||
member1 = OrgMember(
|
||||
@@ -578,11 +604,11 @@ def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
|
||||
org_id=org4.id, user_id=other_user_id, role_id=1, llm_api_key='key4'
|
||||
)
|
||||
session.add_all([member1, member2, member3, other_member])
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
orgs, next_page_id = await OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=2
|
||||
)
|
||||
|
||||
@@ -596,7 +622,10 @@ def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
|
||||
assert 'Other Org' not in org_names
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_with_page_id(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: User has multiple organizations and page_id is provided
|
||||
WHEN: get_user_orgs_paginated is called with page_id
|
||||
@@ -605,17 +634,17 @@ def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api):
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
org3 = Org(name='Gamma Org')
|
||||
session.add_all([org1, org2, org3])
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
@@ -627,11 +656,11 @@ def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api):
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
session.add_all([member1, member2, member3])
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
orgs, next_page_id = await OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id='1', limit=1
|
||||
)
|
||||
|
||||
@@ -641,7 +670,10 @@ def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api):
|
||||
assert next_page_id == '2' # Has more results
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_no_more_results(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: User has organizations but fewer than limit
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
@@ -650,16 +682,16 @@ def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
@@ -668,11 +700,11 @@ def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
session.add_all([member1, member2])
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
orgs, next_page_id = await OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
@@ -681,7 +713,10 @@ def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_invalid_page_id(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Invalid page_id (non-numeric string)
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
@@ -690,25 +725,25 @@ def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
session.add(org1)
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
session.add(member1)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
orgs, next_page_id = await OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id='invalid', limit=10
|
||||
)
|
||||
|
||||
@@ -718,7 +753,8 @@ def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_empty_results(async_session_maker):
|
||||
"""
|
||||
GIVEN: User has no organizations
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
@@ -728,8 +764,8 @@ def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
orgs, next_page_id = await OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
@@ -738,7 +774,8 @@ def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_orgs_paginated_ordering(async_session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has organizations with different names
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
@@ -747,18 +784,18 @@ def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
async with async_session_maker() as session:
|
||||
# Create orgs in non-alphabetical order
|
||||
org3 = Org(name='Zebra Org')
|
||||
org1 = Org(name='Apple Org')
|
||||
org2 = Org(name='Banana Org')
|
||||
session.add_all([org3, org1, org2])
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
await session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
@@ -770,11 +807,11 @@ def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
session.add_all([member1, member2, member3])
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, _ = OrgStore.get_user_orgs_paginated(
|
||||
with patch('storage.org_store.a_session_maker', async_session_maker):
|
||||
orgs, _ = await OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
|
||||
@@ -84,10 +84,13 @@ async def test_find_customer_id_by_user_id_checks_db_first(
|
||||
with (
|
||||
patch('integrations.stripe_service.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
patch(
|
||||
'integrations.stripe_service.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
# Mock the async method to return the org
|
||||
mock_get_org.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_customer_id_by_user_id(str(test_user_id))
|
||||
@@ -95,8 +98,8 @@ async def test_find_customer_id_by_user_id_checks_db_first(
|
||||
# Verify the result
|
||||
assert result == 'cus_test123'
|
||||
|
||||
# Verify that call_sync_from_async was called with the correct function
|
||||
mock_call_sync.assert_called_once()
|
||||
# Verify that OrgStore.get_current_org_from_keycloak_user_id was called
|
||||
mock_get_org.assert_called_once_with(str(test_user_id))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -122,10 +125,13 @@ async def test_find_customer_id_by_user_id_falls_back_to_stripe(
|
||||
patch('integrations.stripe_service.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
patch(
|
||||
'integrations.stripe_service.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
# Mock the async method to return the org
|
||||
mock_get_org.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_customer_id_by_user_id(str(test_user_id))
|
||||
@@ -165,14 +171,17 @@ async def test_create_customer_stores_id_in_db(
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('stripe.Customer.create_async', mock_create_async),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
patch(
|
||||
'integrations.stripe_service.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org,
|
||||
patch(
|
||||
'integrations.stripe_service.find_customer_id_by_org_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_find_customer,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
# Mock the async method to return the org
|
||||
mock_get_org.return_value = mock_org
|
||||
# Mock find_customer_id_by_org_id to return None (force creation path)
|
||||
mock_find_customer.return_value = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user