From db40eb1e94485535e61ff762957fe6ae7dc5bfad Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Tue, 10 Mar 2026 13:11:33 -0600 Subject: [PATCH] Using the web_url where it is configured rather than the request.url (#13319) Co-authored-by: openhands --- enterprise/server/middleware.py | 11 +- enterprise/server/routes/auth.py | 85 +--- enterprise/server/routes/billing.py | 24 +- enterprise/server/routes/email.py | 13 +- enterprise/server/routes/oauth_device.py | 3 +- enterprise/server/utils/url_utils.py | 38 ++ ..._saas_sql_app_conversation_info_service.py | 6 +- enterprise/tests/unit/test_auth_routes.py | 83 +--- enterprise/tests/unit/test_billing.py | 24 +- enterprise/tests/unit/utils/__init__.py | 1 + enterprise/tests/unit/utils/test_url_utils.py | 425 ++++++++++++++++++ 11 files changed, 529 insertions(+), 184 deletions(-) create mode 100644 enterprise/server/utils/url_utils.py create mode 100644 enterprise/tests/unit/utils/__init__.py create mode 100644 enterprise/tests/unit/utils/test_url_utils.py diff --git a/enterprise/server/middleware.py b/enterprise/server/middleware.py index 561b106418..659a66046a 100644 --- a/enterprise/server/middleware.py +++ b/enterprise/server/middleware.py @@ -12,11 +12,8 @@ from server.auth.auth_error import ( ) from server.auth.gitlab_sync import schedule_gitlab_repo_sync from server.auth.saas_user_auth import SaasUserAuth, token_manager -from server.routes.auth import ( - get_cookie_domain, - get_cookie_samesite, - set_response_cookie, -) +from server.routes.auth import set_response_cookie +from server.utils.url_utils import get_cookie_domain, get_cookie_samesite from openhands.core.logger import openhands_logger as logger from openhands.server.user_auth.user_auth import AuthType, UserAuth, get_user_auth @@ -93,8 +90,8 @@ class SetAuthCookieMiddleware: if keycloak_auth_cookie: response.delete_cookie( key='keycloak_auth', - domain=get_cookie_domain(request), - samesite=get_cookie_samesite(request), + domain=get_cookie_domain(), + samesite=get_cookie_samesite(), ) return response diff --git a/enterprise/server/routes/auth.py b/enterprise/server/routes/auth.py index 5bd3b755d9..4118969274 100644 --- a/enterprise/server/routes/auth.py +++ b/enterprise/server/routes/auth.py @@ -3,7 +3,7 @@ import json import uuid import warnings from datetime import datetime, timezone -from typing import Annotated, Literal, Optional, cast +from typing import Annotated, Optional, cast from urllib.parse import quote, urlencode from uuid import UUID as parse_uuid @@ -27,7 +27,7 @@ from server.auth.user.user_authorizer import ( depends_user_authorizer, ) from server.config import sign_token -from server.constants import IS_FEATURE_ENV +from server.constants import IS_FEATURE_ENV, IS_LOCAL_ENV from server.routes.event_webhook import _get_session_api_key, _get_user_id from server.services.org_invitation_service import ( EmailMismatchError, @@ -37,12 +37,12 @@ from server.services.org_invitation_service import ( UserAlreadyMemberError, ) from server.utils.rate_limit_utils import check_rate_limit_by_user_id +from server.utils.url_utils import get_cookie_domain, get_cookie_samesite, get_web_url from sqlalchemy import select from storage.database import a_session_maker from storage.user import User from storage.user_store import UserStore -from openhands.app_server.config import get_global_config from openhands.core.logger import openhands_logger as logger from openhands.integrations.provider import ProviderHandler from openhands.integrations.service_types import ProviderType, TokenResponse @@ -77,7 +77,7 @@ def set_response_cookie( signed_token = sign_token(cookie_data, config.jwt_secret.get_secret_value()) # type: ignore # Set secure cookie with signed token - domain = get_cookie_domain(request) + domain = get_cookie_domain() if domain: response.set_cookie( key='keycloak_auth', @@ -85,7 +85,7 @@ def set_response_cookie( domain=domain, httponly=True, secure=secure, - samesite=get_cookie_samesite(request), + samesite=get_cookie_samesite(), ) else: response.set_cookie( @@ -93,30 +93,10 @@ def set_response_cookie( value=signed_token, httponly=True, secure=secure, - samesite=get_cookie_samesite(request), + samesite=get_cookie_samesite(), ) -def get_cookie_domain(request: Request) -> str | None: - # for now just use the full hostname except for staging stacks. - return ( - None - if not request.url.hostname - or request.url.hostname.endswith('staging.all-hands.dev') - else request.url.hostname - ) - - -def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']: - # for localhost and feature/staging stacks we set it to 'lax' as the cookie domain won't allow 'strict' - return ( - 'lax' - if request.url.hostname == 'localhost' - or (request.url.hostname or '').endswith('staging.all-hands.dev') - else 'strict' - ) - - def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]: """Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state. @@ -140,19 +120,6 @@ def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None return state, None, None -# Keep alias for backward compatibility -def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]: - """Extract redirect URL and reCAPTCHA token from OAuth state. - - Deprecated: Use _extract_oauth_state instead. - - Returns: - Tuple of (redirect_url, recaptcha_token). Token may be None. - """ - redirect_url, recaptcha_token, _ = _extract_oauth_state(state) - return redirect_url, recaptcha_token - - @oauth_router.get('/keycloak/callback') async def keycloak_callback( request: Request, @@ -183,10 +150,7 @@ async def keycloak_callback( detail='Missing code in request params', ) - web_url = get_global_config().web_url - if not web_url: - scheme = 'http' if request.url.hostname == 'localhost' else 'https' - web_url = f'{scheme}://{request.url.netloc}' + web_url = get_web_url(request) redirect_uri = web_url + request.url.path ( @@ -313,7 +277,9 @@ async def keycloak_callback( else: raise - verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}' + verification_redirect_url = ( + f'{web_url}/login?email_verification_required=true&user_id={user_id}' + ) if rate_limited: verification_redirect_url = f'{verification_redirect_url}&rate_limited=true' @@ -474,9 +440,7 @@ async def keycloak_callback( # If the user hasn't accepted the TOS, redirect to the TOS page if not has_accepted_tos: encoded_redirect_url = quote(redirect_url, safe='') - tos_redirect_url = ( - f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}' - ) + tos_redirect_url = f'{web_url}/accept-tos?redirect_url={encoded_redirect_url}' if invitation_token: tos_redirect_url = f'{tos_redirect_url}&invitation_success=true' response = RedirectResponse(tos_redirect_url, status_code=302) @@ -508,10 +472,9 @@ async def keycloak_offline_callback(code: str, state: str, request: Request): status_code=status.HTTP_400_BAD_REQUEST, content={'error': 'Missing code in request params'}, ) - scheme = 'https' - if request.url.hostname == 'localhost': - scheme = 'http' - redirect_uri = f'{scheme}://{request.url.netloc}{request.url.path}' + + web_url = get_web_url(request) + redirect_uri = web_url + request.url.path logger.debug(f'code: {code}, redirect_uri: {redirect_uri}') ( @@ -533,15 +496,14 @@ async def keycloak_offline_callback(code: str, state: str, request: Request): ) redirect_url, _, _ = _extract_oauth_state(state) - return RedirectResponse( - redirect_url if redirect_url else request.base_url, status_code=302 - ) + return RedirectResponse(redirect_url if redirect_url else web_url, status_code=302) @oauth_router.get('/github/callback') async def github_dummy_callback(request: Request): """Callback for GitHub that just forwards the user to the app base URL.""" - return RedirectResponse(request.base_url, status_code=302) + web_url = get_web_url(request) + return RedirectResponse(web_url, status_code=302) @api_router.post('/authenticate') @@ -563,8 +525,8 @@ async def authenticate(request: Request): if keycloak_auth_cookie: response.delete_cookie( key='keycloak_auth', - domain=get_cookie_domain(request), - samesite=get_cookie_samesite(request), + domain=get_cookie_domain(), + samesite=get_cookie_samesite(), ) return response @@ -588,7 +550,8 @@ async def accept_tos(request: Request): # Get redirect URL from request body body = await request.json() - redirect_url = body.get('redirect_url', str(request.base_url)) + web_url = get_web_url(request) + redirect_url = body.get('redirect_url', str(web_url)) # Update user settings with TOS acceptance accepted_tos: datetime = datetime.now(timezone.utc).replace(tzinfo=None) @@ -618,7 +581,7 @@ async def accept_tos(request: Request): response=response, keycloak_access_token=access_token.get_secret_value(), keycloak_refresh_token=refresh_token.get_secret_value(), - secure=False if request.url.hostname == 'localhost' else True, + secure=not IS_LOCAL_ENV, accepted_tos=True, ) return response @@ -635,8 +598,8 @@ async def logout(request: Request): # Always delete the cookie regardless of what happens response.delete_cookie( key='keycloak_auth', - domain=get_cookie_domain(request), - samesite=get_cookie_samesite(request), + domain=get_cookie_domain(), + samesite=get_cookie_samesite(), ) # Try to properly logout from Keycloak, but don't fail if it doesn't work diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index cf8b72b689..71241df457 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -11,8 +11,8 @@ from integrations import stripe_service from pydantic import BaseModel from server.constants import STRIPE_API_KEY from server.logger import logger +from server.utils.url_utils import get_web_url from sqlalchemy import select -from starlette.datastructures import URL from storage.billing_session import BillingSession from storage.database import a_session_maker from storage.lite_llm_manager import LiteLlmManager @@ -151,7 +151,7 @@ async def create_customer_setup_session( status_code=status.HTTP_400_BAD_REQUEST, detail='Could not find or create customer for user', ) - base_url = _get_base_url(request) + base_url = get_web_url(request) checkout_session = await stripe.checkout.Session.create_async( customer=customer_info['customer_id'], mode='setup', @@ -170,7 +170,7 @@ async def create_checkout_session( user_id: str = Depends(get_user_id), ) -> CreateBillingSessionResponse: await validate_billing_enabled() - base_url = _get_base_url(request) + base_url = get_web_url(request) customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id) if not customer_info: raise HTTPException( @@ -198,8 +198,8 @@ async def create_checkout_session( saved_payment_method_options={ 'payment_method_save': 'enabled', }, - success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}', - cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}', + success_url=f'{base_url}/api/billing/success?session_id={{CHECKOUT_SESSION_ID}}', + cancel_url=f'{base_url}/api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}', ) logger.info( 'created_stripe_checkout_session', @@ -300,7 +300,7 @@ async def success_callback(session_id: str, request: Request): await session.commit() return RedirectResponse( - f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302 + f'{get_web_url(request)}/settings/billing?checkout=success', status_code=302 ) @@ -325,17 +325,9 @@ async def cancel_callback(session_id: str, request: Request): ) billing_session.status = 'cancelled' billing_session.updated_at = datetime.now(UTC) - session.merge(billing_session) + await session.merge(billing_session) await session.commit() return RedirectResponse( - f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302 + f'{get_web_url(request)}/settings/billing?checkout=cancel', status_code=302 ) - - -def _get_base_url(request: Request) -> URL: - # Never send any part of the credit card process over a non secure connection - base_url = request.base_url - if base_url.hostname != 'localhost': - base_url = base_url.replace(scheme='https') - return base_url diff --git a/enterprise/server/routes/email.py b/enterprise/server/routes/email.py index 7571b619b2..5b910fcff4 100644 --- a/enterprise/server/routes/email.py +++ b/enterprise/server/routes/email.py @@ -7,8 +7,10 @@ from pydantic import BaseModel, field_validator from server.auth.constants import KEYCLOAK_CLIENT_ID from server.auth.keycloak_manager import get_keycloak_admin from server.auth.saas_user_auth import SaasUserAuth +from server.constants import IS_LOCAL_ENV from server.routes.auth import set_response_cookie from server.utils.rate_limit_utils import check_rate_limit_by_user_id +from server.utils.url_utils import get_web_url from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger @@ -87,7 +89,7 @@ async def update_email( response=response, keycloak_access_token=user_auth.access_token.get_secret_value(), keycloak_refresh_token=user_auth.refresh_token.get_secret_value(), - secure=False if request.url.hostname == 'localhost' else True, + secure=not IS_LOCAL_ENV, accepted_tos=user_auth.accepted_tos or False, ) @@ -156,8 +158,8 @@ async def verified_email(request: Request): await user_auth.refresh() # refresh so access token has updated email user_auth.email_verified = True await UserStore.update_user_email(user_id=user_auth.user_id, email_verified=True) - scheme = 'http' if request.url.hostname == 'localhost' else 'https' - redirect_uri = f'{scheme}://{request.url.netloc}/settings/user' + + redirect_uri = f'{get_web_url(request)}/settings/user' response = RedirectResponse(redirect_uri, status_code=302) # need to set auth cookie to the new tokens @@ -180,11 +182,10 @@ async def verified_email(request: Request): async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False): keycloak_admin = get_keycloak_admin() - scheme = 'http' if request.url.hostname == 'localhost' else 'https' if is_auth_flow: - redirect_uri = f'{scheme}://{request.url.netloc}/login?email_verified=true' + redirect_uri = f'{get_web_url(request)}/login?email_verified=true' else: - redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified' + redirect_uri = f'{get_web_url(request)}/api/email/verified' logger.info(f'Redirect URI: {redirect_uri}') await keycloak_admin.a_send_verify_email( user_id=user_id, diff --git a/enterprise/server/routes/oauth_device.py b/enterprise/server/routes/oauth_device.py index 52ee16b570..d0c110df0e 100644 --- a/enterprise/server/routes/oauth_device.py +++ b/enterprise/server/routes/oauth_device.py @@ -6,6 +6,7 @@ from typing import Optional from fastapi import APIRouter, Depends, Form, HTTPException, Request, status from fastapi.responses import JSONResponse from pydantic import BaseModel +from server.utils.url_utils import get_web_url from storage.api_key_store import ApiKeyStore from storage.device_code_store import DeviceCodeStore @@ -93,7 +94,7 @@ async def device_authorization( expires_in=DEVICE_CODE_EXPIRES_IN, ) - base_url = str(http_request.base_url).rstrip('/') + base_url = get_web_url(http_request) verification_uri = f'{base_url}/oauth/device/verify' verification_uri_complete = ( f'{verification_uri}?user_code={device_code_entry.user_code}' diff --git a/enterprise/server/utils/url_utils.py b/enterprise/server/utils/url_utils.py new file mode 100644 index 0000000000..5c939dbe8f --- /dev/null +++ b/enterprise/server/utils/url_utils.py @@ -0,0 +1,38 @@ +from typing import Literal + +from fastapi import Request +from server.constants import IS_FEATURE_ENV, IS_LOCAL_ENV, IS_STAGING_ENV +from starlette.datastructures import URL + +from openhands.app_server.config import get_global_config + + +def get_web_url(request: Request): + web_url = get_global_config().web_url + if not web_url: + scheme = 'http' if request.url.hostname == 'localhost' else 'https' + web_url = f'{scheme}://{request.url.netloc}' + else: + web_url = web_url.rstrip('/') + return web_url + + +def get_cookie_domain() -> str | None: + config = get_global_config() + web_url = config.web_url + # for now just use the full hostname except for staging stacks. + return ( + URL(web_url).hostname + if web_url and not (IS_FEATURE_ENV or IS_STAGING_ENV or IS_LOCAL_ENV) + else None + ) + + +def get_cookie_samesite() -> Literal['lax', 'strict']: + # for localhost and feature/staging stacks we set it to 'lax' as the cookie domain won't allow 'strict' + web_url = get_global_config().web_url + return ( + 'strict' + if web_url and not (IS_FEATURE_ENV or IS_STAGING_ENV or IS_LOCAL_ENV) + else 'lax' + ) diff --git a/enterprise/tests/unit/storage/test_saas_sql_app_conversation_info_service.py b/enterprise/tests/unit/storage/test_saas_sql_app_conversation_info_service.py index ac7f7b3db6..84a29be701 100644 --- a/enterprise/tests/unit/storage/test_saas_sql_app_conversation_info_service.py +++ b/enterprise/tests/unit/storage/test_saas_sql_app_conversation_info_service.py @@ -10,6 +10,9 @@ from unittest.mock import AsyncMock, MagicMock from uuid import UUID, uuid4 import pytest +from server.utils.saas_app_conversation_info_injector import ( + SaasSQLAppConversationInfoService, +) from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool @@ -17,9 +20,6 @@ from storage.base import Base from storage.org import Org from storage.user import User -from enterprise.server.utils.saas_app_conversation_info_injector import ( - SaasSQLAppConversationInfoService, -) from openhands.app_server.app_conversation.app_conversation_models import ( AppConversationInfo, ) diff --git a/enterprise/tests/unit/test_auth_routes.py b/enterprise/tests/unit/test_auth_routes.py index 88b112595d..e672449fa9 100644 --- a/enterprise/tests/unit/test_auth_routes.py +++ b/enterprise/tests/unit/test_auth_routes.py @@ -11,7 +11,6 @@ from server.auth.auth_error import AuthError from server.auth.saas_user_auth import SaasUserAuth from server.auth.user.user_authorizer import UserAuthorizationResponse, UserAuthorizer from server.routes.auth import ( - _extract_recaptcha_state, accept_tos, authenticate, keycloak_callback, @@ -55,11 +54,12 @@ def mock_response(): def test_set_response_cookie(mock_response, mock_request): """Test setting the auth cookie on a response.""" - with patch('server.routes.auth.config') as mock_config: + with ( + patch('server.routes.auth.config') as mock_config, + patch('server.utils.url_utils.get_global_config') as get_global_config, + ): mock_config.jwt_secret.get_secret_value.return_value = 'test_secret' - - # Configure mock_request.url.hostname - mock_request.url.hostname = 'example.com' + get_global_config.return_value = MagicMock(web_url='https://example.com') set_response_cookie( request=mock_request, @@ -1036,79 +1036,6 @@ async def test_keycloak_callback_no_email_in_user_info( mock_token_manager.check_duplicate_base_email.assert_not_called() -class TestExtractRecaptchaState: - """Tests for _extract_recaptcha_state() helper function.""" - - def test_should_extract_redirect_url_and_token_from_new_json_format(self): - """Test extraction from new base64-encoded JSON format.""" - # Arrange - state_data = { - 'redirect_url': 'https://example.com', - 'recaptcha_token': 'test-token', - } - encoded_state = base64.urlsafe_b64encode( - json.dumps(state_data).encode() - ).decode() - - # Act - redirect_url, token = _extract_recaptcha_state(encoded_state) - - # Assert - assert redirect_url == 'https://example.com' - assert token == 'test-token' - - def test_should_handle_old_format_plain_redirect_url(self): - """Test handling of old format (plain redirect URL string).""" - # Arrange - state = 'https://example.com' - - # Act - redirect_url, token = _extract_recaptcha_state(state) - - # Assert - assert redirect_url == 'https://example.com' - assert token is None - - def test_should_handle_none_state(self): - """Test handling of None state.""" - # Arrange - state = None - - # Act - redirect_url, token = _extract_recaptcha_state(state) - - # Assert - assert redirect_url == '' - assert token is None - - def test_should_handle_invalid_base64_gracefully(self): - """Test handling of invalid base64/JSON (fallback to old format).""" - # Arrange - state = 'not-valid-base64!!!' - - # Act - redirect_url, token = _extract_recaptcha_state(state) - - # Assert - assert redirect_url == state - assert token is None - - def test_should_handle_missing_redirect_url_in_json(self): - """Test handling when redirect_url is missing in JSON.""" - # Arrange - state_data = {'recaptcha_token': 'test-token'} - encoded_state = base64.urlsafe_b64encode( - json.dumps(state_data).encode() - ).decode() - - # Act - redirect_url, token = _extract_recaptcha_state(encoded_state) - - # Assert - assert redirect_url == '' - assert token == 'test-token' - - class TestKeycloakCallbackRecaptcha: """Tests for reCAPTCHA integration in keycloak_callback().""" diff --git a/enterprise/tests/unit/test_billing.py b/enterprise/tests/unit/test_billing.py index 995e277acd..18262790fa 100644 --- a/enterprise/tests/unit/test_billing.py +++ b/enterprise/tests/unit/test_billing.py @@ -48,7 +48,7 @@ def mock_checkout_request(): 'server': ('test.com', 80), } ) - request._base_url = URL('http://test.com/') + request._url = URL('http://test.com/') return request @@ -62,7 +62,7 @@ def mock_subscription_request(): 'server': ('test.com', 80), } ) - request._base_url = URL('http://test.com/') + request._url = URL('http://test.com/') return request @@ -264,7 +264,7 @@ async def test_create_checkout_session_success( async def test_success_callback_session_not_found(async_session_maker): """Test success callback when billing session is not found.""" mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') + mock_request._url = URL('http://test.com/') with ( patch('server.routes.billing.a_session_maker', async_session_maker), @@ -281,7 +281,7 @@ async def test_success_callback_stripe_incomplete( ): """Test success callback when Stripe session is not complete.""" mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') + mock_request._url = URL('http://test.com/') session_id = 'test_incomplete_session' async with async_session_maker() as session: @@ -319,7 +319,7 @@ async def test_success_callback_stripe_incomplete( async def test_success_callback_success(async_session_maker, test_org, test_user): """Test successful payment completion and credit update.""" mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') + mock_request._url = URL('http://test.com/') session_id = 'test_success_session' async with async_session_maker() as session: @@ -391,7 +391,7 @@ async def test_success_callback_lite_llm_error( ): """Test handling of LiteLLM API errors during success callback.""" mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') + mock_request._url = URL('http://test.com/') session_id = 'test_litellm_error_session' async with async_session_maker() as session: @@ -445,7 +445,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback( the database transaction rolls back. """ mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') + mock_request._url = URL('http://test.com/') session_id = 'test_budget_rollback_session' async with async_session_maker() as session: @@ -502,7 +502,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback( async def test_cancel_callback_session_not_found(async_session_maker): """Test cancel callback when billing session is not found.""" mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') + mock_request._url = URL('http://test.com/') with patch('server.routes.billing.a_session_maker', async_session_maker): response = await cancel_callback('nonexistent_session_id', mock_request) @@ -517,7 +517,7 @@ async def test_cancel_callback_session_not_found(async_session_maker): async def test_cancel_callback_success(async_session_maker, test_org, test_user): """Test successful cancellation of billing session.""" mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') + mock_request._url = URL('http://test.com/') session_id = 'test_cancel_session' async with async_session_maker() as session: @@ -588,7 +588,7 @@ async def test_create_customer_setup_session_success(): 'headers': [], } ) - mock_request._base_url = URL('http://test.com/') + mock_request._url = URL('http://test.com/') mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'} mock_session = MagicMock() @@ -613,6 +613,6 @@ async def test_create_customer_setup_session_success(): customer='mock-customer-id', mode='setup', payment_method_types=['card'], - success_url='https://test.com/?setup=success', - cancel_url='https://test.com/', + success_url='https://test.com?setup=success', + cancel_url='https://test.com', ) diff --git a/enterprise/tests/unit/utils/__init__.py b/enterprise/tests/unit/utils/__init__.py new file mode 100644 index 0000000000..15cd39f69b --- /dev/null +++ b/enterprise/tests/unit/utils/__init__.py @@ -0,0 +1 @@ +# Tests for enterprise server utils diff --git a/enterprise/tests/unit/utils/test_url_utils.py b/enterprise/tests/unit/utils/test_url_utils.py new file mode 100644 index 0000000000..d42b807b45 --- /dev/null +++ b/enterprise/tests/unit/utils/test_url_utils.py @@ -0,0 +1,425 @@ +"""Tests for URL utility functions that prevent URL hijacking attacks.""" + +from unittest.mock import MagicMock, patch + +import pytest + + +class TestGetWebUrl: + """Tests for get_web_url function.""" + + @pytest.fixture + def mock_request(self): + """Create a mock FastAPI request object.""" + request = MagicMock() + request.url = MagicMock() + return request + + def test_configured_web_url_is_used(self, mock_request): + """When web_url is configured, it should be used instead of request URL.""" + from server.utils.url_utils import get_web_url + + mock_request.url.hostname = 'evil-attacker.com' + mock_request.url.netloc = 'evil-attacker.com:443' + + mock_config = MagicMock() + mock_config.web_url = 'https://app.all-hands.dev' + + with patch( + 'server.utils.url_utils.get_global_config', return_value=mock_config + ): + result = get_web_url(mock_request) + + assert result == 'https://app.all-hands.dev' + # Should not use any info from the potentially poisoned request + assert 'evil-attacker.com' not in result + + def test_configured_web_url_trailing_slash_stripped(self, mock_request): + """Configured web_url should have trailing slashes stripped.""" + from server.utils.url_utils import get_web_url + + mock_config = MagicMock() + mock_config.web_url = 'https://app.all-hands.dev/' + + with patch( + 'server.utils.url_utils.get_global_config', return_value=mock_config + ): + result = get_web_url(mock_request) + + assert result == 'https://app.all-hands.dev' + assert not result.endswith('/') + + def test_unconfigured_web_url_localhost_uses_http(self, mock_request): + """When web_url is not configured and hostname is localhost, use http.""" + from server.utils.url_utils import get_web_url + + mock_request.url.hostname = 'localhost' + mock_request.url.netloc = 'localhost:3000' + + mock_config = MagicMock() + mock_config.web_url = None + + with patch( + 'server.utils.url_utils.get_global_config', return_value=mock_config + ): + result = get_web_url(mock_request) + + assert result == 'http://localhost:3000' + + def test_unconfigured_web_url_non_localhost_uses_https(self, mock_request): + """When web_url is not configured and hostname is not localhost, use https.""" + from server.utils.url_utils import get_web_url + + mock_request.url.hostname = 'example.com' + mock_request.url.netloc = 'example.com:443' + + mock_config = MagicMock() + mock_config.web_url = None + + with patch( + 'server.utils.url_utils.get_global_config', return_value=mock_config + ): + result = get_web_url(mock_request) + + assert result == 'https://example.com:443' + + def test_unconfigured_web_url_empty_string_fallback(self, mock_request): + """Empty string web_url should trigger fallback.""" + from server.utils.url_utils import get_web_url + + mock_request.url.hostname = 'localhost' + mock_request.url.netloc = 'localhost:3000' + + mock_config = MagicMock() + mock_config.web_url = '' + + with patch( + 'server.utils.url_utils.get_global_config', return_value=mock_config + ): + result = get_web_url(mock_request) + + assert result == 'http://localhost:3000' + + +class TestGetCookieDomain: + """Tests for get_cookie_domain function.""" + + def test_production_with_configured_web_url(self): + """In production with web_url configured, should return hostname.""" + from server.utils.url_utils import get_cookie_domain + + mock_config = MagicMock() + mock_config.web_url = 'https://app.all-hands.dev' + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', False), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + result = get_cookie_domain() + + assert result == 'app.all-hands.dev' + + def test_production_without_web_url_returns_none(self): + """In production without web_url configured, should return None.""" + from server.utils.url_utils import get_cookie_domain + + mock_config = MagicMock() + mock_config.web_url = None + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', False), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + result = get_cookie_domain() + + assert result is None + + def test_local_env_returns_none(self): + """In local environment, should return None for cookie domain.""" + from server.utils.url_utils import get_cookie_domain + + mock_config = MagicMock() + mock_config.web_url = 'https://app.all-hands.dev' + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', False), + patch('server.utils.url_utils.IS_LOCAL_ENV', True), + ): + result = get_cookie_domain() + + assert result is None + + def test_staging_env_returns_none(self): + """In staging environment, should return None for cookie domain.""" + from server.utils.url_utils import get_cookie_domain + + mock_config = MagicMock() + mock_config.web_url = 'https://staging.all-hands.dev' + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', True), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + result = get_cookie_domain() + + assert result is None + + def test_feature_env_returns_none(self): + """In feature environment, should return None for cookie domain.""" + from server.utils.url_utils import get_cookie_domain + + mock_config = MagicMock() + mock_config.web_url = 'https://feature-123.staging.all-hands.dev' + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', True), + patch('server.utils.url_utils.IS_STAGING_ENV', True), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + result = get_cookie_domain() + + assert result is None + + +class TestGetCookieSamesite: + """Tests for get_cookie_samesite function.""" + + def test_production_with_configured_web_url_returns_strict(self): + """In production with web_url configured, should return 'strict'.""" + from server.utils.url_utils import get_cookie_samesite + + mock_config = MagicMock() + mock_config.web_url = 'https://app.all-hands.dev' + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', False), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + result = get_cookie_samesite() + + assert result == 'strict' + + def test_production_without_web_url_returns_lax(self): + """In production without web_url configured, should return 'lax'.""" + from server.utils.url_utils import get_cookie_samesite + + mock_config = MagicMock() + mock_config.web_url = None + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', False), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + result = get_cookie_samesite() + + assert result == 'lax' + + def test_local_env_returns_lax(self): + """In local environment, should return 'lax'.""" + from server.utils.url_utils import get_cookie_samesite + + mock_config = MagicMock() + mock_config.web_url = 'http://localhost:3000' + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', False), + patch('server.utils.url_utils.IS_LOCAL_ENV', True), + ): + result = get_cookie_samesite() + + assert result == 'lax' + + def test_staging_env_returns_lax(self): + """In staging environment, should return 'lax'.""" + from server.utils.url_utils import get_cookie_samesite + + mock_config = MagicMock() + mock_config.web_url = 'https://staging.all-hands.dev' + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', True), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + result = get_cookie_samesite() + + assert result == 'lax' + + def test_feature_env_returns_lax(self): + """In feature environment, should return 'lax'.""" + from server.utils.url_utils import get_cookie_samesite + + mock_config = MagicMock() + mock_config.web_url = 'https://feature-xyz.staging.all-hands.dev' + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', True), + patch('server.utils.url_utils.IS_STAGING_ENV', True), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + result = get_cookie_samesite() + + assert result == 'lax' + + def test_empty_web_url_returns_lax(self): + """Empty web_url should be treated as unconfigured and return 'lax'.""" + from server.utils.url_utils import get_cookie_samesite + + mock_config = MagicMock() + mock_config.web_url = '' + + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', False), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + result = get_cookie_samesite() + + assert result == 'lax' + + +class TestSecurityScenarios: + """Tests for security-critical scenarios.""" + + @pytest.fixture + def mock_request(self): + """Create a mock FastAPI request object.""" + request = MagicMock() + request.url = MagicMock() + return request + + def test_header_poisoning_attack_blocked_when_configured(self, mock_request): + """ + When web_url is configured, X-Forwarded-* header poisoning should not affect + the returned URL. + """ + from server.utils.url_utils import get_web_url + + # Simulate a poisoned request where attacker controls headers + mock_request.url.hostname = 'evil.com' + mock_request.url.netloc = 'evil.com:443' + + mock_config = MagicMock() + mock_config.web_url = 'https://app.all-hands.dev' + + with patch( + 'server.utils.url_utils.get_global_config', return_value=mock_config + ): + result = get_web_url(mock_request) + + # Should use configured web_url, not the poisoned request data + assert result == 'https://app.all-hands.dev' + assert 'evil' not in result + + def test_cookie_domain_not_set_in_dev_environments(self): + """ + Cookie domain should not be set in development environments to prevent + cookies from leaking to other subdomains. + """ + from server.utils.url_utils import get_cookie_domain + + mock_config = MagicMock() + mock_config.web_url = 'https://my-feature.staging.all-hands.dev' + + # Test each dev environment + for env_name, env_config in [ + ( + 'local', + { + 'IS_LOCAL_ENV': True, + 'IS_STAGING_ENV': False, + 'IS_FEATURE_ENV': False, + }, + ), + ( + 'staging', + { + 'IS_LOCAL_ENV': False, + 'IS_STAGING_ENV': True, + 'IS_FEATURE_ENV': False, + }, + ), + ( + 'feature', + {'IS_LOCAL_ENV': False, 'IS_STAGING_ENV': True, 'IS_FEATURE_ENV': True}, + ), + ]: + with ( + patch( + 'server.utils.url_utils.get_global_config', return_value=mock_config + ), + patch( + 'server.utils.url_utils.IS_FEATURE_ENV', + env_config['IS_FEATURE_ENV'], + ), + patch( + 'server.utils.url_utils.IS_STAGING_ENV', + env_config['IS_STAGING_ENV'], + ), + patch( + 'server.utils.url_utils.IS_LOCAL_ENV', env_config['IS_LOCAL_ENV'] + ), + ): + result = get_cookie_domain() + assert result is None, f'Expected None for {env_name} environment' + + def test_strict_samesite_only_in_production(self): + """ + SameSite=strict should only be set in production to ensure proper + security without breaking OAuth flows in development. + """ + from server.utils.url_utils import get_cookie_samesite + + mock_config = MagicMock() + mock_config.web_url = 'https://app.all-hands.dev' + + # Production should be strict + with ( + patch('server.utils.url_utils.get_global_config', return_value=mock_config), + patch('server.utils.url_utils.IS_FEATURE_ENV', False), + patch('server.utils.url_utils.IS_STAGING_ENV', False), + patch('server.utils.url_utils.IS_LOCAL_ENV', False), + ): + assert get_cookie_samesite() == 'strict' + + # Dev environments should be lax + for env_config in [ + {'IS_LOCAL_ENV': True, 'IS_STAGING_ENV': False, 'IS_FEATURE_ENV': False}, + {'IS_LOCAL_ENV': False, 'IS_STAGING_ENV': True, 'IS_FEATURE_ENV': False}, + {'IS_LOCAL_ENV': False, 'IS_STAGING_ENV': True, 'IS_FEATURE_ENV': True}, + ]: + with ( + patch( + 'server.utils.url_utils.get_global_config', return_value=mock_config + ), + patch( + 'server.utils.url_utils.IS_FEATURE_ENV', + env_config['IS_FEATURE_ENV'], + ), + patch( + 'server.utils.url_utils.IS_STAGING_ENV', + env_config['IS_STAGING_ENV'], + ), + patch( + 'server.utils.url_utils.IS_LOCAL_ENV', env_config['IS_LOCAL_ENV'] + ), + ): + assert get_cookie_samesite() == 'lax'