mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
fix: disable pro subscription upgrade on LLM page for self-hosted installs (#11479)
This commit is contained in:
parent
523b40dbfc
commit
134c122026
@ -31,6 +31,37 @@ stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
|
||||
|
||||
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
|
||||
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
|
||||
def is_all_hands_saas_environment(request: Request) -> bool:
|
||||
"""Check if the current domain is an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
|
||||
"""
|
||||
hostname = request.url.hostname or ''
|
||||
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
|
||||
|
||||
|
||||
def validate_saas_environment(request: Request) -> None:
|
||||
"""Validate that the request is coming from an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the request is not from an All Hands SaaS environment
|
||||
"""
|
||||
if not is_all_hands_saas_environment(request):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Checkout sessions are only available for All Hands SaaS environments',
|
||||
)
|
||||
|
||||
|
||||
class BillingSessionType(Enum):
|
||||
DIRECT_PAYMENT = 'DIRECT_PAYMENT'
|
||||
MONTHLY_SUBSCRIPTION = 'MONTHLY_SUBSCRIPTION'
|
||||
@ -196,6 +227,8 @@ async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONRespon
|
||||
async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
@ -214,6 +247,8 @@ async def create_checkout_session(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
@ -268,6 +303,8 @@ async def create_subscription_checkout_session(
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
# Prevent duplicate subscriptions for the same user
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
@ -343,6 +380,8 @@ async def create_subscription_checkout_session_via_get(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> RedirectResponse:
|
||||
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
|
||||
validate_saas_environment(request)
|
||||
|
||||
response = await create_subscription_checkout_session(
|
||||
request, billing_session_type, user_id
|
||||
)
|
||||
|
||||
@ -36,6 +36,46 @@ def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object with proper URL structure for testing."""
|
||||
return Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/test',
|
||||
'server': ('test.com', 80),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_checkout_request():
|
||||
"""Create a mock request object for checkout session tests."""
|
||||
request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/create-checkout-session',
|
||||
'server': ('test.com', 80),
|
||||
}
|
||||
)
|
||||
request._base_url = URL('http://test.com/')
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_subscription_request():
|
||||
"""Create a mock request object for subscription checkout session tests."""
|
||||
request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/subscription-checkout-session',
|
||||
'server': ('test.com', 80),
|
||||
}
|
||||
)
|
||||
request._base_url = URL('http://test.com/')
|
||||
return request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
|
||||
@ -90,14 +130,10 @@ async def test_get_credits_success():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkout_session_stripe_error(session_maker):
|
||||
async def test_create_checkout_session_stripe_error(
|
||||
session_maker, mock_checkout_request
|
||||
):
|
||||
"""Test handling of Stripe API errors."""
|
||||
mock_request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
}
|
||||
)
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_customer = stripe.Customer(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
@ -118,17 +154,16 @@ async def test_create_checkout_session_stripe_error(session_maker):
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
await create_checkout_session(
|
||||
CreateCheckoutSessionRequest(amount=25), mock_request, 'mock_user'
|
||||
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkout_session_success(session_maker):
|
||||
async def test_create_checkout_session_success(session_maker, mock_checkout_request):
|
||||
"""Test successful creation of checkout session."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.url = 'https://checkout.stripe.com/test-session'
|
||||
@ -152,12 +187,13 @@ async def test_create_checkout_session_success(session_maker):
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
result = await create_checkout_session(
|
||||
CreateCheckoutSessionRequest(amount=25), mock_request, 'mock_user'
|
||||
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
@ -590,7 +626,9 @@ async def test_cancel_subscription_stripe_error():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_duplicate_prevention():
|
||||
async def test_create_subscription_checkout_session_duplicate_prevention(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription when user already has active subscription raises error."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
@ -609,11 +647,9 @@ async def test_create_subscription_checkout_session_duplicate_prevention():
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return existing active subscription
|
||||
mock_session = MagicMock()
|
||||
@ -623,7 +659,7 @@ async def test_create_subscription_checkout_session_duplicate_prevention():
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_subscription_checkout_session(
|
||||
mock_request, user_id='test_user'
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
@ -634,10 +670,10 @@ async def test_create_subscription_checkout_session_duplicate_prevention():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_allows_after_cancellation():
|
||||
async def test_create_subscription_checkout_session_allows_after_cancellation(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
@ -657,6 +693,7 @@ async def test_create_subscription_checkout_session_allows_after_cancellation():
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
|
||||
mock_session = MagicMock()
|
||||
@ -665,7 +702,7 @@ async def test_create_subscription_checkout_session_allows_after_cancellation():
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_request, user_id='test_user'
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
@ -673,10 +710,10 @@ async def test_create_subscription_checkout_session_allows_after_cancellation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_success_no_existing():
|
||||
async def test_create_subscription_checkout_session_success_no_existing(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test successful subscription creation when no existing subscription."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
@ -696,6 +733,7 @@ async def test_create_subscription_checkout_session_success_no_existing():
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return no existing subscription
|
||||
mock_session = MagicMock()
|
||||
@ -704,7 +742,7 @@ async def test_create_subscription_checkout_session_success_no_existing():
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_request, user_id='test_user'
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
|
||||
@ -25,6 +25,12 @@ vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
useIsAuthed: () => mockUseIsAuthed(),
|
||||
}));
|
||||
|
||||
// Mock useIsAllHandsSaaSEnvironment hook
|
||||
const mockUseIsAllHandsSaaSEnvironment = vi.fn();
|
||||
vi.mock("#/hooks/use-is-all-hands-saas-environment", () => ({
|
||||
useIsAllHandsSaaSEnvironment: () => mockUseIsAllHandsSaaSEnvironment(),
|
||||
}));
|
||||
|
||||
const renderLlmSettingsScreen = () =>
|
||||
render(<LlmSettingsScreen />, {
|
||||
wrapper: ({ children }) => (
|
||||
@ -48,6 +54,9 @@ beforeEach(() => {
|
||||
|
||||
// Default mock for useIsAuthed - returns authenticated by default
|
||||
mockUseIsAuthed.mockReturnValue({ data: true, isLoading: false });
|
||||
|
||||
// Default mock for useIsAllHandsSaaSEnvironment - returns true for SaaS environment
|
||||
mockUseIsAllHandsSaaSEnvironment.mockReturnValue(true);
|
||||
});
|
||||
|
||||
describe("Content", () => {
|
||||
@ -104,7 +113,6 @@ describe("Content", () => {
|
||||
expect(screen.getByTestId("set-indicator")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
describe("Advanced form", () => {
|
||||
|
||||
13
frontend/src/hooks/use-is-all-hands-saas-environment.ts
Normal file
13
frontend/src/hooks/use-is-all-hands-saas-environment.ts
Normal file
@ -0,0 +1,13 @@
|
||||
import { useMemo } from "react";
|
||||
|
||||
/**
|
||||
* Hook to check if the current domain is an All Hands SaaS environment
|
||||
* @returns True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
|
||||
*/
|
||||
export const useIsAllHandsSaaSEnvironment = (): boolean =>
|
||||
useMemo(() => {
|
||||
const { hostname } = window.location;
|
||||
return (
|
||||
hostname.endsWith("all-hands.dev") || hostname.endsWith("openhands.dev")
|
||||
);
|
||||
}, []);
|
||||
@ -33,6 +33,7 @@ import { UpgradeBannerWithBackdrop } from "#/components/features/settings/upgrad
|
||||
import { useCreateSubscriptionCheckoutSession } from "#/hooks/mutation/stripe/use-create-subscription-checkout-session";
|
||||
import { useIsAuthed } from "#/hooks/query/use-is-authed";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { useIsAllHandsSaaSEnvironment } from "#/hooks/use-is-all-hands-saas-environment";
|
||||
|
||||
interface OpenHandsApiKeyHelpProps {
|
||||
testId: string;
|
||||
@ -78,6 +79,7 @@ function LlmSettingsScreen() {
|
||||
const { data: isAuthed } = useIsAuthed();
|
||||
const { mutate: createSubscriptionCheckoutSession } =
|
||||
useCreateSubscriptionCheckoutSession();
|
||||
const isAllHandsSaaSEnvironment = useIsAllHandsSaaSEnvironment();
|
||||
|
||||
const [view, setView] = React.useState<"basic" | "advanced">("basic");
|
||||
|
||||
@ -441,8 +443,11 @@ function LlmSettingsScreen() {
|
||||
if (!settings || isFetching) return <LlmSettingsInputsSkeleton />;
|
||||
|
||||
// Show upgrade banner and disable form in SaaS mode when user doesn't have an active subscription
|
||||
// Exclude self-hosted enterprise customers (those not on all-hands.dev domains)
|
||||
const shouldShowUpgradeBanner =
|
||||
config?.APP_MODE === "saas" && !subscriptionAccess;
|
||||
config?.APP_MODE === "saas" &&
|
||||
!subscriptionAccess &&
|
||||
isAllHandsSaaSEnvironment;
|
||||
|
||||
const formAction = (formData: FormData) => {
|
||||
// Prevent form submission for unsubscribed SaaS users
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user