diff --git a/enterprise/migrations/versions/093_add_pending_free_credits.py b/enterprise/migrations/versions/093_add_pending_free_credits.py new file mode 100644 index 0000000000..10bd8fb078 --- /dev/null +++ b/enterprise/migrations/versions/093_add_pending_free_credits.py @@ -0,0 +1,37 @@ +"""Add pending_free_credits flag to org table. + +Revision ID: 093 +Revises: 092 +Create Date: 2025-02-17 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '093' +down_revision: Union[str, None] = '092' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add pending_free_credits column to org table with default false. + # New orgs will have this set to TRUE at creation time. + # Existing orgs default to FALSE (not eligible - they already got $10 at signup). + op.add_column( + 'org', + sa.Column( + 'pending_free_credits', + sa.Boolean, + nullable=False, + server_default=sa.text('false'), + ), + ) + + +def downgrade() -> None: + op.drop_column('org', 'pending_free_credits') diff --git a/enterprise/server/constants.py b/enterprise/server/constants.py index 341429924e..d89593a09d 100644 --- a/enterprise/server/constants.py +++ b/enterprise/server/constants.py @@ -61,7 +61,8 @@ SUBSCRIPTION_PRICE_DATA = { }, } -DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10')) +FREE_CREDIT_THRESHOLD = float(os.environ.get('FREE_CREDIT_THRESHOLD', '10')) +FREE_CREDIT_AMOUNT = float(os.environ.get('FREE_CREDIT_AMOUNT', '10')) STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None) REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true') diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index 438887571d..61231a5e01 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -10,6 +10,8 @@ from fastapi.responses import RedirectResponse from integrations import stripe_service from pydantic import BaseModel from server.constants import ( + FREE_CREDIT_AMOUNT, + FREE_CREDIT_THRESHOLD, STRIPE_API_KEY, ) from server.logger import logger @@ -254,15 +256,33 @@ async def success_callback(session_id: str, request: Request): max_budget = (user_team_info.get('litellm_budget_table') or {}).get( 'max_budget', 0 ) + + org = session.query(Org).filter(Org.id == user.current_org_id).first() new_max_budget = max_budget + add_credits + # Grant free credits if: + # 1. The org has pending free credits (new org, eligible) + # 2. The budget after this purchase meets the threshold + should_grant_free_credits = ( + org and org.pending_free_credits and new_max_budget >= FREE_CREDIT_THRESHOLD + ) + if should_grant_free_credits: + new_max_budget += FREE_CREDIT_AMOUNT + org.pending_free_credits = False + logger.info( + 'free_credits_granted', + extra={ + 'user_id': billing_session.user_id, + 'org_id': str(user.current_org_id), + 'free_credit_amount': FREE_CREDIT_AMOUNT, + }, + ) + await LiteLlmManager.update_team_and_users_budget( str(user.current_org_id), new_max_budget ) # Enable BYOR export for the org now that they've purchased credits - # Update within the same session to avoid nested session issues - org = session.query(Org).filter(Org.id == user.current_org_id).first() if org: org.byor_export_enabled = True @@ -279,6 +299,7 @@ async def success_callback(session_id: str, request: Request): 'org_id': str(user.current_org_id), 'checkout_session_id': billing_session.id, 'stripe_customer_id': stripe_session.customer, + 'free_credits_granted': should_grant_free_credits, }, ) session.commit() diff --git a/enterprise/storage/lite_llm_manager.py b/enterprise/storage/lite_llm_manager.py index e4380e29f3..e636e39759 100644 --- a/enterprise/storage/lite_llm_manager.py +++ b/enterprise/storage/lite_llm_manager.py @@ -10,7 +10,6 @@ import httpx from pydantic import SecretStr from server.auth.token_manager import TokenManager from server.constants import ( - DEFAULT_INITIAL_BUDGET, LITE_LLM_API_KEY, LITE_LLM_API_URL, LITE_LLM_TEAM_ID, @@ -72,9 +71,8 @@ class LiteLlmManager: 'x-goog-api-key': LITE_LLM_API_KEY, } ) as client: - await LiteLlmManager._create_team( - client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET - ) + # New users start with $0 budget - they must purchase credits + await LiteLlmManager._create_team(client, keycloak_user_id, org_id, 0) if create_user: await LiteLlmManager._create_user( @@ -82,7 +80,7 @@ class LiteLlmManager: ) await LiteLlmManager._add_user_to_team( - client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET + client, keycloak_user_id, org_id, 0 ) key = await LiteLlmManager._generate_key( diff --git a/enterprise/storage/org.py b/enterprise/storage/org.py index 45d64a0b83..6e9884d655 100644 --- a/enterprise/storage/org.py +++ b/enterprise/storage/org.py @@ -47,6 +47,7 @@ class Org(Base): # type: ignore conversation_expiration = Column(Integer, nullable=True) condenser_max_size = Column(Integer, nullable=True) byor_export_enabled = Column(Boolean, nullable=False, default=False) + pending_free_credits = Column(Boolean, nullable=False, default=False) # Relationships org_members = relationship('OrgMember', back_populates='org') diff --git a/enterprise/storage/org_service.py b/enterprise/storage/org_service.py index 144d636a83..b358c78d12 100644 --- a/enterprise/storage/org_service.py +++ b/enterprise/storage/org_service.py @@ -112,6 +112,7 @@ class OrgService: contact_email=contact_email, org_version=ORG_SETTINGS_VERSION, default_llm_model=get_default_litellm_model(), + pending_free_credits=True, ) @staticmethod diff --git a/enterprise/storage/user_store.py b/enterprise/storage/user_store.py index 651e98176d..ff49c17da8 100644 --- a/enterprise/storage/user_store.py +++ b/enterprise/storage/user_store.py @@ -59,6 +59,7 @@ class UserStore: or user_info.get('preferred_username', ''), contact_email=user_info['email'], v1_enabled=True, + pending_free_credits=True, ) session.add(org) @@ -195,6 +196,7 @@ class UserStore: or user_info.get('username', ''), contact_email=user_info['email'], byor_export_enabled=has_completed_billing, + pending_free_credits=not has_completed_billing, ) session.add(org) diff --git a/enterprise/tests/unit/test_billing.py b/enterprise/tests/unit/test_billing.py index 46ca96f083..902af81e15 100644 --- a/enterprise/tests/unit/test_billing.py +++ b/enterprise/tests/unit/test_billing.py @@ -291,7 +291,7 @@ async def test_success_callback_stripe_incomplete(): @pytest.mark.asyncio async def test_success_callback_success(): - """Test successful payment completion and credit update.""" + """Test successful payment completion and credit update (bonus already granted).""" mock_request = Request(scope={'type': 'http'}) mock_request._base_url = URL('http://test.com/') @@ -300,6 +300,7 @@ async def test_success_callback_success(): mock_billing_session.user_id = 'mock_user' mock_org = MagicMock() + mock_org.pending_free_credits = False # Not eligible (old org or already granted) with ( patch('server.routes.billing.session_maker') as mock_session_maker, @@ -346,10 +347,10 @@ async def test_success_callback_success(): == 'https://test.com/settings/billing?checkout=success' ) - # Verify LiteLLM API calls + # Verify LiteLLM API calls - no bonus since not eligible mock_update_budget.assert_called_once_with( 'mock_org_id', - 125.0, # 100 + (25.00 from Stripe) + 125.0, # 100 + 25.00 (no bonus) ) # Verify BYOR export is enabled for the org (updated in same session) @@ -362,6 +363,92 @@ async def test_success_callback_success(): mock_db_session.commit.assert_called_once() +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'initial_budget,purchase_cents,pending_credits,expected_final_budget,expected_pending_after', + [ + # New user buys $10 -> gets free credits, pending becomes False + (0, 1000, True, 20.0, False), + # New user buys $5 -> below threshold, no free credits yet, pending stays True + (0, 500, True, 5.0, True), + # User with $5 buys $5 more -> reaches threshold, gets free credits + (5.0, 500, True, 20.0, False), + # User with $5 buys $3 -> below threshold, no free credits yet + (5.0, 300, True, 8.0, True), + # Old user (not pending) buys $25 -> no free credits, stays False + (20.0, 2500, False, 45.0, False), + ], + ids=[ + 'new_user_buys_10_gets_free_credits', + 'new_user_buys_5_below_threshold', + 'user_with_5_buys_5_reaches_threshold', + 'user_with_5_buys_3_below_threshold', + 'old_user_not_eligible', + ], +) +async def test_success_callback_free_credits( + initial_budget, + purchase_cents, + pending_credits, + expected_final_budget, + expected_pending_after, +): + """Test free credits are granted only when pending and threshold is met.""" + mock_request = Request(scope={'type': 'http'}) + mock_request._base_url = URL('http://test.com/') + + mock_billing_session = MagicMock() + mock_billing_session.status = 'in_progress' + mock_billing_session.user_id = 'mock_user' + + mock_org = MagicMock() + mock_org.pending_free_credits = pending_credits + + with ( + patch('server.routes.billing.session_maker') as mock_session_maker, + patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve, + patch( + 'storage.user_store.UserStore.get_user_by_id_async', + new_callable=AsyncMock, + return_value=MagicMock(current_org_id='mock_org_id'), + ), + patch( + 'storage.lite_llm_manager.LiteLlmManager.get_user_team_info', + return_value={ + 'spend': 0, + 'litellm_budget_table': {'max_budget': initial_budget}, + }, + ), + patch( + 'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget' + ) as mock_update_budget, + patch('server.routes.billing.FREE_CREDIT_THRESHOLD', 10.0), + patch('server.routes.billing.FREE_CREDIT_AMOUNT', 10.0), + ): + mock_db_session = MagicMock() + mock_query_chain_billing = MagicMock() + mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session + mock_query_chain_org = MagicMock() + mock_query_chain_org.filter.return_value.first.return_value = mock_org + mock_db_session.query.side_effect = [ + mock_query_chain_billing, + mock_query_chain_org, + ] + mock_session_maker.return_value.__enter__.return_value = mock_db_session + + mock_stripe_retrieve.return_value = MagicMock( + status='complete', + amount_subtotal=purchase_cents, + customer='mock_customer_id', + ) + + response = await success_callback('test_session_id', mock_request) + + assert response.status_code == 302 + mock_update_budget.assert_called_once_with('mock_org_id', expected_final_budget) + assert mock_org.pending_free_credits is expected_pending_after + + @pytest.mark.asyncio async def test_success_callback_lite_llm_error(): """Test handling of LiteLLM API errors during success callback.""" @@ -402,6 +489,73 @@ async def test_success_callback_lite_llm_error(): mock_db_session.commit.assert_not_called() +@pytest.mark.asyncio +async def test_success_callback_lite_llm_update_budget_error_rollback(): + """Test that pending_free_credits change is not committed when update_team_and_users_budget fails. + + This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception + after pending_free_credits has been set to False, the database transaction rolls back and + pending_free_credits remains True. + """ + mock_request = Request(scope={'type': 'http'}) + mock_request._base_url = URL('http://test.com/') + + mock_billing_session = MagicMock() + mock_billing_session.status = 'in_progress' + mock_billing_session.user_id = 'mock_user' + + mock_org = MagicMock() + mock_org.pending_free_credits = True + + with ( + patch('server.routes.billing.session_maker') as mock_session_maker, + patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve, + patch( + 'storage.user_store.UserStore.get_user_by_id_async', + new_callable=AsyncMock, + return_value=MagicMock(current_org_id='mock_org_id'), + ), + patch( + 'storage.lite_llm_manager.LiteLlmManager.get_user_team_info', + return_value={ + 'spend': 0, + 'litellm_budget_table': {'max_budget': 0}, + }, + ), + patch( + 'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget', + side_effect=Exception('LiteLLM API Error'), + ), + patch('server.routes.billing.FREE_CREDIT_THRESHOLD', 10.0), + patch('server.routes.billing.FREE_CREDIT_AMOUNT', 10.0), + ): + mock_db_session = MagicMock() + mock_query_chain_billing = MagicMock() + mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session + mock_query_chain_org = MagicMock() + mock_query_chain_org.filter.return_value.first.return_value = mock_org + mock_db_session.query.side_effect = [ + mock_query_chain_billing, + mock_query_chain_org, + ] + mock_session_maker.return_value.__enter__.return_value = mock_db_session + + # Purchase $10 to reach threshold + mock_stripe_retrieve.return_value = MagicMock( + status='complete', + amount_subtotal=1000, # $10 + customer='mock_customer_id', + ) + + with pytest.raises(Exception, match='LiteLLM API Error'): + await success_callback('test_session_id', mock_request) + + # Verify no database commit occurred - the transaction should roll back + assert mock_billing_session.status == 'in_progress' + mock_db_session.merge.assert_not_called() + mock_db_session.commit.assert_not_called() + + @pytest.mark.asyncio async def test_cancel_callback_session_not_found(): """Test cancel callback when billing session is not found."""