diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index 0ec621f0d5..19e9dfa17c 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -13,46 +13,33 @@ from server.constants import ( STRIPE_API_KEY, ) from server.logger import logger +from starlette.datastructures import URL from storage.billing_session import BillingSession from storage.database import session_maker from storage.lite_llm_manager import LiteLlmManager from storage.subscription_access import SubscriptionAccess from storage.user_store import UserStore +from openhands.app_server.config import get_global_config from openhands.server.user_auth import get_user_id stripe.api_key = STRIPE_API_KEY billing_router = APIRouter(prefix='/api/billing') -# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this -# and members should comment out the "validate_saas_environment" function if they are developing and testing locally. -def is_all_hands_saas_environment(request: Request) -> bool: - """Check if the current domain is an All Hands SaaS environment. - - Args: - request: FastAPI Request object - - Returns: - True if the current domain contains "all-hands.dev" or "openhands.dev" postfix +async def validate_billing_enabled() -> None: """ - hostname = request.url.hostname or '' - return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev') - - -def validate_saas_environment(request: Request) -> None: - """Validate that the request is coming from an All Hands SaaS environment. - - Args: - request: FastAPI Request object - - Raises: - HTTPException: If the request is not from an All Hands SaaS environment + Validate that the billing feature flag is enabled """ - if not is_all_hands_saas_environment(request): + config = get_global_config() + web_client_config = await config.web_client.get_web_client_config() + if not web_client_config.feature_flags.enable_billing: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail='Checkout sessions are only available for All Hands SaaS environments', + detail=( + 'Billing is disabled in this environment. ' + 'Please set OH_WEB_CLIENT_FEATURE_FLAGS_ENABLE_BILLING to enable billing.' + ), ) @@ -154,14 +141,15 @@ async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool: async def create_customer_setup_session( request: Request, user_id: str = Depends(get_user_id) ) -> CreateBillingSessionResponse: - validate_saas_environment(request) + await validate_billing_enabled() customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id) + base_url = _get_base_url(request) checkout_session = await stripe.checkout.Session.create_async( customer=customer_info['customer_id'], mode='setup', payment_method_types=['card'], - success_url=f'{request.base_url}?free_credits=success', - cancel_url=f'{request.base_url}', + success_url=f'{base_url}?free_credits=success', + cancel_url=f'{base_url}', ) return CreateBillingSessionResponse(redirect_url=checkout_session.url) @@ -173,8 +161,8 @@ async def create_checkout_session( request: Request, user_id: str = Depends(get_user_id), ) -> CreateBillingSessionResponse: - validate_saas_environment(request) - + await validate_billing_enabled() + base_url = _get_base_url(request) customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id) checkout_session = await stripe.checkout.Session.create_async( customer=customer_info['customer_id'], @@ -197,8 +185,8 @@ async def create_checkout_session( saved_payment_method_options={ 'payment_method_save': 'enabled', }, - success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}', - cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}', + success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}', + cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}', ) logger.info( 'created_stripe_checkout_session', @@ -289,7 +277,7 @@ async def success_callback(session_id: str, request: Request): session.commit() return RedirectResponse( - f'{request.base_url}settings/billing?checkout=success', status_code=302 + f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302 ) @@ -317,5 +305,13 @@ async def cancel_callback(session_id: str, request: Request): session.commit() return RedirectResponse( - f'{request.base_url}settings/billing?checkout=cancel', status_code=302 + f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302 ) + + +def _get_base_url(request: Request) -> URL: + # Never send any part of the credit card process over a non secure connection + base_url = request.base_url + if base_url.hostname != 'localhost': + base_url = base_url.replace(scheme='https') + return base_url diff --git a/enterprise/storage/user_store.py b/enterprise/storage/user_store.py index ca30612285..6f2adb459a 100644 --- a/enterprise/storage/user_store.py +++ b/enterprise/storage/user_store.py @@ -14,9 +14,9 @@ from server.constants import ( get_default_litellm_model, ) from server.logger import logger -from sqlalchemy import text +from sqlalchemy import select, text from sqlalchemy.orm import joinedload -from storage.database import session_maker +from storage.database import a_session_maker, session_maker from storage.encrypt_utils import decrypt_legacy_model from storage.org import Org from storage.org_member import OrgMember @@ -372,13 +372,13 @@ class UserStore: This is the preferred method when calling from an async context as it avoids event loop conflicts that can occur with the sync version. """ - with session_maker() as session: - user = ( - session.query(User) + async with a_session_maker() as session: + result = await session.execute( + select(User) .options(joinedload(User.org_members)) .filter(User.id == uuid.UUID(user_id)) - .first() ) + user = result.scalars().first() if user: return user @@ -392,16 +392,16 @@ class UserStore: await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS) # Check for user again as migration could have happened while trying to get the lock. - user = ( - session.query(User) + result = await session.execute( + select(User) .options(joinedload(User.org_members)) .filter(User.id == uuid.UUID(user_id)) - .first() ) + user = result.scalars().first() if user: return user - user_settings = ( + user_settings = await ( session.query(UserSettings) .filter( UserSettings.keycloak_user_id == user_id, diff --git a/enterprise/tests/unit/test_billing.py b/enterprise/tests/unit/test_billing.py index 58ef67d1f0..c7259e9a05 100644 --- a/enterprise/tests/unit/test_billing.py +++ b/enterprise/tests/unit/test_billing.py @@ -163,7 +163,7 @@ async def test_create_checkout_session_stripe_error( 'server.auth.token_manager.TokenManager.get_user_info_from_user_id', AsyncMock(return_value={'email': 'testy@tester.com'}), ), - patch('server.routes.billing.validate_saas_environment'), + patch('server.routes.billing.validate_billing_enabled'), ): await create_checkout_session( CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user' @@ -204,7 +204,7 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ 'server.auth.token_manager.TokenManager.get_user_info_from_user_id', AsyncMock(return_value={'email': 'testy@tester.com'}), ), - patch('server.routes.billing.validate_saas_environment'), + patch('server.routes.billing.validate_billing_enabled'), ): mock_db_session = MagicMock() mock_session_maker.return_value.__enter__.return_value = mock_db_session @@ -236,8 +236,8 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ mode='payment', payment_method_types=['card'], saved_payment_method_options={'payment_method_save': 'enabled'}, - success_url='http://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}', - cancel_url='http://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}', + success_url='https://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}', + cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}', ) # Verify database session creation @@ -331,7 +331,7 @@ async def test_success_callback_success(): assert response.status_code == 302 assert ( response.headers['location'] - == 'http://test.com/settings/billing?checkout=success' + == 'https://test.com/settings/billing?checkout=success' ) # Verify LiteLLM API calls @@ -402,7 +402,7 @@ async def test_cancel_callback_session_not_found(): assert response.status_code == 302 assert ( response.headers['location'] - == 'http://test.com/settings/billing?checkout=cancel' + == 'https://test.com/settings/billing?checkout=cancel' ) # Verify no database updates occurred @@ -429,7 +429,7 @@ async def test_cancel_callback_success(): assert response.status_code == 302 assert ( response.headers['location'] - == 'http://test.com/settings/billing?checkout=cancel' + == 'https://test.com/settings/billing?checkout=cancel' ) # Verify database updates @@ -490,7 +490,7 @@ async def test_create_customer_setup_session_success(): AsyncMock(return_value=mock_customer_info), ), patch('stripe.checkout.Session.create_async', mock_create), - patch('server.routes.billing.validate_saas_environment'), + patch('server.routes.billing.validate_billing_enabled'), ): result = await create_customer_setup_session(mock_request, 'mock_user') @@ -502,6 +502,6 @@ async def test_create_customer_setup_session_success(): customer='mock-customer-id', mode='setup', payment_method_types=['card'], - success_url='http://test.com/?free_credits=success', - cancel_url='http://test.com/', + success_url='https://test.com/?free_credits=success', + cancel_url='https://test.com/', )