diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index f99d90f710..438887571d 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -17,7 +17,7 @@ from starlette.datastructures import URL from storage.billing_session import BillingSession from storage.database import session_maker from storage.lite_llm_manager import LiteLlmManager -from storage.org_store import OrgStore +from storage.org import Org from storage.subscription_access import SubscriptionAccess from storage.user_store import UserStore @@ -261,7 +261,10 @@ async def success_callback(session_id: str, request: Request): ) # Enable BYOR export for the org now that they've purchased credits - OrgStore.update_org(user.current_org_id, {'byor_export_enabled': True}) + # Update within the same session to avoid nested session issues + org = session.query(Org).filter(Org.id == user.current_org_id).first() + if org: + org.byor_export_enabled = True # Store transaction status billing_session.status = 'completed' diff --git a/enterprise/tests/unit/test_billing.py b/enterprise/tests/unit/test_billing.py index ca54cd788b..46ca96f083 100644 --- a/enterprise/tests/unit/test_billing.py +++ b/enterprise/tests/unit/test_billing.py @@ -299,6 +299,8 @@ async def test_success_callback_success(): mock_billing_session.status = 'in_progress' mock_billing_session.user_id = 'mock_user' + mock_org = MagicMock() + with ( patch('server.routes.billing.session_maker') as mock_session_maker, patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve, @@ -317,10 +319,19 @@ async def test_success_callback_success(): patch( 'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget' ) as mock_update_budget, - patch('server.routes.billing.OrgStore.update_org') as mock_update_org, ): mock_db_session = MagicMock() + # First query: BillingSession (query().filter().filter().first()) mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session + # Second query: Org (query().filter().first()) - use side_effect for different return chains + mock_query_chain_billing = MagicMock() + mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session + mock_query_chain_org = MagicMock() + mock_query_chain_org.filter.return_value.first.return_value = mock_org + mock_db_session.query.side_effect = [ + mock_query_chain_billing, + mock_query_chain_org, + ] mock_session_maker.return_value.__enter__.return_value = mock_db_session mock_stripe_retrieve.return_value = MagicMock( @@ -341,11 +352,8 @@ async def test_success_callback_success(): 125.0, # 100 + (25.00 from Stripe) ) - # Verify BYOR export is enabled for the org - mock_update_org.assert_called_once_with( - 'mock_org_id', - {'byor_export_enabled': True}, - ) + # Verify BYOR export is enabled for the org (updated in same session) + assert mock_org.byor_export_enabled is True # Verify database updates assert mock_billing_session.status == 'completed'