mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
523 lines
19 KiB
Python
523 lines
19 KiB
Python
import uuid
|
|
from decimal import Decimal
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import stripe
|
|
from fastapi import HTTPException, Request, status
|
|
from httpx import Response
|
|
from server.routes import billing
|
|
from server.routes.billing import (
|
|
CreateBillingSessionResponse,
|
|
CreateCheckoutSessionRequest,
|
|
GetCreditsResponse,
|
|
cancel_callback,
|
|
create_checkout_session,
|
|
create_customer_setup_session,
|
|
get_credits,
|
|
has_payment_method,
|
|
success_callback,
|
|
)
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from starlette.datastructures import URL
|
|
from storage.stripe_customer import Base as StripeCustomerBase
|
|
|
|
|
|
@pytest.fixture
|
|
def engine():
|
|
engine = create_engine('sqlite:///:memory:')
|
|
StripeCustomerBase.metadata.create_all(engine)
|
|
return engine
|
|
|
|
|
|
@pytest.fixture
|
|
def session_maker(engine):
|
|
return sessionmaker(bind=engine)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_request():
|
|
"""Create a mock request object with proper URL structure for testing."""
|
|
return Request(
|
|
scope={
|
|
'type': 'http',
|
|
'path': '/api/billing/test',
|
|
'server': ('test.com', 80),
|
|
}
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_checkout_request():
|
|
"""Create a mock request object for checkout session tests."""
|
|
request = Request(
|
|
scope={
|
|
'type': 'http',
|
|
'path': '/api/billing/create-checkout-session',
|
|
'server': ('test.com', 80),
|
|
}
|
|
)
|
|
request._base_url = URL('http://test.com/')
|
|
return request
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_subscription_request():
|
|
"""Create a mock request object for subscription checkout session tests."""
|
|
request = Request(
|
|
scope={
|
|
'type': 'http',
|
|
'path': '/api/billing/subscription-checkout-session',
|
|
'server': ('test.com', 80),
|
|
}
|
|
)
|
|
request._base_url = URL('http://test.com/')
|
|
return request
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_credits_lite_llm_error():
|
|
with (
|
|
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
|
patch(
|
|
'storage.user_store.UserStore.get_user_by_id_async',
|
|
new_callable=AsyncMock,
|
|
return_value=MagicMock(current_org_id='mock_org_id'),
|
|
),
|
|
patch(
|
|
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
|
side_effect=Exception('LiteLLM API Error'),
|
|
),
|
|
):
|
|
with pytest.raises(Exception, match='LiteLLM API Error'):
|
|
await get_credits('mock_user')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_credits_success():
|
|
mock_response = Response(
|
|
status_code=200,
|
|
json={
|
|
'user_info': {
|
|
'spend': 25.50,
|
|
'litellm_budget_table': {'max_budget': 100.00},
|
|
}
|
|
},
|
|
request=MagicMock(),
|
|
)
|
|
mock_client = AsyncMock()
|
|
mock_client.__aenter__.return_value.get.return_value = mock_response
|
|
|
|
with (
|
|
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
|
patch('httpx.AsyncClient', return_value=mock_client),
|
|
patch(
|
|
'storage.user_store.UserStore.get_user_by_id_async',
|
|
new_callable=AsyncMock,
|
|
return_value=MagicMock(current_org_id='mock_org_id'),
|
|
),
|
|
patch(
|
|
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
|
return_value={
|
|
'spend': 25.50,
|
|
'litellm_budget_table': {'max_budget': 100.00},
|
|
},
|
|
),
|
|
):
|
|
result = await get_credits('mock_user')
|
|
|
|
assert isinstance(result, GetCreditsResponse)
|
|
assert result.credits == Decimal('74.50') # 100.00 - 25.50 = 74.50
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_checkout_session_stripe_error(
|
|
session_maker, mock_checkout_request
|
|
):
|
|
"""Test handling of Stripe API errors."""
|
|
|
|
mock_customer = stripe.Customer(
|
|
id='mock-customer', metadata={'user_id': 'mock-user'}
|
|
)
|
|
mock_customer_create = AsyncMock(return_value=mock_customer)
|
|
mock_org = MagicMock()
|
|
mock_org.id = uuid.uuid4()
|
|
mock_org.contact_email = 'testy@tester.com'
|
|
with (
|
|
pytest.raises(Exception, match='Stripe API Error'),
|
|
patch('stripe.Customer.create_async', mock_customer_create),
|
|
patch(
|
|
'stripe.Customer.search_async', AsyncMock(return_value=MagicMock(data=[]))
|
|
),
|
|
patch(
|
|
'stripe.checkout.Session.create_async',
|
|
AsyncMock(side_effect=Exception('Stripe API Error')),
|
|
),
|
|
patch('integrations.stripe_service.session_maker', session_maker),
|
|
patch(
|
|
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
|
return_value=mock_org,
|
|
),
|
|
patch(
|
|
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
|
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
|
),
|
|
patch('server.routes.billing.validate_billing_enabled'),
|
|
):
|
|
await create_checkout_session(
|
|
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_checkout_session_success(session_maker, mock_checkout_request):
|
|
"""Test successful creation of checkout session."""
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.url = 'https://checkout.stripe.com/test-session'
|
|
mock_session.id = 'test_session_id'
|
|
mock_create = AsyncMock(return_value=mock_session)
|
|
mock_create.return_value = mock_session
|
|
|
|
mock_customer = stripe.Customer(
|
|
id='mock-customer', metadata={'user_id': 'mock-user'}
|
|
)
|
|
mock_customer_create = AsyncMock(return_value=mock_customer)
|
|
mock_org = MagicMock()
|
|
mock_org_id = uuid.uuid4()
|
|
mock_org.id = mock_org_id
|
|
mock_org.contact_email = 'testy@tester.com'
|
|
with (
|
|
patch('stripe.Customer.create_async', mock_customer_create),
|
|
patch(
|
|
'stripe.Customer.search_async', AsyncMock(return_value=MagicMock(data=[]))
|
|
),
|
|
patch('stripe.checkout.Session.create_async', mock_create),
|
|
patch('server.routes.billing.session_maker') as mock_session_maker,
|
|
patch('integrations.stripe_service.session_maker', session_maker),
|
|
patch(
|
|
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
|
return_value=mock_org,
|
|
),
|
|
patch(
|
|
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
|
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
|
),
|
|
patch('server.routes.billing.validate_billing_enabled'),
|
|
):
|
|
mock_db_session = MagicMock()
|
|
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
|
|
|
result = await create_checkout_session(
|
|
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
|
)
|
|
|
|
assert isinstance(result, CreateBillingSessionResponse)
|
|
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
|
|
|
# Verify Stripe session creation parameters
|
|
mock_create.assert_called_once_with(
|
|
customer='mock-customer',
|
|
line_items=[
|
|
{
|
|
'price_data': {
|
|
'unit_amount': 2500,
|
|
'currency': 'usd',
|
|
'product_data': {
|
|
'name': 'OpenHands Credits',
|
|
'tax_code': 'txcd_10000000',
|
|
},
|
|
'tax_behavior': 'exclusive',
|
|
},
|
|
'quantity': 1,
|
|
}
|
|
],
|
|
mode='payment',
|
|
payment_method_types=['card'],
|
|
saved_payment_method_options={'payment_method_save': 'enabled'},
|
|
success_url='https://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
|
|
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
|
|
)
|
|
|
|
# Verify database session creation
|
|
mock_db_session.add.assert_called_once()
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_callback_session_not_found():
|
|
"""Test success callback when billing session is not found."""
|
|
mock_request = Request(scope={'type': 'http'})
|
|
mock_request._base_url = URL('http://test.com/')
|
|
|
|
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
|
mock_db_session = MagicMock()
|
|
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
|
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await success_callback('test_session_id', mock_request)
|
|
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
|
mock_db_session.merge.assert_not_called()
|
|
mock_db_session.commit.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_callback_stripe_incomplete():
|
|
"""Test success callback when Stripe session is not complete."""
|
|
mock_request = Request(scope={'type': 'http'})
|
|
mock_request._base_url = URL('http://test.com/')
|
|
|
|
mock_billing_session = MagicMock()
|
|
mock_billing_session.status = 'in_progress'
|
|
mock_billing_session.user_id = 'mock_user'
|
|
|
|
with (
|
|
patch('server.routes.billing.session_maker') as mock_session_maker,
|
|
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
|
):
|
|
mock_db_session = MagicMock()
|
|
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
|
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
|
|
|
mock_stripe_retrieve.return_value = MagicMock(status='pending')
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await success_callback('test_session_id', mock_request)
|
|
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
|
mock_db_session.merge.assert_not_called()
|
|
mock_db_session.commit.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_callback_success():
|
|
"""Test successful payment completion and credit update."""
|
|
mock_request = Request(scope={'type': 'http'})
|
|
mock_request._base_url = URL('http://test.com/')
|
|
|
|
mock_billing_session = MagicMock()
|
|
mock_billing_session.status = 'in_progress'
|
|
mock_billing_session.user_id = 'mock_user'
|
|
|
|
mock_org = MagicMock()
|
|
|
|
with (
|
|
patch('server.routes.billing.session_maker') as mock_session_maker,
|
|
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
|
patch(
|
|
'storage.user_store.UserStore.get_user_by_id_async',
|
|
new_callable=AsyncMock,
|
|
return_value=MagicMock(current_org_id='mock_org_id'),
|
|
),
|
|
patch(
|
|
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
|
return_value={
|
|
'spend': 25.50,
|
|
'litellm_budget_table': {'max_budget': 100.00},
|
|
},
|
|
),
|
|
patch(
|
|
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
|
) as mock_update_budget,
|
|
):
|
|
mock_db_session = MagicMock()
|
|
# First query: BillingSession (query().filter().filter().first())
|
|
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
|
# Second query: Org (query().filter().first()) - use side_effect for different return chains
|
|
mock_query_chain_billing = MagicMock()
|
|
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
|
mock_query_chain_org = MagicMock()
|
|
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
|
mock_db_session.query.side_effect = [
|
|
mock_query_chain_billing,
|
|
mock_query_chain_org,
|
|
]
|
|
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
|
|
|
mock_stripe_retrieve.return_value = MagicMock(
|
|
status='complete', amount_subtotal=2500, customer='mock_customer_id'
|
|
) # $25.00 in cents
|
|
|
|
response = await success_callback('test_session_id', mock_request)
|
|
|
|
assert response.status_code == 302
|
|
assert (
|
|
response.headers['location']
|
|
== 'https://test.com/settings/billing?checkout=success'
|
|
)
|
|
|
|
# Verify LiteLLM API calls
|
|
mock_update_budget.assert_called_once_with(
|
|
'mock_org_id',
|
|
125.0, # 100 + (25.00 from Stripe)
|
|
)
|
|
|
|
# Verify BYOR export is enabled for the org (updated in same session)
|
|
assert mock_org.byor_export_enabled is True
|
|
|
|
# Verify database updates
|
|
assert mock_billing_session.status == 'completed'
|
|
assert mock_billing_session.price == 25.0
|
|
mock_db_session.merge.assert_called_once()
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_callback_lite_llm_error():
|
|
"""Test handling of LiteLLM API errors during success callback."""
|
|
mock_request = Request(scope={'type': 'http'})
|
|
mock_request._base_url = URL('http://test.com/')
|
|
|
|
mock_billing_session = MagicMock()
|
|
mock_billing_session.status = 'in_progress'
|
|
mock_billing_session.user_id = 'mock_user'
|
|
|
|
with (
|
|
patch('server.routes.billing.session_maker') as mock_session_maker,
|
|
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
|
patch(
|
|
'storage.user_store.UserStore.get_user_by_id_async',
|
|
new_callable=AsyncMock,
|
|
return_value=MagicMock(current_org_id='mock_org_id'),
|
|
),
|
|
patch(
|
|
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
|
side_effect=Exception('LiteLLM API Error'),
|
|
),
|
|
):
|
|
mock_db_session = MagicMock()
|
|
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
|
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
|
|
|
mock_stripe_retrieve.return_value = MagicMock(
|
|
status='complete', amount_subtotal=2500
|
|
)
|
|
|
|
with pytest.raises(Exception, match='LiteLLM API Error'):
|
|
await success_callback('test_session_id', mock_request)
|
|
|
|
# Verify no database updates occurred
|
|
assert mock_billing_session.status == 'in_progress'
|
|
mock_db_session.merge.assert_not_called()
|
|
mock_db_session.commit.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_callback_session_not_found():
|
|
"""Test cancel callback when billing session is not found."""
|
|
mock_request = Request(scope={'type': 'http'})
|
|
mock_request._base_url = URL('http://test.com/')
|
|
|
|
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
|
mock_db_session = MagicMock()
|
|
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
|
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
|
|
|
response = await cancel_callback('test_session_id', mock_request)
|
|
assert response.status_code == 302
|
|
assert (
|
|
response.headers['location']
|
|
== 'https://test.com/settings/billing?checkout=cancel'
|
|
)
|
|
|
|
# Verify no database updates occurred
|
|
mock_db_session.merge.assert_not_called()
|
|
mock_db_session.commit.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_callback_success():
|
|
"""Test successful cancellation of billing session."""
|
|
mock_request = Request(scope={'type': 'http'})
|
|
mock_request._base_url = URL('http://test.com/')
|
|
|
|
mock_billing_session = MagicMock()
|
|
mock_billing_session.status = 'in_progress'
|
|
|
|
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
|
mock_db_session = MagicMock()
|
|
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
|
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
|
|
|
response = await cancel_callback('test_session_id', mock_request)
|
|
|
|
assert response.status_code == 302
|
|
assert (
|
|
response.headers['location']
|
|
== 'https://test.com/settings/billing?checkout=cancel'
|
|
)
|
|
|
|
# Verify database updates
|
|
assert mock_billing_session.status == 'cancelled'
|
|
mock_db_session.merge.assert_called_once()
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_has_payment_method_with_payment_method():
|
|
"""Test has_payment_method returns True when user has a payment method."""
|
|
|
|
mock_has_payment_method = AsyncMock(return_value=True)
|
|
with patch(
|
|
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
|
mock_has_payment_method,
|
|
):
|
|
result = await has_payment_method('mock_user')
|
|
assert result is True
|
|
mock_has_payment_method.assert_called_once_with('mock_user')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_has_payment_method_without_payment_method():
|
|
"""Test has_payment_method returns False when user has no payment method."""
|
|
mock_has_payment_method = AsyncMock(return_value=False)
|
|
with patch(
|
|
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
|
mock_has_payment_method,
|
|
):
|
|
mock_has_payment_method.return_value = False
|
|
result = await has_payment_method('mock_user')
|
|
assert result is False
|
|
mock_has_payment_method.assert_called_once_with('mock_user')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_customer_setup_session_success():
|
|
"""Test successful creation of customer setup session."""
|
|
mock_request = Request(
|
|
scope={
|
|
'type': 'http',
|
|
'path': '/api/billing/create-customer-setup-session',
|
|
'server': ('test.com', 80),
|
|
'headers': [],
|
|
}
|
|
)
|
|
mock_request._base_url = URL('http://test.com/')
|
|
|
|
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
|
|
mock_session = MagicMock()
|
|
mock_session.url = 'https://checkout.stripe.com/test-session'
|
|
mock_create = AsyncMock(return_value=mock_session)
|
|
|
|
with (
|
|
patch(
|
|
'integrations.stripe_service.find_or_create_customer_by_user_id',
|
|
AsyncMock(return_value=mock_customer_info),
|
|
),
|
|
patch('stripe.checkout.Session.create_async', mock_create),
|
|
patch('server.routes.billing.validate_billing_enabled'),
|
|
):
|
|
result = await create_customer_setup_session(mock_request, 'mock_user')
|
|
|
|
assert isinstance(result, billing.CreateBillingSessionResponse)
|
|
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
|
|
|
# Verify Stripe session creation parameters
|
|
mock_create.assert_called_once_with(
|
|
customer='mock-customer-id',
|
|
mode='setup',
|
|
payment_method_types=['card'],
|
|
success_url='https://test.com/?free_credits=success',
|
|
cancel_url='https://test.com/',
|
|
)
|