fix: disable pro subscription upgrade on LLM page for self-hosted installs (#11479)

This commit is contained in:
Hiep Le 2025-10-23 01:11:04 +07:00 committed by GitHub
parent 523b40dbfc
commit 134c122026
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 130 additions and 27 deletions

View File

@ -31,6 +31,37 @@ stripe.api_key = STRIPE_API_KEY
billing_router = APIRouter(prefix='/api/billing')
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
def is_all_hands_saas_environment(request: Request) -> bool:
"""Check if the current domain is an All Hands SaaS environment.
Args:
request: FastAPI Request object
Returns:
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
"""
hostname = request.url.hostname or ''
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
def validate_saas_environment(request: Request) -> None:
"""Validate that the request is coming from an All Hands SaaS environment.
Args:
request: FastAPI Request object
Raises:
HTTPException: If the request is not from an All Hands SaaS environment
"""
if not is_all_hands_saas_environment(request):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='Checkout sessions are only available for All Hands SaaS environments',
)
class BillingSessionType(Enum):
DIRECT_PAYMENT = 'DIRECT_PAYMENT'
MONTHLY_SUBSCRIPTION = 'MONTHLY_SUBSCRIPTION'
@ -196,6 +227,8 @@ async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONRespon
async def create_customer_setup_session(
request: Request, user_id: str = Depends(get_user_id)
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
customer_id = await stripe_service.find_or_create_customer(user_id)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_id,
@ -214,6 +247,8 @@ async def create_checkout_session(
request: Request,
user_id: str = Depends(get_user_id),
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
customer_id = await stripe_service.find_or_create_customer(user_id)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_id,
@ -268,6 +303,8 @@ async def create_subscription_checkout_session(
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
user_id: str = Depends(get_user_id),
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
# Prevent duplicate subscriptions for the same user
with session_maker() as session:
now = datetime.now(UTC)
@ -343,6 +380,8 @@ async def create_subscription_checkout_session_via_get(
user_id: str = Depends(get_user_id),
) -> RedirectResponse:
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
validate_saas_environment(request)
response = await create_subscription_checkout_session(
request, billing_session_type, user_id
)

View File

@ -36,6 +36,46 @@ def session_maker(engine):
return sessionmaker(bind=engine)
@pytest.fixture
def mock_request():
"""Create a mock request object with proper URL structure for testing."""
return Request(
scope={
'type': 'http',
'path': '/api/billing/test',
'server': ('test.com', 80),
}
)
@pytest.fixture
def mock_checkout_request():
"""Create a mock request object for checkout session tests."""
request = Request(
scope={
'type': 'http',
'path': '/api/billing/create-checkout-session',
'server': ('test.com', 80),
}
)
request._base_url = URL('http://test.com/')
return request
@pytest.fixture
def mock_subscription_request():
"""Create a mock request object for subscription checkout session tests."""
request = Request(
scope={
'type': 'http',
'path': '/api/billing/subscription-checkout-session',
'server': ('test.com', 80),
}
)
request._base_url = URL('http://test.com/')
return request
@pytest.mark.asyncio
async def test_get_credits_lite_llm_error():
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
@ -90,14 +130,10 @@ async def test_get_credits_success():
@pytest.mark.asyncio
async def test_create_checkout_session_stripe_error(session_maker):
async def test_create_checkout_session_stripe_error(
session_maker, mock_checkout_request
):
"""Test handling of Stripe API errors."""
mock_request = Request(
scope={
'type': 'http',
}
)
mock_request._base_url = URL('http://test.com/')
mock_customer = stripe.Customer(
id='mock-customer', metadata={'user_id': 'mock-user'}
@ -118,17 +154,16 @@ async def test_create_checkout_session_stripe_error(session_maker):
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_saas_environment'),
):
await create_checkout_session(
CreateCheckoutSessionRequest(amount=25), mock_request, 'mock_user'
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
)
@pytest.mark.asyncio
async def test_create_checkout_session_success(session_maker):
async def test_create_checkout_session_success(session_maker, mock_checkout_request):
"""Test successful creation of checkout session."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_session = MagicMock()
mock_session.url = 'https://checkout.stripe.com/test-session'
@ -152,12 +187,13 @@ async def test_create_checkout_session_success(session_maker):
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_saas_environment'),
):
mock_db_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_db_session
result = await create_checkout_session(
CreateCheckoutSessionRequest(amount=25), mock_request, 'mock_user'
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
)
assert isinstance(result, CreateBillingSessionResponse)
@ -590,7 +626,9 @@ async def test_cancel_subscription_stripe_error():
@pytest.mark.asyncio
async def test_create_subscription_checkout_session_duplicate_prevention():
async def test_create_subscription_checkout_session_duplicate_prevention(
mock_subscription_request,
):
"""Test that creating a subscription when user already has active subscription raises error."""
from datetime import UTC, datetime
@ -609,11 +647,9 @@ async def test_create_subscription_checkout_session_duplicate_prevention():
cancelled_at=None,
)
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('server.routes.billing.validate_saas_environment'),
):
# Setup mock session to return existing active subscription
mock_session = MagicMock()
@ -623,7 +659,7 @@ async def test_create_subscription_checkout_session_duplicate_prevention():
# Call the function and expect HTTPException
with pytest.raises(HTTPException) as exc_info:
await create_subscription_checkout_session(
mock_request, user_id='test_user'
mock_subscription_request, user_id='test_user'
)
assert exc_info.value.status_code == 400
@ -634,10 +670,10 @@ async def test_create_subscription_checkout_session_duplicate_prevention():
@pytest.mark.asyncio
async def test_create_subscription_checkout_session_allows_after_cancellation():
async def test_create_subscription_checkout_session_allows_after_cancellation(
mock_subscription_request,
):
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_session_obj = MagicMock()
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
@ -657,6 +693,7 @@ async def test_create_subscription_checkout_session_allows_after_cancellation():
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
),
patch('server.routes.billing.validate_saas_environment'),
):
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
mock_session = MagicMock()
@ -665,7 +702,7 @@ async def test_create_subscription_checkout_session_allows_after_cancellation():
# Should succeed
result = await create_subscription_checkout_session(
mock_request, user_id='test_user'
mock_subscription_request, user_id='test_user'
)
assert isinstance(result, CreateBillingSessionResponse)
@ -673,10 +710,10 @@ async def test_create_subscription_checkout_session_allows_after_cancellation():
@pytest.mark.asyncio
async def test_create_subscription_checkout_session_success_no_existing():
async def test_create_subscription_checkout_session_success_no_existing(
mock_subscription_request,
):
"""Test successful subscription creation when no existing subscription."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_session_obj = MagicMock()
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
@ -696,6 +733,7 @@ async def test_create_subscription_checkout_session_success_no_existing():
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
),
patch('server.routes.billing.validate_saas_environment'),
):
# Setup mock session to return no existing subscription
mock_session = MagicMock()
@ -704,7 +742,7 @@ async def test_create_subscription_checkout_session_success_no_existing():
# Should succeed
result = await create_subscription_checkout_session(
mock_request, user_id='test_user'
mock_subscription_request, user_id='test_user'
)
assert isinstance(result, CreateBillingSessionResponse)

View File

@ -25,6 +25,12 @@ vi.mock("#/hooks/query/use-is-authed", () => ({
useIsAuthed: () => mockUseIsAuthed(),
}));
// Mock useIsAllHandsSaaSEnvironment hook
const mockUseIsAllHandsSaaSEnvironment = vi.fn();
vi.mock("#/hooks/use-is-all-hands-saas-environment", () => ({
useIsAllHandsSaaSEnvironment: () => mockUseIsAllHandsSaaSEnvironment(),
}));
const renderLlmSettingsScreen = () =>
render(<LlmSettingsScreen />, {
wrapper: ({ children }) => (
@ -48,6 +54,9 @@ beforeEach(() => {
// Default mock for useIsAuthed - returns authenticated by default
mockUseIsAuthed.mockReturnValue({ data: true, isLoading: false });
// Default mock for useIsAllHandsSaaSEnvironment - returns true for SaaS environment
mockUseIsAllHandsSaaSEnvironment.mockReturnValue(true);
});
describe("Content", () => {
@ -104,7 +113,6 @@ describe("Content", () => {
expect(screen.getByTestId("set-indicator")).toBeInTheDocument();
});
});
});
describe("Advanced form", () => {

View File

@ -0,0 +1,13 @@
import { useMemo } from "react";
/**
* Hook to check if the current domain is an All Hands SaaS environment
* @returns True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
*/
export const useIsAllHandsSaaSEnvironment = (): boolean =>
useMemo(() => {
const { hostname } = window.location;
return (
hostname.endsWith("all-hands.dev") || hostname.endsWith("openhands.dev")
);
}, []);

View File

@ -33,6 +33,7 @@ import { UpgradeBannerWithBackdrop } from "#/components/features/settings/upgrad
import { useCreateSubscriptionCheckoutSession } from "#/hooks/mutation/stripe/use-create-subscription-checkout-session";
import { useIsAuthed } from "#/hooks/query/use-is-authed";
import { cn } from "#/utils/utils";
import { useIsAllHandsSaaSEnvironment } from "#/hooks/use-is-all-hands-saas-environment";
interface OpenHandsApiKeyHelpProps {
testId: string;
@ -78,6 +79,7 @@ function LlmSettingsScreen() {
const { data: isAuthed } = useIsAuthed();
const { mutate: createSubscriptionCheckoutSession } =
useCreateSubscriptionCheckoutSession();
const isAllHandsSaaSEnvironment = useIsAllHandsSaaSEnvironment();
const [view, setView] = React.useState<"basic" | "advanced">("basic");
@ -441,8 +443,11 @@ function LlmSettingsScreen() {
if (!settings || isFetching) return <LlmSettingsInputsSkeleton />;
// Show upgrade banner and disable form in SaaS mode when user doesn't have an active subscription
// Exclude self-hosted enterprise customers (those not on all-hands.dev domains)
const shouldShowUpgradeBanner =
config?.APP_MODE === "saas" && !subscriptionAccess;
config?.APP_MODE === "saas" &&
!subscriptionAccess &&
isAllHandsSaaSEnvironment;
const formAction = (formData: FormData) => {
// Prevent form submission for unsubscribed SaaS users