mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Refactor: Migrate remaining enterprise modules to async database sessions (#13124)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -24,7 +24,6 @@ from jinja2 import Environment
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
@@ -153,9 +152,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
@@ -6,7 +6,6 @@ from integrations.utils import HOST, get_oh_labels, has_exact_mention
|
||||
from jinja2 import Environment
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -78,9 +77,7 @@ class GitlabIssue(ResolverViewInterface):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
@@ -22,7 +22,8 @@ from server.constants import SLACK_CLIENT_ID
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from slack_sdk.oauth import AuthorizeUrlGenerator
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -63,12 +64,11 @@ class SlackManager(Manager):
|
||||
) -> tuple[SlackUser | None, UserAuth | None]:
|
||||
# We get the user and correlate them back to a user in OpenHands - if we can
|
||||
slack_user = None
|
||||
with session_maker() as session:
|
||||
slack_user = (
|
||||
session.query(SlackUser)
|
||||
.filter(SlackUser.slack_user_id == slack_user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SlackUser).where(SlackUser.slack_user_id == slack_user_id)
|
||||
)
|
||||
slack_user = result.scalar_one_or_none()
|
||||
|
||||
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from uuid import UUID
|
||||
import stripe
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
@@ -15,12 +15,10 @@ stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.first()
|
||||
)
|
||||
async with a_session_maker() as session:
|
||||
stmt = select(StripeCustomer).where(StripeCustomer.org_id == org_id)
|
||||
result = await session.execute(stmt)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer:
|
||||
return stripe_customer.stripe_customer_id
|
||||
|
||||
@@ -74,7 +72,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
@@ -82,7 +80,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
@@ -108,26 +106,27 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
async def migrate_customer(user_id: str, org: Org):
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StripeCustomer).where(StripeCustomer.keycloak_user_id == user_id)
|
||||
)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@@ -14,7 +14,7 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -111,9 +111,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -132,9 +132,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
# Mark callback as completed status
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
|
||||
@@ -34,7 +34,8 @@ from server.services.org_invitation_service import (
|
||||
OrgInvitationService,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -610,17 +611,20 @@ async def accept_tos(request: Request):
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
accepted_tos: datetime = datetime.now(timezone.utc)
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
)
|
||||
user.accepted_tos = accepted_tos
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
|
||||
@@ -11,9 +11,10 @@ from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
@@ -106,16 +107,17 @@ async def get_subscription_access(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> SubscriptionAccessResponse | None:
|
||||
"""Get details of the currently valid subscription for the user."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.first()
|
||||
result = await session.execute(
|
||||
select(SubscriptionAccess).where(
|
||||
SubscriptionAccess.status == 'ACTIVE',
|
||||
SubscriptionAccess.user_id == user_id,
|
||||
SubscriptionAccess.start_at <= now,
|
||||
SubscriptionAccess.end_at >= now,
|
||||
)
|
||||
)
|
||||
subscription_access = result.scalar_one_or_none()
|
||||
if not subscription_access:
|
||||
return None
|
||||
return SubscriptionAccessResponse(
|
||||
@@ -197,7 +199,7 @@ async def create_checkout_session(
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
@@ -206,7 +208,7 @@ async def create_checkout_session(
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
@@ -215,13 +217,14 @@ async def create_checkout_session(
|
||||
@billing_router.get('/success')
|
||||
async def success_callback(session_id: str, request: Request):
|
||||
# We can't use the auth cookie because of SameSite=strict
|
||||
with session_maker() as session:
|
||||
billing_session = (
|
||||
session.query(BillingSession)
|
||||
.filter(BillingSession.id == session_id)
|
||||
.filter(BillingSession.status == 'in_progress')
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == session_id,
|
||||
BillingSession.status == 'in_progress',
|
||||
)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
|
||||
if billing_session is None:
|
||||
# Hopefully this never happens - we get a redirect from stripe where the session does not exist
|
||||
@@ -253,7 +256,8 @@ async def success_callback(session_id: str, request: Request):
|
||||
user_team_info, billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
|
||||
org = session.query(Org).filter(Org.id == user.current_org_id).first()
|
||||
result = await session.execute(select(Org).where(Org.id == user.current_org_id))
|
||||
org = result.scalar_one_or_none()
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
@@ -279,7 +283,7 @@ async def success_callback(session_id: str, request: Request):
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
|
||||
@@ -289,13 +293,14 @@ async def success_callback(session_id: str, request: Request):
|
||||
# Callback endpoint for cancelled Stripe payments - updates billing session status
|
||||
@billing_router.get('/cancel')
|
||||
async def cancel_callback(session_id: str, request: Request):
|
||||
with session_maker() as session:
|
||||
billing_session = (
|
||||
session.query(BillingSession)
|
||||
.filter(BillingSession.id == session_id)
|
||||
.filter(BillingSession.status == 'in_progress')
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == session_id,
|
||||
BillingSession.status == 'in_progress',
|
||||
)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
if billing_session:
|
||||
logger.info(
|
||||
'stripe_checkout_cancel',
|
||||
@@ -307,7 +312,7 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
billing_session.status = 'cancelled'
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from sqlalchemy.sql import text
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.redis import create_redis_client
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -9,11 +9,11 @@ readiness_router = APIRouter()
|
||||
|
||||
|
||||
@readiness_router.get('/ready')
|
||||
def is_ready():
|
||||
async def is_ready():
|
||||
# Check database connection
|
||||
try:
|
||||
with session_maker() as session:
|
||||
session.execute(text('SELECT 1'))
|
||||
async with a_session_maker() as session:
|
||||
await session.execute(text('SELECT 1'))
|
||||
except Exception as e:
|
||||
logger.error(f'Database check failed: {str(e)}')
|
||||
raise HTTPException(
|
||||
|
||||
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
@@ -24,7 +25,7 @@ class JiraDcIntegrationStore:
|
||||
) -> JiraDcWorkspace:
|
||||
"""Create a new Jira DC workspace with encrypted sensitive data."""
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
workspace = JiraDcWorkspace(
|
||||
name=name.lower(),
|
||||
admin_user_id=admin_user_id,
|
||||
@@ -34,8 +35,8 @@ class JiraDcIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
session.add(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
logger.info(f'[Jira DC] Created workspace {workspace.name}')
|
||||
return workspace
|
||||
|
||||
@@ -48,11 +49,12 @@ class JiraDcIntegrationStore:
|
||||
status: Optional[str] = None,
|
||||
) -> JiraDcWorkspace:
|
||||
"""Update an existing Jira DC workspace with encrypted sensitive data."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Find existing workspace by ID
|
||||
workspace = (
|
||||
session.query(JiraDcWorkspace).filter(JiraDcWorkspace.id == id).first()
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(JiraDcWorkspace.id == id)
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
|
||||
if not workspace:
|
||||
raise ValueError(f'Workspace with ID "{id}" not found')
|
||||
@@ -69,8 +71,8 @@ class JiraDcIntegrationStore:
|
||||
if status is not None:
|
||||
workspace.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Jira DC] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -91,10 +93,10 @@ class JiraDcIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(jira_dc_user)
|
||||
session.commit()
|
||||
session.refresh(jira_dc_user)
|
||||
await session.commit()
|
||||
await session.refresh(jira_dc_user)
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Created user {jira_dc_user.id} for workspace {jira_dc_workspace_id}'
|
||||
@@ -103,94 +105,91 @@ class JiraDcIntegrationStore:
|
||||
|
||||
async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraDcWorkspace]:
|
||||
"""Retrieve workspace by ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcWorkspace)
|
||||
.filter(JiraDcWorkspace.id == workspace_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_workspace_by_name(
|
||||
self, workspace_name: str
|
||||
) -> Optional[JiraDcWorkspace]:
|
||||
"""Retrieve workspace by name."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcWorkspace)
|
||||
.filter(JiraDcWorkspace.name == workspace_name.lower())
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(
|
||||
JiraDcWorkspace.name == workspace_name.lower()
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_active_workspace(
|
||||
self, keycloak_user_id: str
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Retrieve user by Keycloak user ID."""
|
||||
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, jira_dc_workspace_id: int
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Get Jira DC user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_user(
|
||||
self, jira_dc_user_id: str, jira_dc_workspace_id: int
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Get Jira DC user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.jira_dc_user_id == jira_dc_user_id,
|
||||
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, jira_dc_workspace_id: int
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Get Jira DC user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_user_integration_status(
|
||||
self, keycloak_user_id: str, status: str
|
||||
) -> JiraDcUser:
|
||||
"""Update the status of a Jira DC user mapping."""
|
||||
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(JiraDcUser)
|
||||
.filter(JiraDcUser.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise ValueError(
|
||||
@@ -198,37 +197,35 @@ class JiraDcIntegrationStore:
|
||||
)
|
||||
|
||||
user.status = status
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
logger.info(f'[Jira DC] Updated user {keycloak_user_id} status to {status}')
|
||||
return user
|
||||
|
||||
async def deactivate_workspace(self, workspace_id: int):
|
||||
"""Deactivate the workspace and all user links for a given workspace."""
|
||||
with session_maker() as session:
|
||||
users = (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.jira_dc_workspace_id == workspace_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.all()
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
user.status = 'inactive'
|
||||
session.add(user)
|
||||
|
||||
workspace = (
|
||||
session.query(JiraDcWorkspace)
|
||||
.filter(JiraDcWorkspace.id == workspace_id)
|
||||
.first()
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
if workspace:
|
||||
workspace.status = 'inactive'
|
||||
session.add(workspace)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Deactivated all user links for workspace {workspace_id}'
|
||||
@@ -238,23 +235,22 @@ class JiraDcIntegrationStore:
|
||||
self, jira_dc_conversation: JiraDcConversation
|
||||
) -> None:
|
||||
"""Create a new Jira DC conversation record."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(jira_dc_conversation)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
async def get_user_conversations_by_issue_id(
|
||||
self, issue_id: str, jira_dc_user_id: int
|
||||
) -> JiraDcConversation | None:
|
||||
"""Get a Jira DC conversation by issue ID and jira dc user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcConversation)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcConversation).where(
|
||||
JiraDcConversation.issue_id == issue_id,
|
||||
JiraDcConversation.jira_dc_user_id == jira_dc_user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> JiraDcIntegrationStore:
|
||||
|
||||
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
@@ -35,10 +36,10 @@ class LinearIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Linear] Created workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -53,11 +54,12 @@ class LinearIntegrationStore:
|
||||
status: Optional[str] = None,
|
||||
) -> LinearWorkspace:
|
||||
"""Update an existing Linear workspace with encrypted sensitive data."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Find existing workspace by ID
|
||||
workspace = (
|
||||
session.query(LinearWorkspace).filter(LinearWorkspace.id == id).first()
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(LinearWorkspace.id == id)
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
|
||||
if not workspace:
|
||||
raise ValueError(f'Workspace with ID "{id}" not found')
|
||||
@@ -77,8 +79,8 @@ class LinearIntegrationStore:
|
||||
if status is not None:
|
||||
workspace.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Linear] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -98,10 +100,10 @@ class LinearIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(linear_user)
|
||||
session.commit()
|
||||
session.refresh(linear_user)
|
||||
await session.commit()
|
||||
await session.refresh(linear_user)
|
||||
|
||||
logger.info(
|
||||
f'[Linear] Created user {linear_user.id} for workspace {linear_workspace_id}'
|
||||
@@ -110,77 +112,75 @@ class LinearIntegrationStore:
|
||||
|
||||
async def get_workspace_by_id(self, workspace_id: int) -> Optional[LinearWorkspace]:
|
||||
"""Retrieve workspace by ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearWorkspace)
|
||||
.filter(LinearWorkspace.id == workspace_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_workspace_by_name(
|
||||
self, workspace_name: str
|
||||
) -> Optional[LinearWorkspace]:
|
||||
"""Retrieve workspace by name."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearWorkspace)
|
||||
.filter(LinearWorkspace.name == workspace_name.lower())
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(
|
||||
LinearWorkspace.name == workspace_name.lower()
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_active_workspace(
|
||||
self, keycloak_user_id: str
|
||||
) -> LinearUser | None:
|
||||
"""Get Linear user by Keycloak user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.keycloak_user_id == keycloak_user_id,
|
||||
LinearUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, linear_workspace_id: int
|
||||
) -> Optional[LinearUser]:
|
||||
"""Get Linear user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.keycloak_user_id == keycloak_user_id,
|
||||
LinearUser.linear_workspace_id == linear_workspace_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_user(
|
||||
self, linear_user_id: str, linear_workspace_id: int
|
||||
) -> Optional[LinearUser]:
|
||||
"""Get Linear user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.linear_user_id == linear_user_id,
|
||||
LinearUser.linear_workspace_id == linear_workspace_id,
|
||||
LinearUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_user_integration_status(
|
||||
self, keycloak_user_id: str, status: str
|
||||
) -> LinearUser:
|
||||
"""Update Linear user integration status."""
|
||||
with session_maker() as session:
|
||||
linear_user = (
|
||||
session.query(LinearUser)
|
||||
.filter(LinearUser.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.keycloak_user_id == keycloak_user_id
|
||||
)
|
||||
)
|
||||
linear_user = result.scalar_one_or_none()
|
||||
|
||||
if not linear_user:
|
||||
raise ValueError(
|
||||
@@ -188,38 +188,36 @@ class LinearIntegrationStore:
|
||||
)
|
||||
|
||||
linear_user.status = status
|
||||
session.commit()
|
||||
session.refresh(linear_user)
|
||||
await session.commit()
|
||||
await session.refresh(linear_user)
|
||||
|
||||
logger.info(f'[Linear] Updated user {keycloak_user_id} status to {status}')
|
||||
return linear_user
|
||||
|
||||
async def deactivate_workspace(self, workspace_id: int):
|
||||
"""Deactivate the workspace and all user links for a given workspace."""
|
||||
with session_maker() as session:
|
||||
users = (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.linear_workspace_id == workspace_id,
|
||||
LinearUser.status == 'active',
|
||||
)
|
||||
.all()
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
user.status = 'inactive'
|
||||
session.add(user)
|
||||
|
||||
workspace = (
|
||||
session.query(LinearWorkspace)
|
||||
.filter(LinearWorkspace.id == workspace_id)
|
||||
.first()
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
if workspace:
|
||||
workspace.status = 'inactive'
|
||||
session.add(workspace)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}')
|
||||
|
||||
@@ -227,23 +225,22 @@ class LinearIntegrationStore:
|
||||
self, linear_conversation: LinearConversation
|
||||
) -> None:
|
||||
"""Create a new Linear conversation record."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(linear_conversation)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
async def get_user_conversations_by_issue_id(
|
||||
self, issue_id: str, linear_user_id: int
|
||||
) -> LinearConversation | None:
|
||||
"""Get a Linear conversation by issue ID and linear user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearConversation)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearConversation).where(
|
||||
LinearConversation.issue_id == issue_id,
|
||||
LinearConversation.linear_user_id == linear_user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> LinearIntegrationStore:
|
||||
|
||||
@@ -2,38 +2,35 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.slack_conversation import SlackConversation
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackConversationStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
async def get_slack_conversation(
|
||||
self, channel_id: str, parent_id: str
|
||||
) -> SlackConversation | None:
|
||||
"""Get a slack conversation by channel_id and message_ts.
|
||||
Both parameters are required to match for a conversation to be returned.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
conversation = (
|
||||
session.query(SlackConversation)
|
||||
.filter(SlackConversation.channel_id == channel_id)
|
||||
.filter(SlackConversation.parent_id == parent_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SlackConversation).where(
|
||||
SlackConversation.channel_id == channel_id,
|
||||
SlackConversation.parent_id == parent_id,
|
||||
)
|
||||
)
|
||||
|
||||
return conversation
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_slack_conversation(
|
||||
self, slack_converstion: SlackConversation
|
||||
) -> None:
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(slack_converstion)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> SlackConversationStore:
|
||||
return SlackConversationStore(session_maker)
|
||||
return SlackConversationStore()
|
||||
|
||||
@@ -227,7 +227,7 @@ class UserStore:
|
||||
'user_store:migrate_user:calling_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await migrate_customer(session, user_id, org)
|
||||
await migrate_customer(user_id, org)
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
|
||||
@@ -114,9 +114,7 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_expired(
|
||||
api_key_store, session_maker, async_session_maker
|
||||
):
|
||||
async def test_validate_api_key_expired(api_key_store, async_session_maker):
|
||||
"""Test validating an expired API key."""
|
||||
# Setup - create an expired API key in the database
|
||||
user_id = str(uuid.uuid4())
|
||||
@@ -144,7 +142,7 @@ async def test_validate_api_key_expired(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_expired_timezone_naive(
|
||||
api_key_store, session_maker, async_session_maker
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating an expired API key with timezone-naive datetime from database."""
|
||||
# Setup - create an expired API key with timezone-naive datetime
|
||||
@@ -174,7 +172,7 @@ async def test_validate_api_key_expired_timezone_naive(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_valid_timezone_naive(
|
||||
api_key_store, session_maker, async_session_maker
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||
# Setup - create a valid API key with timezone-naive datetime (future date)
|
||||
@@ -293,7 +291,7 @@ async def test_delete_api_key_by_id(api_key_store, async_session_maker):
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_list_api_keys(
|
||||
mock_get_user, api_key_store, session_maker, async_session_maker, mock_user
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
):
|
||||
"""Test listing API keys for a user."""
|
||||
# Setup
|
||||
@@ -346,7 +344,7 @@ async def test_list_api_keys(
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_retrieve_mcp_api_key(
|
||||
mock_get_user, api_key_store, session_maker, async_session_maker, mock_user
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key for a user."""
|
||||
# Setup
|
||||
@@ -385,7 +383,7 @@ async def test_retrieve_mcp_api_key(
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, api_key_store, session_maker, async_session_maker, mock_user
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key when none exists."""
|
||||
# Setup
|
||||
@@ -415,9 +413,7 @@ async def test_retrieve_mcp_api_key_not_found(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_api_key_by_name(
|
||||
api_key_store, session_maker, async_session_maker
|
||||
):
|
||||
async def test_retrieve_api_key_by_name(api_key_store, async_session_maker):
|
||||
"""Test retrieving an API key by name."""
|
||||
# Setup
|
||||
user_id = str(uuid.uuid4())
|
||||
@@ -457,9 +453,7 @@ async def test_retrieve_api_key_by_name_not_found(api_key_store, async_session_m
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_name(
|
||||
api_key_store, session_maker, async_session_maker
|
||||
):
|
||||
async def test_delete_api_key_by_name(api_key_store, async_session_maker):
|
||||
"""Test deleting an API key by name."""
|
||||
# Setup
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
@@ -621,7 +621,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
@@ -686,7 +686,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
@@ -749,7 +749,7 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
@@ -898,7 +898,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
@@ -959,7 +959,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
@@ -1022,7 +1022,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
@@ -1174,7 +1174,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
@@ -1325,7 +1325,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
@@ -1414,7 +1414,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
@@ -1500,7 +1500,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
@@ -1585,7 +1585,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
@@ -1666,7 +1666,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', ''),
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
@@ -1734,7 +1734,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
@@ -1808,7 +1808,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
import stripe
|
||||
from fastapi import HTTPException, Request, status
|
||||
from httpx import Response
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from server.routes import billing
|
||||
from server.routes.billing import (
|
||||
CreateBillingSessionResponse,
|
||||
@@ -18,22 +19,11 @@ from server.routes.billing import (
|
||||
has_payment_method,
|
||||
success_callback,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import select
|
||||
from starlette.datastructures import URL
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
StripeCustomerBase.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -76,6 +66,38 @@ def mock_subscription_request():
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_org(async_session_maker):
|
||||
"""Create a test org in the database."""
|
||||
org_id = uuid.uuid4()
|
||||
async with async_session_maker() as session:
|
||||
org = Org(
|
||||
id=org_id,
|
||||
name=f'test-org-{org_id}',
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
return org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(async_session_maker, test_org):
|
||||
"""Create a test user in the database linked to test_org."""
|
||||
user_id = uuid.uuid4()
|
||||
async with async_session_maker() as session:
|
||||
user = User(
|
||||
id=user_id,
|
||||
current_org_id=test_org.id,
|
||||
user_consents_to_analytics=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
with (
|
||||
@@ -133,17 +155,14 @@ async def test_get_credits_success():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkout_session_stripe_error(
|
||||
session_maker, mock_checkout_request
|
||||
async_session_maker, mock_checkout_request, test_org
|
||||
):
|
||||
"""Test handling of Stripe API errors."""
|
||||
|
||||
mock_customer = stripe.Customer(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = uuid.uuid4()
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
|
||||
with (
|
||||
pytest.raises(Exception, match='Stripe API Error'),
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
@@ -154,10 +173,13 @@ async def test_create_checkout_session_stripe_error(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(side_effect=Exception('Stripe API Error')),
|
||||
),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('integrations.stripe_service.a_session_maker', async_session_maker),
|
||||
patch('storage.database.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
return_value=test_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
@@ -171,44 +193,27 @@ async def test_create_checkout_session_stripe_error(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkout_session_success(session_maker, mock_checkout_request):
|
||||
async def test_create_checkout_session_success(
|
||||
async_session_maker, mock_checkout_request, test_org
|
||||
):
|
||||
"""Test successful creation of checkout session."""
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session.id = 'test_session_id'
|
||||
mock_session.id = 'test_session_id_checkout'
|
||||
mock_create = AsyncMock(return_value=mock_session)
|
||||
mock_create.return_value = mock_session
|
||||
|
||||
mock_customer = stripe.Customer(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org_id = uuid.uuid4()
|
||||
mock_org.id = mock_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
mock_customer_info = {'customer_id': 'mock-customer', 'org_id': test_org.id}
|
||||
|
||||
with (
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
patch(
|
||||
'stripe.Customer.search_async', AsyncMock(return_value=MagicMock(data=[]))
|
||||
),
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('integrations.stripe_service.a_session_maker', async_session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
'integrations.stripe_service.find_or_create_customer_by_user_id',
|
||||
AsyncMock(return_value=mock_customer_info),
|
||||
),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
result = await create_checkout_session(
|
||||
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
||||
)
|
||||
@@ -240,74 +245,102 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
|
||||
)
|
||||
|
||||
# Verify database session creation
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
# Verify database record was created
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == 'test_session_id_checkout'
|
||||
)
|
||||
)
|
||||
billing_session = result_db.scalar_one_or_none()
|
||||
assert billing_session is not None
|
||||
assert billing_session.user_id == 'mock_user'
|
||||
assert billing_session.org_id == test_org.id
|
||||
assert billing_session.status == 'in_progress'
|
||||
assert float(billing_session.price) == 25.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_session_not_found():
|
||||
async def test_success_callback_session_not_found(async_session_maker):
|
||||
"""Test success callback when billing session is not found."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
with (
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('stripe.checkout.Session.retrieve'),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await success_callback('test_session_id', mock_request)
|
||||
await success_callback('nonexistent_session_id', mock_request)
|
||||
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_stripe_incomplete():
|
||||
async def test_success_callback_stripe_incomplete(
|
||||
async_session_maker, test_org, test_user
|
||||
):
|
||||
"""Test success callback when Stripe session is not complete."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
session_id = 'test_incomplete_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=25,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(status='pending')
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await success_callback('test_session_id', mock_request)
|
||||
await success_callback(session_id, mock_request)
|
||||
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
# Verify no database update occurred
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'in_progress'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_success():
|
||||
async def test_success_callback_success(async_session_maker, test_org, test_user):
|
||||
"""Test successful payment completion and credit update."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
session_id = 'test_success_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=25,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
return_value=MagicMock(current_org_id=test_org.id),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
@@ -320,25 +353,11 @@ async def test_success_callback_success():
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
||||
) as mock_update_budget,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
# First query: BillingSession (query().filter().filter().first())
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
# Second query: Org (query().filter().first()) - use side_effect for different return chains
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_subtotal=2500, customer='mock_customer_id'
|
||||
) # $25.00 in cents
|
||||
)
|
||||
|
||||
response = await success_callback('test_session_id', mock_request)
|
||||
response = await success_callback(session_id, mock_request)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
@@ -346,64 +365,80 @@ async def test_success_callback_success():
|
||||
== 'https://test.com/settings/billing?checkout=success'
|
||||
)
|
||||
|
||||
# Verify LiteLLM API calls
|
||||
mock_update_budget.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
str(test_org.id),
|
||||
125.0, # 100 + 25.00
|
||||
)
|
||||
|
||||
# Verify BYOR export is enabled for the org (updated in same session)
|
||||
assert mock_org.byor_export_enabled is True
|
||||
# Verify database updates
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'completed'
|
||||
assert float(billing_session.price) == 25.0
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
mock_db_session.merge.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
# Verify org byor_export_enabled was set
|
||||
org_result = await session.execute(select(Org).where(Org.id == test_org.id))
|
||||
org = org_result.scalar_one_or_none()
|
||||
assert org.byor_export_enabled is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_lite_llm_error():
|
||||
async def test_success_callback_lite_llm_error(
|
||||
async_session_maker, test_org, test_user
|
||||
):
|
||||
"""Test handling of LiteLLM API errors during success callback."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
session_id = 'test_litellm_error_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=25,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
return_value=MagicMock(current_org_id=test_org.id),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_subtotal=2500
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
await success_callback(session_id, mock_request)
|
||||
|
||||
# Verify no database updates occurred
|
||||
assert mock_billing_session.status == 'in_progress'
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
# Verify no database updates occurred (transaction rolled back)
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'in_progress'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_lite_llm_update_budget_error_rollback():
|
||||
async def test_success_callback_lite_llm_update_budget_error_rollback(
|
||||
async_session_maker, test_org, test_user
|
||||
):
|
||||
"""Test that database changes are not committed when update_team_and_users_budget fails.
|
||||
|
||||
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception,
|
||||
@@ -412,19 +447,26 @@ async def test_success_callback_lite_llm_update_budget_error_rollback():
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
session_id = 'test_budget_rollback_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=10,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
return_value=MagicMock(current_org_id=test_org.id),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
@@ -438,70 +480,60 @@ async def test_success_callback_lite_llm_update_budget_error_rollback():
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete',
|
||||
amount_subtotal=1000, # $10
|
||||
amount_subtotal=1000,
|
||||
customer='mock_customer_id',
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
await success_callback(session_id, mock_request)
|
||||
|
||||
# Verify no database commit occurred - the transaction should roll back
|
||||
assert mock_billing_session.status == 'in_progress'
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
# Verify no database commit occurred - the transaction should roll back
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'in_progress'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_session_not_found():
|
||||
async def test_cancel_callback_session_not_found(async_session_maker):
|
||||
"""Test cancel callback when billing session is not found."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
response = await cancel_callback('test_session_id', mock_request)
|
||||
with patch('server.routes.billing.a_session_maker', async_session_maker):
|
||||
response = await cancel_callback('nonexistent_session_id', mock_request)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify no database updates occurred
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_success():
|
||||
async def test_cancel_callback_success(async_session_maker, test_org, test_user):
|
||||
"""Test successful cancellation of billing session."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
session_id = 'test_cancel_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=25,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
response = await cancel_callback('test_session_id', mock_request)
|
||||
with patch('server.routes.billing.a_session_maker', async_session_maker):
|
||||
response = await cancel_callback(session_id, mock_request)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
@@ -509,16 +541,18 @@ async def test_cancel_callback_success():
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'cancelled'
|
||||
mock_db_session.merge.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
# Verify database update
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'cancelled'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_with_payment_method():
|
||||
"""Test has_payment_method returns True when user has a payment method."""
|
||||
|
||||
mock_has_payment_method = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Tests for the GitlabCallbackProcessor.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.gitlab.gitlab_view import GitlabIssueComment
|
||||
@@ -111,20 +111,15 @@ class TestGitlabCallbackProcessor:
|
||||
@patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.conversation_manager'
|
||||
)
|
||||
@patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.session_maker'
|
||||
)
|
||||
async def test_call_with_send_summary_instruction(
|
||||
self,
|
||||
mock_session_maker,
|
||||
mock_conversation_manager,
|
||||
mock_get_summary_instruction,
|
||||
async_session_maker,
|
||||
gitlab_callback_processor,
|
||||
):
|
||||
"""Test the __call__ method when send_summary_instruction is True."""
|
||||
# Setup mocks
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_conversation_manager.send_event_to_conversation = AsyncMock()
|
||||
mock_get_summary_instruction.return_value = (
|
||||
"I'm a man of few words. Any questions?"
|
||||
@@ -142,15 +137,17 @@ class TestGitlabCallbackProcessor:
|
||||
)
|
||||
|
||||
# Call the processor
|
||||
await gitlab_callback_processor(callback, observation)
|
||||
with patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.a_session_maker',
|
||||
async_session_maker,
|
||||
):
|
||||
await gitlab_callback_processor(callback, observation)
|
||||
|
||||
# Verify that send_event_to_conversation was called
|
||||
mock_conversation_manager.send_event_to_conversation.assert_called_once()
|
||||
|
||||
# Verify that the processor state was updated
|
||||
assert gitlab_callback_processor.send_summary_instruction is False
|
||||
mock_session.merge.assert_called_once_with(callback)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
@@ -162,21 +159,16 @@ class TestGitlabCallbackProcessor:
|
||||
@patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.asyncio.create_task'
|
||||
)
|
||||
@patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.session_maker'
|
||||
)
|
||||
async def test_call_with_extract_summary(
|
||||
self,
|
||||
mock_session_maker,
|
||||
mock_create_task,
|
||||
mock_extract_summary,
|
||||
mock_conversation_manager,
|
||||
async_session_maker,
|
||||
gitlab_callback_processor,
|
||||
):
|
||||
"""Test the __call__ method when send_summary_instruction is False."""
|
||||
# Setup mocks
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_extract_summary.return_value = 'Test summary'
|
||||
# Ensure we don't leak an un-awaited coroutine when create_task is mocked
|
||||
mock_create_task.side_effect = lambda coro: (coro.close(), None)[1]
|
||||
@@ -196,20 +188,22 @@ class TestGitlabCallbackProcessor:
|
||||
)
|
||||
|
||||
# Call the processor
|
||||
await gitlab_callback_processor(callback, observation)
|
||||
with patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.a_session_maker',
|
||||
async_session_maker,
|
||||
):
|
||||
await gitlab_callback_processor(callback, observation)
|
||||
|
||||
# Verify that extract_summary_from_conversation_manager was called
|
||||
mock_extract_summary.assert_called_once_with(
|
||||
mock_conversation_manager, 'conv123'
|
||||
)
|
||||
|
||||
# Verify that create_task was called to send the message
|
||||
mock_create_task.assert_called_once()
|
||||
# Verify that create_task was called at least once to send the message
|
||||
assert mock_create_task.call_count >= 1
|
||||
|
||||
# Verify that the callback status was updated
|
||||
assert callback.status == CallbackStatus.COMPLETED
|
||||
mock_session.merge.assert_called_once_with(callback)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_with_non_terminal_state(self, gitlab_callback_processor):
|
||||
|
||||
@@ -12,35 +12,20 @@ from integrations.stripe_service import (
|
||||
find_customer_id_by_user_id,
|
||||
find_or_create_customer_by_user_id,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
# Create all tables using the unified Base
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_org_and_user(session_maker):
|
||||
def add_test_org_and_user(session_maker):
|
||||
"""Create a test org and user for use in tests."""
|
||||
test_user_id = uuid.uuid4()
|
||||
test_org_id = uuid.uuid4()
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
with session_maker() as session:
|
||||
# Create role first
|
||||
role = Role(name='test-role', rank=1)
|
||||
@@ -72,15 +57,17 @@ def test_org_and_user(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_customer_id_by_user_id_checks_db_first(
|
||||
session_maker, test_org_and_user
|
||||
async_session_maker, session_maker_with_minimal_fixtures
|
||||
):
|
||||
"""Test that find_customer_id_by_user_id checks the database first"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
# Add test org and user to the db
|
||||
test_user_id, test_org_id = add_test_org_and_user(
|
||||
session_maker_with_minimal_fixtures
|
||||
)
|
||||
|
||||
# Set up the mock for the database query result
|
||||
with session_maker() as session:
|
||||
# Create stripe customer
|
||||
# Create stripe customer in the db
|
||||
async with async_session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id=str(test_user_id),
|
||||
@@ -88,15 +75,15 @@ async def test_find_customer_id_by_user_id_checks_db_first(
|
||||
stripe_customer_id='cus_test123',
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('integrations.stripe_service.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
@@ -114,11 +101,14 @@ async def test_find_customer_id_by_user_id_checks_db_first(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_customer_id_by_user_id_falls_back_to_stripe(
|
||||
session_maker, test_org_and_user
|
||||
async_session_maker, session_maker_with_minimal_fixtures
|
||||
):
|
||||
"""Test that find_customer_id_by_user_id falls back to Stripe if not found in the database"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
# Add test org and user to the db
|
||||
test_user_id, test_org_id = add_test_org_and_user(
|
||||
session_maker_with_minimal_fixtures
|
||||
)
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async
|
||||
mock_customer = stripe.Customer(id='cus_test123')
|
||||
@@ -129,8 +119,8 @@ async def test_find_customer_id_by_user_id_falls_back_to_stripe(
|
||||
mock_org.id = test_org_id
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('integrations.stripe_service.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
@@ -151,10 +141,15 @@ async def test_find_customer_id_by_user_id_falls_back_to_stripe(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_customer_stores_id_in_db(session_maker, test_org_and_user):
|
||||
async def test_create_customer_stores_id_in_db(
|
||||
async_session_maker, session_maker_with_minimal_fixtures
|
||||
):
|
||||
"""Test that create_customer stores the customer ID in the database"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
# Add test org and user to the db
|
||||
test_user_id, test_org_id = add_test_org_and_user(
|
||||
session_maker_with_minimal_fixtures
|
||||
)
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async and create_async
|
||||
mock_search = AsyncMock(return_value=MagicMock(data=[]))
|
||||
@@ -166,14 +161,20 @@ async def test_create_customer_stores_id_in_db(session_maker, test_org_and_user)
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('integrations.stripe_service.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('stripe.Customer.create_async', mock_create_async),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
patch(
|
||||
'integrations.stripe_service.find_customer_id_by_org_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_find_customer,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
# Mock find_customer_id_by_org_id to return None (force creation path)
|
||||
mock_find_customer.return_value = None
|
||||
|
||||
# Call the function
|
||||
result = await find_or_create_customer_by_user_id(str(test_user_id))
|
||||
@@ -182,8 +183,15 @@ async def test_create_customer_stores_id_in_db(session_maker, test_org_and_user)
|
||||
assert result == {'customer_id': 'cus_test123', 'org_id': str(test_org_id)}
|
||||
|
||||
# Verify that the stripe customer was stored in the db
|
||||
with session_maker() as session:
|
||||
customer = session.query(StripeCustomer).first()
|
||||
async with async_session_maker() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(StripeCustomer).where(
|
||||
StripeCustomer.keycloak_user_id == str(test_user_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
customer = result.scalar_one_or_none()
|
||||
assert customer is not None
|
||||
assert customer.id > 0
|
||||
assert customer.keycloak_user_id == str(test_user_id)
|
||||
assert customer.org_id == test_org_id
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user