Fix org billing (#12562)

This commit is contained in:
Tim O'Farrell
2026-01-23 08:42:50 -07:00
committed by GitHub
parent 9d0a19cf8f
commit 373ade8554
3 changed files with 49 additions and 53 deletions

View File

@@ -13,46 +13,33 @@ from server.constants import (
STRIPE_API_KEY,
)
from server.logger import logger
from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.subscription_access import SubscriptionAccess
from storage.user_store import UserStore
from openhands.app_server.config import get_global_config
from openhands.server.user_auth import get_user_id
stripe.api_key = STRIPE_API_KEY
billing_router = APIRouter(prefix='/api/billing')
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
def is_all_hands_saas_environment(request: Request) -> bool:
"""Check if the current domain is an All Hands SaaS environment.
Args:
request: FastAPI Request object
Returns:
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
async def validate_billing_enabled() -> None:
"""
hostname = request.url.hostname or ''
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
def validate_saas_environment(request: Request) -> None:
"""Validate that the request is coming from an All Hands SaaS environment.
Args:
request: FastAPI Request object
Raises:
HTTPException: If the request is not from an All Hands SaaS environment
Validate that the billing feature flag is enabled
"""
if not is_all_hands_saas_environment(request):
config = get_global_config()
web_client_config = await config.web_client.get_web_client_config()
if not web_client_config.feature_flags.enable_billing:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='Checkout sessions are only available for All Hands SaaS environments',
detail=(
'Billing is disabled in this environment. '
'Please set OH_WEB_CLIENT_FEATURE_FLAGS_ENABLE_BILLING to enable billing.'
),
)
@@ -154,14 +141,15 @@ async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
async def create_customer_setup_session(
request: Request, user_id: str = Depends(get_user_id)
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
await validate_billing_enabled()
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
base_url = _get_base_url(request)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_info['customer_id'],
mode='setup',
payment_method_types=['card'],
success_url=f'{request.base_url}?free_credits=success',
cancel_url=f'{request.base_url}',
success_url=f'{base_url}?free_credits=success',
cancel_url=f'{base_url}',
)
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
@@ -173,8 +161,8 @@ async def create_checkout_session(
request: Request,
user_id: str = Depends(get_user_id),
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
await validate_billing_enabled()
base_url = _get_base_url(request)
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_info['customer_id'],
@@ -197,8 +185,8 @@ async def create_checkout_session(
saved_payment_method_options={
'payment_method_save': 'enabled',
},
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
)
logger.info(
'created_stripe_checkout_session',
@@ -289,7 +277,7 @@ async def success_callback(session_id: str, request: Request):
session.commit()
return RedirectResponse(
f'{request.base_url}settings/billing?checkout=success', status_code=302
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
)
@@ -317,5 +305,13 @@ async def cancel_callback(session_id: str, request: Request):
session.commit()
return RedirectResponse(
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
)
def _get_base_url(request: Request) -> URL:
# Never send any part of the credit card process over a non secure connection
base_url = request.base_url
if base_url.hostname != 'localhost':
base_url = base_url.replace(scheme='https')
return base_url

View File

@@ -14,9 +14,9 @@ from server.constants import (
get_default_litellm_model,
)
from server.logger import logger
from sqlalchemy import text
from sqlalchemy import select, text
from sqlalchemy.orm import joinedload
from storage.database import session_maker
from storage.database import a_session_maker, session_maker
from storage.encrypt_utils import decrypt_legacy_model
from storage.org import Org
from storage.org_member import OrgMember
@@ -372,13 +372,13 @@ class UserStore:
This is the preferred method when calling from an async context as it
avoids event loop conflicts that can occur with the sync version.
"""
with session_maker() as session:
user = (
session.query(User)
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
user = result.scalars().first()
if user:
return user
@@ -392,16 +392,16 @@ class UserStore:
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
# Check for user again as migration could have happened while trying to get the lock.
user = (
session.query(User)
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
user = result.scalars().first()
if user:
return user
user_settings = (
user_settings = await (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == user_id,

View File

@@ -163,7 +163,7 @@ async def test_create_checkout_session_stripe_error(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_saas_environment'),
patch('server.routes.billing.validate_billing_enabled'),
):
await create_checkout_session(
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
@@ -204,7 +204,7 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_saas_environment'),
patch('server.routes.billing.validate_billing_enabled'),
):
mock_db_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_db_session
@@ -236,8 +236,8 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
mode='payment',
payment_method_types=['card'],
saved_payment_method_options={'payment_method_save': 'enabled'},
success_url='http://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
cancel_url='http://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
success_url='https://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
)
# Verify database session creation
@@ -331,7 +331,7 @@ async def test_success_callback_success():
assert response.status_code == 302
assert (
response.headers['location']
== 'http://test.com/settings/billing?checkout=success'
== 'https://test.com/settings/billing?checkout=success'
)
# Verify LiteLLM API calls
@@ -402,7 +402,7 @@ async def test_cancel_callback_session_not_found():
assert response.status_code == 302
assert (
response.headers['location']
== 'http://test.com/settings/billing?checkout=cancel'
== 'https://test.com/settings/billing?checkout=cancel'
)
# Verify no database updates occurred
@@ -429,7 +429,7 @@ async def test_cancel_callback_success():
assert response.status_code == 302
assert (
response.headers['location']
== 'http://test.com/settings/billing?checkout=cancel'
== 'https://test.com/settings/billing?checkout=cancel'
)
# Verify database updates
@@ -490,7 +490,7 @@ async def test_create_customer_setup_session_success():
AsyncMock(return_value=mock_customer_info),
),
patch('stripe.checkout.Session.create_async', mock_create),
patch('server.routes.billing.validate_saas_environment'),
patch('server.routes.billing.validate_billing_enabled'),
):
result = await create_customer_setup_session(mock_request, 'mock_user')
@@ -502,6 +502,6 @@ async def test_create_customer_setup_session_success():
customer='mock-customer-id',
mode='setup',
payment_method_types=['card'],
success_url='http://test.com/?free_credits=success',
cancel_url='http://test.com/',
success_url='https://test.com/?free_credits=success',
cancel_url='https://test.com/',
)