diff --git a/enterprise/server/auth/email_validation.py b/enterprise/server/auth/email_validation.py new file mode 100644 index 0000000000..94c6f52c2a --- /dev/null +++ b/enterprise/server/auth/email_validation.py @@ -0,0 +1,109 @@ +"""Email validation utilities for preventing duplicate signups with + modifier.""" + +import re + + +def extract_base_email(email: str) -> str | None: + """Extract base email from an email address. + + For emails with + modifier, extracts the base email (local part before + and @, plus domain). + For emails without + modifier, returns the email as-is. + + Examples: + extract_base_email("joe+test@example.com") -> "joe@example.com" + extract_base_email("joe@example.com") -> "joe@example.com" + extract_base_email("joe+openhands+test@example.com") -> "joe@example.com" + + Args: + email: The email address to process + + Returns: + The base email address, or None if email format is invalid + """ + if not email or '@' not in email: + return None + + try: + local_part, domain = email.rsplit('@', 1) + # Extract the part before + if it exists + base_local = local_part.split('+', 1)[0] + return f'{base_local}@{domain}' + except (ValueError, AttributeError): + return None + + +def has_plus_modifier(email: str) -> bool: + """Check if an email address contains a + modifier. + + Args: + email: The email address to check + + Returns: + True if email contains + before @, False otherwise + """ + if not email or '@' not in email: + return False + + try: + local_part, _ = email.rsplit('@', 1) + return '+' in local_part + except (ValueError, AttributeError): + return False + + +def matches_base_email(email: str, base_email: str) -> bool: + """Check if an email matches a base email pattern. + + An email matches if: + - It is exactly the base email (e.g., joe@example.com) + - It has the same base local part and domain, with or without + modifier + (e.g., joe+test@example.com matches base joe@example.com) + + Args: + email: The email address to check + base_email: The base email to match against + + Returns: + True if email matches the base pattern, False otherwise + """ + if not email or not base_email: + return False + + # Extract base from both emails for comparison + email_base = extract_base_email(email) + base_email_normalized = extract_base_email(base_email) + + if not email_base or not base_email_normalized: + return False + + # Emails match if they have the same base + return email_base.lower() == base_email_normalized.lower() + + +def get_base_email_regex_pattern(base_email: str) -> re.Pattern | None: + """Generate a regex pattern to match emails with the same base. + + For base_email "joe@example.com", the pattern will match: + - joe@example.com + - joe+anything@example.com + + Args: + base_email: The base email address + + Returns: + A compiled regex pattern, or None if base_email is invalid + """ + base = extract_base_email(base_email) + if not base: + return None + + try: + local_part, domain = base.rsplit('@', 1) + # Escape special regex characters in local part and domain + escaped_local = re.escape(local_part) + escaped_domain = re.escape(domain) + # Pattern: joe@example.com OR joe+anything@example.com + pattern = rf'^{escaped_local}(\+[^@\s]+)?@{escaped_domain}$' + return re.compile(pattern, re.IGNORECASE) + except (ValueError, AttributeError): + return None diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index b51d336997..73a7217fd2 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -154,8 +154,10 @@ class SaasUserAuth(UserAuth): try: # TODO: I think we can do this in a single request if we refactor with session_maker() as session: - tokens = session.query(AuthTokens).where( - AuthTokens.keycloak_user_id == self.user_id + tokens = ( + session.query(AuthTokens) + .where(AuthTokens.keycloak_user_id == self.user_id) + .all() ) for token in tokens: diff --git a/enterprise/server/auth/token_manager.py b/enterprise/server/auth/token_manager.py index 04bfae0767..6061518cb4 100644 --- a/enterprise/server/auth/token_manager.py +++ b/enterprise/server/auth/token_manager.py @@ -1,3 +1,4 @@ +import asyncio import base64 import hashlib import json @@ -25,6 +26,11 @@ from server.auth.constants import ( KEYCLOAK_SERVER_URL, KEYCLOAK_SERVER_URL_EXT, ) +from server.auth.email_validation import ( + extract_base_email, + get_base_email_regex_pattern, + matches_base_email, +) from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid from server.config import get_config from server.logger import logger @@ -509,6 +515,183 @@ class TokenManager: logger.info(f'Got user ID {keycloak_user_id} from email: {email}') return keycloak_user_id + async def _query_users_by_wildcard_pattern( + self, local_part: str, domain: str + ) -> dict[str, dict]: + """Query Keycloak for users matching a wildcard email pattern. + + Tries multiple query methods to find users with emails matching + the pattern {local_part}*@{domain}. This catches the base email + and all + modifier variants. + + Args: + local_part: The local part of the email (before @) + domain: The domain part of the email (after @) + + Returns: + Dictionary mapping user IDs to user objects + """ + keycloak_admin = get_keycloak_admin(self.external) + all_users = {} + + # Query for users with emails matching the base pattern using wildcard + # Pattern: {local_part}*@{domain} - catches base email and all + variants + # This may also catch unintended matches (e.g., joesmith@example.com), but + # they will be filtered out by the regex pattern check later + # Use 'search' parameter for Keycloak 26+ (better wildcard support) + wildcard_queries = [ + {'search': f'{local_part}*@{domain}'}, # Try 'search' parameter first + {'q': f'email:{local_part}*@{domain}'}, # Fallback to 'q' parameter + ] + + for query_params in wildcard_queries: + try: + users = await keycloak_admin.a_get_users(query_params) + for user in users: + all_users[user.get('id')] = user + break # Success, no need to try fallback + except Exception as e: + logger.debug( + f'Wildcard query failed with {list(query_params.keys())[0]}: {e}' + ) + continue # Try next query method + + return all_users + + def _find_duplicate_in_users( + self, users: dict[str, dict], base_email: str, current_user_id: str + ) -> bool: + """Check if any user in the provided list matches the base email pattern. + + Filters users to find duplicates that match the base email pattern, + excluding the current user. + + Args: + users: Dictionary mapping user IDs to user objects + base_email: The base email to match against + current_user_id: The user ID to exclude from the check + + Returns: + True if a duplicate is found, False otherwise + """ + regex_pattern = get_base_email_regex_pattern(base_email) + if not regex_pattern: + logger.warning( + f'Could not generate regex pattern for base email: {base_email}' + ) + # Fallback to simple matching + for user in users.values(): + user_email = user.get('email', '').lower() + if ( + user_email + and user.get('id') != current_user_id + and matches_base_email(user_email, base_email) + ): + logger.info( + f'Found duplicate email: {user_email} matches base {base_email}' + ) + return True + else: + for user in users.values(): + user_email = user.get('email', '') + if ( + user_email + and user.get('id') != current_user_id + and regex_pattern.match(user_email) + ): + logger.info( + f'Found duplicate email: {user_email} matches base {base_email}' + ) + return True + + return False + + @retry( + stop=stop_after_attempt(2), + retry=retry_if_exception_type(KeycloakConnectionError), + before_sleep=_before_sleep_callback, + ) + async def check_duplicate_base_email( + self, email: str, current_user_id: str + ) -> bool: + """Check if a user with the same base email already exists. + + This method checks for duplicate signups using email + modifier. + It checks if any user exists with the same base email, regardless of whether + the provided email has a + modifier or not. + + Examples: + - If email is "joe+test@example.com", it checks for existing users with + base email "joe@example.com" (e.g., "joe@example.com", "joe+1@example.com") + - If email is "joe@example.com", it checks for existing users with + base email "joe@example.com" (e.g., "joe+1@example.com", "joe+test@example.com") + + Args: + email: The email address to check (may or may not contain + modifier) + current_user_id: The user ID of the current user (to exclude from check) + + Returns: + True if a duplicate is found (excluding current user), False otherwise + """ + if not email: + return False + + base_email = extract_base_email(email) + if not base_email: + logger.warning(f'Could not extract base email from: {email}') + return False + + try: + local_part, domain = base_email.rsplit('@', 1) + users = await self._query_users_by_wildcard_pattern(local_part, domain) + return self._find_duplicate_in_users(users, base_email, current_user_id) + + except KeycloakConnectionError: + logger.exception('KeycloakConnectionError when checking duplicate email') + raise + except Exception as e: + logger.exception(f'Unexpected error checking duplicate email: {e}') + # On any error, allow signup to proceed (fail open) + return False + + @retry( + stop=stop_after_attempt(2), + retry=retry_if_exception_type(KeycloakConnectionError), + before_sleep=_before_sleep_callback, + ) + async def delete_keycloak_user(self, user_id: str) -> bool: + """Delete a user from Keycloak. + + This method is used to clean up user accounts that were created + but should not exist (e.g., duplicate email signups). + + Args: + user_id: The Keycloak user ID to delete + + Returns: + True if deletion was successful, False otherwise + """ + try: + keycloak_admin = get_keycloak_admin(self.external) + # Use the sync method (python-keycloak doesn't have async delete_user) + # Run it in a thread executor to avoid blocking the event loop + await asyncio.to_thread(keycloak_admin.delete_user, user_id) + logger.info(f'Successfully deleted Keycloak user {user_id}') + return True + except KeycloakConnectionError: + logger.exception(f'KeycloakConnectionError when deleting user {user_id}') + raise + except KeycloakError as e: + # User might not exist or already deleted + logger.warning( + f'KeycloakError when deleting user {user_id}: {e}', + extra={'user_id': user_id, 'error': str(e)}, + ) + return False + except Exception as e: + logger.exception(f'Unexpected error deleting Keycloak user {user_id}: {e}') + return False + async def get_user_info_from_user_id(self, user_id: str) -> dict | None: keycloak_admin = get_keycloak_admin(self.external) user = await keycloak_admin.a_get_user(user_id) diff --git a/enterprise/server/routes/auth.py b/enterprise/server/routes/auth.py index 2ee50bbd2d..e911538da6 100644 --- a/enterprise/server/routes/auth.py +++ b/enterprise/server/routes/auth.py @@ -146,9 +146,11 @@ async def keycloak_callback( content={'error': 'Missing user ID or username in response'}, ) - # Check if email domain is blocked email = user_info.get('email') user_id = user_info['sub'] + + # Check if email domain is blocked + email = user_info.get('email') if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email): logger.warning( f'Blocked authentication attempt for email: {email}, user_id: {user_id}' @@ -164,6 +166,54 @@ async def keycloak_callback( }, ) + # Check for duplicate email with + modifier + if email: + try: + has_duplicate = await token_manager.check_duplicate_base_email( + email, user_id + ) + if has_duplicate: + logger.warning( + f'Blocked signup attempt for email {email} - duplicate base email found', + extra={'user_id': user_id, 'email': email}, + ) + + # Delete the Keycloak user that was automatically created during OAuth + # This prevents orphaned accounts in Keycloak + # The delete_keycloak_user method already handles all errors internally + deletion_success = await token_manager.delete_keycloak_user(user_id) + if deletion_success: + logger.info( + f'Deleted Keycloak user {user_id} after detecting duplicate email {email}' + ) + else: + logger.warning( + f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. ' + f'User may need to be manually cleaned up.' + ) + + # Redirect to home page with query parameter indicating the issue + home_url = f'{request.base_url}?duplicated_email=true' + return RedirectResponse(home_url, status_code=302) + except Exception as e: + # Log error but allow signup to proceed (fail open) + logger.error( + f'Error checking duplicate email for {email}: {e}', + extra={'user_id': user_id, 'email': email}, + ) + + # Check email verification status + email_verified = user_info.get('email_verified', False) + if not email_verified: + # Send verification email + # Import locally to avoid circular import with email.py + from server.routes.email import verify_email + + await verify_email(request=request, user_id=user_id, is_auth_flow=True) + redirect_url = f'{request.base_url}?email_verification_required=true' + response = RedirectResponse(redirect_url, status_code=302) + return response + # default to github IDP for now. # TODO: remove default once Keycloak is updated universally with the new attribute. idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value) diff --git a/enterprise/server/routes/email.py b/enterprise/server/routes/email.py index b0d88afaa0..b58adf9a4f 100644 --- a/enterprise/server/routes/email.py +++ b/enterprise/server/routes/email.py @@ -74,7 +74,7 @@ async def update_email( accepted_tos=user_auth.accepted_tos, ) - await _verify_email(request=request, user_id=user_id) + await verify_email(request=request, user_id=user_id) logger.info(f'Updating email address for {user_id} to {email}') return response @@ -91,8 +91,10 @@ async def update_email( @api_router.put('/verify') -async def verify_email(request: Request, user_id: str = Depends(get_user_id)): - await _verify_email(request=request, user_id=user_id) +async def resend_email_verification( + request: Request, user_id: str = Depends(get_user_id) +): + await verify_email(request=request, user_id=user_id) logger.info(f'Resending verification email for {user_id}') return JSONResponse( @@ -124,10 +126,14 @@ async def verified_email(request: Request): return response -async def _verify_email(request: Request, user_id: str): +async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False): keycloak_admin = get_keycloak_admin() scheme = 'http' if request.url.hostname == 'localhost' else 'https' - redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified' + redirect_uri = ( + f'{scheme}://{request.url.netloc}?email_verified=true' + if is_auth_flow + else f'{scheme}://{request.url.netloc}/api/email/verified' + ) logger.info(f'Redirect URI: {redirect_uri}') await keycloak_admin.a_send_verify_email( user_id=user_id, diff --git a/enterprise/tests/unit/server/routes/test_email_routes.py b/enterprise/tests/unit/server/routes/test_email_routes.py new file mode 100644 index 0000000000..8f5ba12e87 --- /dev/null +++ b/enterprise/tests/unit/server/routes/test_email_routes.py @@ -0,0 +1,151 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import Request +from fastapi.responses import RedirectResponse +from pydantic import SecretStr +from server.auth.saas_user_auth import SaasUserAuth +from server.routes.email import verified_email, verify_email + + +@pytest.fixture +def mock_request(): + """Create a mock request object.""" + request = MagicMock(spec=Request) + request.url = MagicMock() + request.url.hostname = 'localhost' + request.url.netloc = 'localhost:8000' + request.url.path = '/api/email/verified' + request.base_url = 'http://localhost:8000/' + request.headers = {} + request.cookies = {} + request.query_params = MagicMock() + return request + + +@pytest.fixture +def mock_user_auth(): + """Create a mock SaasUserAuth object.""" + auth = MagicMock(spec=SaasUserAuth) + auth.access_token = SecretStr('test_access_token') + auth.refresh_token = SecretStr('test_refresh_token') + auth.email = 'test@example.com' + auth.email_verified = False + auth.accepted_tos = True + auth.refresh = AsyncMock() + return auth + + +@pytest.mark.asyncio +async def test_verify_email_default_behavior(mock_request): + """Test verify_email with default is_auth_flow=False.""" + # Arrange + user_id = 'test_user_id' + mock_keycloak_admin = AsyncMock() + mock_keycloak_admin.a_send_verify_email = AsyncMock() + + # Act + with patch( + 'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin + ): + await verify_email(request=mock_request, user_id=user_id) + + # Assert + mock_keycloak_admin.a_send_verify_email.assert_called_once() + call_args = mock_keycloak_admin.a_send_verify_email.call_args + assert call_args.kwargs['user_id'] == user_id + assert ( + call_args.kwargs['redirect_uri'] == 'http://localhost:8000/api/email/verified' + ) + assert 'client_id' in call_args.kwargs + + +@pytest.mark.asyncio +async def test_verify_email_with_auth_flow(mock_request): + """Test verify_email with is_auth_flow=True.""" + # Arrange + user_id = 'test_user_id' + mock_keycloak_admin = AsyncMock() + mock_keycloak_admin.a_send_verify_email = AsyncMock() + + # Act + with patch( + 'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin + ): + await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True) + + # Assert + mock_keycloak_admin.a_send_verify_email.assert_called_once() + call_args = mock_keycloak_admin.a_send_verify_email.call_args + assert call_args.kwargs['user_id'] == user_id + assert ( + call_args.kwargs['redirect_uri'] == 'http://localhost:8000?email_verified=true' + ) + assert 'client_id' in call_args.kwargs + + +@pytest.mark.asyncio +async def test_verify_email_https_scheme(mock_request): + """Test verify_email uses https scheme for non-localhost hosts.""" + # Arrange + user_id = 'test_user_id' + mock_request.url.hostname = 'example.com' + mock_request.url.netloc = 'example.com' + mock_keycloak_admin = AsyncMock() + mock_keycloak_admin.a_send_verify_email = AsyncMock() + + # Act + with patch( + 'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin + ): + await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True) + + # Assert + call_args = mock_keycloak_admin.a_send_verify_email.call_args + assert call_args.kwargs['redirect_uri'].startswith('https://') + + +@pytest.mark.asyncio +async def test_verified_email_default_redirect(mock_request, mock_user_auth): + """Test verified_email redirects to /settings/user by default.""" + # Arrange + mock_request.query_params.get.return_value = None + + # Act + with ( + patch('server.routes.email.get_user_auth', return_value=mock_user_auth), + patch('server.routes.email.set_response_cookie') as mock_set_cookie, + ): + result = await verified_email(mock_request) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + assert result.headers['location'] == 'http://localhost:8000/settings/user' + mock_user_auth.refresh.assert_called_once() + mock_set_cookie.assert_called_once() + assert mock_user_auth.email_verified is True + + +@pytest.mark.asyncio +async def test_verified_email_https_scheme(mock_request, mock_user_auth): + """Test verified_email uses https scheme for non-localhost hosts.""" + # Arrange + mock_request.url.hostname = 'example.com' + mock_request.url.netloc = 'example.com' + mock_request.query_params.get.return_value = None + + # Act + with ( + patch('server.routes.email.get_user_auth', return_value=mock_user_auth), + patch('server.routes.email.set_response_cookie') as mock_set_cookie, + ): + result = await verified_email(mock_request) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.headers['location'].startswith('https://') + mock_set_cookie.assert_called_once() + # Verify secure flag is True for https + call_kwargs = mock_set_cookie.call_args.kwargs + assert call_kwargs['secure'] is True diff --git a/enterprise/tests/unit/test_auth_routes.py b/enterprise/tests/unit/test_auth_routes.py index d3e8f47fbe..8490d92760 100644 --- a/enterprise/tests/unit/test_auth_routes.py +++ b/enterprise/tests/unit/test_auth_routes.py @@ -136,6 +136,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request): 'sub': 'test_user_id', 'preferred_username': 'test_user', 'identity_provider': 'github', + 'email_verified': True, } ) mock_token_manager.store_idp_tokens = AsyncMock() @@ -184,6 +185,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request): 'sub': 'test_user_id', 'preferred_username': 'test_user', 'identity_provider': 'github', + 'email_verified': True, } ) mock_token_manager.store_idp_tokens = AsyncMock() @@ -214,6 +216,82 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request): mock_posthog.set.assert_called_once() +@pytest.mark.asyncio +async def test_keycloak_callback_email_not_verified(mock_request): + """Test keycloak_callback when email is not verified.""" + # Arrange + mock_verify_email = AsyncMock() + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.email.verify_email', mock_verify_email), + ): + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'identity_provider': 'github', + 'email_verified': False, + } + ) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_verifier.is_active.return_value = False + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + assert 'email_verification_required=true' in result.headers['location'] + mock_verify_email.assert_called_once_with( + request=mock_request, user_id='test_user_id', is_auth_flow=True + ) + + +@pytest.mark.asyncio +async def test_keycloak_callback_email_not_verified_missing_field(mock_request): + """Test keycloak_callback when email_verified field is missing (defaults to False).""" + # Arrange + mock_verify_email = AsyncMock() + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.email.verify_email', mock_verify_email), + ): + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'identity_provider': 'github', + # email_verified field is missing + } + ) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_verifier.is_active.return_value = False + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + assert 'email_verification_required=true' in result.headers['location'] + mock_verify_email.assert_called_once_with( + request=mock_request, user_id='test_user_id', is_auth_flow=True + ) + + @pytest.mark.asyncio async def test_keycloak_callback_success_without_offline_token(mock_request): """Test successful keycloak_callback without valid offline token.""" @@ -248,6 +326,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request): 'sub': 'test_user_id', 'preferred_username': 'test_user', 'identity_provider': 'github', + 'email_verified': True, } ) mock_token_manager.store_idp_tokens = AsyncMock() @@ -513,6 +592,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request): 'preferred_username': 'test_user', 'email': 'user@example.com', 'identity_provider': 'github', + 'email_verified': True, } ) mock_token_manager.store_idp_tokens = AsyncMock() @@ -566,6 +646,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request): 'preferred_username': 'test_user', 'email': 'user@colsch.us', 'identity_provider': 'github', + 'email_verified': True, } ) mock_token_manager.store_idp_tokens = AsyncMock() @@ -615,6 +696,7 @@ async def test_keycloak_callback_missing_email(mock_request): 'sub': 'test_user_id', 'preferred_username': 'test_user', 'identity_provider': 'github', + 'email_verified': True, # No email field } ) @@ -635,3 +717,222 @@ async def test_keycloak_callback_missing_email(mock_request): assert isinstance(result, RedirectResponse) mock_domain_blocker.is_domain_blocked.assert_not_called() mock_token_manager.disable_keycloak_user.assert_not_called() + + +@pytest.mark.asyncio +async def test_keycloak_callback_duplicate_email_detected(mock_request): + """Test keycloak_callback when duplicate email is detected.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + ): + # Arrange + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'joe+test@example.com', + 'identity_provider': 'github', + } + ) + mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True) + mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True) + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + assert 'duplicated_email=true' in result.headers['location'] + mock_token_manager.check_duplicate_base_email.assert_called_once_with( + 'joe+test@example.com', 'test_user_id' + ) + mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id') + + +@pytest.mark.asyncio +async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request): + """Test keycloak_callback when duplicate is detected but deletion fails.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + ): + # Arrange + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'joe+test@example.com', + 'identity_provider': 'github', + } + ) + mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True) + mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False) + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + assert 'duplicated_email=true' in result.headers['location'] + mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id') + + +@pytest.mark.asyncio +async def test_keycloak_callback_duplicate_check_exception(mock_request): + """Test keycloak_callback when duplicate check raises exception.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.session_maker') as mock_session_maker, + ): + # Arrange + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_user_settings = MagicMock() + mock_user_settings.accepted_tos = '2025-01-01' + mock_query.first.return_value = mock_user_settings + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'joe+test@example.com', + 'identity_provider': 'github', + 'email_verified': True, + } + ) + mock_token_manager.check_duplicate_base_email = AsyncMock( + side_effect=Exception('Check failed') + ) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + # Should proceed with normal flow despite exception (fail open) + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + + +@pytest.mark.asyncio +async def test_keycloak_callback_no_duplicate_email(mock_request): + """Test keycloak_callback when no duplicate email is found.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.session_maker') as mock_session_maker, + ): + # Arrange + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_user_settings = MagicMock() + mock_user_settings.accepted_tos = '2025-01-01' + mock_query.first.return_value = mock_user_settings + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'joe+test@example.com', + 'identity_provider': 'github', + 'email_verified': True, + } + ) + mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + mock_token_manager.check_duplicate_base_email.assert_called_once_with( + 'joe+test@example.com', 'test_user_id' + ) + # Should not delete user when no duplicate found + mock_token_manager.delete_keycloak_user.assert_not_called() + + +@pytest.mark.asyncio +async def test_keycloak_callback_no_email_in_user_info(mock_request): + """Test keycloak_callback when email is not in user_info.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.session_maker') as mock_session_maker, + ): + # Arrange + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_user_settings = MagicMock() + mock_user_settings.accepted_tos = '2025-01-01' + mock_query.first.return_value = mock_user_settings + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + # No email field + 'identity_provider': 'github', + 'email_verified': True, + } + ) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + # Should not check for duplicate when email is missing + mock_token_manager.check_duplicate_base_email.assert_not_called() diff --git a/enterprise/tests/unit/test_email_validation.py b/enterprise/tests/unit/test_email_validation.py new file mode 100644 index 0000000000..320c5d4699 --- /dev/null +++ b/enterprise/tests/unit/test_email_validation.py @@ -0,0 +1,294 @@ +"""Tests for email validation utilities.""" + +import re + +from server.auth.email_validation import ( + extract_base_email, + get_base_email_regex_pattern, + has_plus_modifier, + matches_base_email, +) + + +class TestExtractBaseEmail: + """Test cases for extract_base_email function.""" + + def test_extract_base_email_with_plus_modifier(self): + """Test extracting base email from email with + modifier.""" + # Arrange + email = 'joe+test@example.com' + + # Act + result = extract_base_email(email) + + # Assert + assert result == 'joe@example.com' + + def test_extract_base_email_without_plus_modifier(self): + """Test that email without + modifier is returned as-is.""" + # Arrange + email = 'joe@example.com' + + # Act + result = extract_base_email(email) + + # Assert + assert result == 'joe@example.com' + + def test_extract_base_email_multiple_plus_signs(self): + """Test extracting base email when multiple + signs exist.""" + # Arrange + email = 'joe+openhands+test@example.com' + + # Act + result = extract_base_email(email) + + # Assert + assert result == 'joe@example.com' + + def test_extract_base_email_invalid_no_at_symbol(self): + """Test that invalid email without @ returns None.""" + # Arrange + email = 'invalid-email' + + # Act + result = extract_base_email(email) + + # Assert + assert result is None + + def test_extract_base_email_empty_string(self): + """Test that empty string returns None.""" + # Arrange + email = '' + + # Act + result = extract_base_email(email) + + # Assert + assert result is None + + def test_extract_base_email_none(self): + """Test that None input returns None.""" + # Arrange + email = None + + # Act + result = extract_base_email(email) + + # Assert + assert result is None + + +class TestHasPlusModifier: + """Test cases for has_plus_modifier function.""" + + def test_has_plus_modifier_true(self): + """Test detecting + modifier in email.""" + # Arrange + email = 'joe+test@example.com' + + # Act + result = has_plus_modifier(email) + + # Assert + assert result is True + + def test_has_plus_modifier_false(self): + """Test that email without + modifier returns False.""" + # Arrange + email = 'joe@example.com' + + # Act + result = has_plus_modifier(email) + + # Assert + assert result is False + + def test_has_plus_modifier_invalid_no_at_symbol(self): + """Test that invalid email without @ returns False.""" + # Arrange + email = 'invalid-email' + + # Act + result = has_plus_modifier(email) + + # Assert + assert result is False + + def test_has_plus_modifier_empty_string(self): + """Test that empty string returns False.""" + # Arrange + email = '' + + # Act + result = has_plus_modifier(email) + + # Assert + assert result is False + + +class TestMatchesBaseEmail: + """Test cases for matches_base_email function.""" + + def test_matches_base_email_exact_match(self): + """Test that exact base email matches.""" + # Arrange + email = 'joe@example.com' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is True + + def test_matches_base_email_with_plus_variant(self): + """Test that email with + variant matches base email.""" + # Arrange + email = 'joe+test@example.com' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is True + + def test_matches_base_email_different_base(self): + """Test that different base emails do not match.""" + # Arrange + email = 'jane@example.com' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is False + + def test_matches_base_email_different_domain(self): + """Test that same local part but different domain does not match.""" + # Arrange + email = 'joe@other.com' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is False + + def test_matches_base_email_case_insensitive(self): + """Test that matching is case-insensitive.""" + # Arrange + email = 'JOE+TEST@EXAMPLE.COM' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is True + + def test_matches_base_email_empty_strings(self): + """Test that empty strings return False.""" + # Arrange + email = '' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is False + + +class TestGetBaseEmailRegexPattern: + """Test cases for get_base_email_regex_pattern function.""" + + def test_get_base_email_regex_pattern_valid(self): + """Test generating valid regex pattern for base email.""" + # Arrange + base_email = 'joe@example.com' + + # Act + pattern = get_base_email_regex_pattern(base_email) + + # Assert + assert pattern is not None + assert isinstance(pattern, re.Pattern) + assert pattern.match('joe@example.com') is not None + assert pattern.match('joe+test@example.com') is not None + assert pattern.match('joe+openhands@example.com') is not None + + def test_get_base_email_regex_pattern_matches_plus_variant(self): + """Test that regex pattern matches + variant.""" + # Arrange + base_email = 'joe@example.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('joe+test@example.com') + + # Assert + assert match is not None + + def test_get_base_email_regex_pattern_rejects_different_base(self): + """Test that regex pattern rejects different base email.""" + # Arrange + base_email = 'joe@example.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('jane@example.com') + + # Assert + assert match is None + + def test_get_base_email_regex_pattern_rejects_different_domain(self): + """Test that regex pattern rejects different domain.""" + # Arrange + base_email = 'joe@example.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('joe@other.com') + + # Assert + assert match is None + + def test_get_base_email_regex_pattern_case_insensitive(self): + """Test that regex pattern is case-insensitive.""" + # Arrange + base_email = 'joe@example.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('JOE+TEST@EXAMPLE.COM') + + # Assert + assert match is not None + + def test_get_base_email_regex_pattern_special_characters(self): + """Test that regex pattern handles special characters in email.""" + # Arrange + base_email = 'user.name+tag@example-site.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('user.name+test@example-site.com') + + # Assert + assert match is not None + + def test_get_base_email_regex_pattern_invalid_base_email(self): + """Test that invalid base email returns None.""" + # Arrange + base_email = 'invalid-email' + + # Act + pattern = get_base_email_regex_pattern(base_email) + + # Assert + assert pattern is None diff --git a/enterprise/tests/unit/test_token_manager.py b/enterprise/tests/unit/test_token_manager.py index 413962d60c..0498ff1cb5 100644 --- a/enterprise/tests/unit/test_token_manager.py +++ b/enterprise/tests/unit/test_token_manager.py @@ -1,6 +1,8 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from keycloak.exceptions import KeycloakConnectionError, KeycloakError +from server.auth.token_manager import TokenManager from sqlalchemy.orm import Session from storage.offline_token_store import OfflineTokenStore from storage.stored_offline_token import StoredOfflineToken @@ -32,6 +34,14 @@ def token_store(mock_session_maker, mock_config): return OfflineTokenStore('test_user_id', mock_session_maker, mock_config) +@pytest.fixture +def token_manager(): + with patch('server.config.get_config') as mock_get_config: + mock_config = mock_get_config.return_value + mock_config.jwt_secret.get_secret_value.return_value = 'test_secret' + return TokenManager(external=False) + + @pytest.mark.asyncio async def test_store_token_new_record(token_store, mock_session): # Setup @@ -109,3 +119,419 @@ async def test_get_instance(mock_config): assert isinstance(result, OfflineTokenStore) assert result.user_id == test_user_id assert result.config == mock_config + + +class TestCheckDuplicateBaseEmail: + """Test cases for check_duplicate_base_email method.""" + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_no_plus_modifier(self, token_manager): + """Test that emails without + modifier are still checked for duplicates.""" + # Arrange + email = 'joe@example.com' + current_user_id = 'user123' + + with ( + patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query, + patch.object(token_manager, '_find_duplicate_in_users') as mock_find, + ): + mock_find.return_value = False + mock_query.return_value = {} + + # Act + result = await token_manager.check_duplicate_base_email( + email, current_user_id + ) + + # Assert + assert result is False + mock_query.assert_called_once() + mock_find.assert_called_once() + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_empty_email(self, token_manager): + """Test that empty email returns False.""" + # Arrange + email = '' + current_user_id = 'user123' + + # Act + result = await token_manager.check_duplicate_base_email(email, current_user_id) + + # Assert + assert result is False + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_invalid_email(self, token_manager): + """Test that invalid email returns False.""" + # Arrange + email = 'invalid-email' + current_user_id = 'user123' + + # Act + result = await token_manager.check_duplicate_base_email(email, current_user_id) + + # Assert + assert result is False + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_duplicate_found(self, token_manager): + """Test that duplicate email is detected when found.""" + # Arrange + email = 'joe+test@example.com' + current_user_id = 'user123' + existing_user = { + 'id': 'existing_user_id', + 'email': 'joe@example.com', + } + + with ( + patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query, + patch.object(token_manager, '_find_duplicate_in_users') as mock_find, + ): + mock_find.return_value = True + mock_query.return_value = {'existing_user_id': existing_user} + + # Act + result = await token_manager.check_duplicate_base_email( + email, current_user_id + ) + + # Assert + assert result is True + mock_query.assert_called_once() + mock_find.assert_called_once() + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_no_duplicate(self, token_manager): + """Test that no duplicate is found when none exists.""" + # Arrange + email = 'joe+test@example.com' + current_user_id = 'user123' + + with ( + patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query, + patch.object(token_manager, '_find_duplicate_in_users') as mock_find, + ): + mock_find.return_value = False + mock_query.return_value = {} + + # Act + result = await token_manager.check_duplicate_base_email( + email, current_user_id + ) + + # Assert + assert result is False + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_keycloak_connection_error( + self, token_manager + ): + """Test that KeycloakConnectionError triggers retry and raises RetryError.""" + # Arrange + email = 'joe+test@example.com' + current_user_id = 'user123' + + with patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query: + mock_query.side_effect = KeycloakConnectionError('Connection failed') + + # Act & Assert + # KeycloakConnectionError is re-raised, which triggers retry decorator + # After retries exhaust (2 attempts), it raises RetryError + from tenacity import RetryError + + with pytest.raises(RetryError): + await token_manager.check_duplicate_base_email(email, current_user_id) + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_general_exception(self, token_manager): + """Test that general exceptions are handled gracefully.""" + # Arrange + email = 'joe+test@example.com' + current_user_id = 'user123' + + with patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query: + mock_query.side_effect = Exception('Unexpected error') + + # Act + result = await token_manager.check_duplicate_base_email( + email, current_user_id + ) + + # Assert + assert result is False + + +class TestQueryUsersByWildcardPattern: + """Test cases for _query_users_by_wildcard_pattern method.""" + + @pytest.mark.asyncio + async def test_query_users_by_wildcard_pattern_success_with_search( + self, token_manager + ): + """Test successful query using search parameter.""" + # Arrange + local_part = 'joe' + domain = 'example.com' + mock_users = [ + {'id': 'user1', 'email': 'joe@example.com'}, + {'id': 'user2', 'email': 'joe+test@example.com'}, + ] + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + mock_admin.a_get_users = AsyncMock(return_value=mock_users) + mock_get_admin.return_value = mock_admin + + # Act + result = await token_manager._query_users_by_wildcard_pattern( + local_part, domain + ) + + # Assert + assert len(result) == 2 + assert 'user1' in result + assert 'user2' in result + mock_admin.a_get_users.assert_called_once_with( + {'search': 'joe*@example.com'} + ) + + @pytest.mark.asyncio + async def test_query_users_by_wildcard_pattern_fallback_to_q(self, token_manager): + """Test fallback to q parameter when search fails.""" + # Arrange + local_part = 'joe' + domain = 'example.com' + mock_users = [{'id': 'user1', 'email': 'joe@example.com'}] + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + # First call fails, second succeeds + mock_admin.a_get_users = AsyncMock( + side_effect=[Exception('Search failed'), mock_users] + ) + mock_get_admin.return_value = mock_admin + + # Act + result = await token_manager._query_users_by_wildcard_pattern( + local_part, domain + ) + + # Assert + assert len(result) == 1 + assert 'user1' in result + assert mock_admin.a_get_users.call_count == 2 + + @pytest.mark.asyncio + async def test_query_users_by_wildcard_pattern_empty_result(self, token_manager): + """Test query returns empty dict when no users found.""" + # Arrange + local_part = 'joe' + domain = 'example.com' + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + mock_admin.a_get_users = AsyncMock(return_value=[]) + mock_get_admin.return_value = mock_admin + + # Act + result = await token_manager._query_users_by_wildcard_pattern( + local_part, domain + ) + + # Assert + assert result == {} + + +class TestFindDuplicateInUsers: + """Test cases for _find_duplicate_in_users method.""" + + def test_find_duplicate_in_users_with_regex_match(self, token_manager): + """Test finding duplicate using regex pattern.""" + # Arrange + users = { + 'user1': {'id': 'user1', 'email': 'joe@example.com'}, + 'user2': {'id': 'user2', 'email': 'joe+test@example.com'}, + } + base_email = 'joe@example.com' + current_user_id = 'user3' + + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + assert result is True + + def test_find_duplicate_in_users_fallback_to_simple_matching(self, token_manager): + """Test fallback to simple matching when regex pattern is None.""" + # Arrange + users = { + 'user1': {'id': 'user1', 'email': 'joe@example.com'}, + } + base_email = 'invalid-email' # Will cause regex pattern to be None + current_user_id = 'user2' + + with patch( + 'server.auth.token_manager.get_base_email_regex_pattern', return_value=None + ): + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + # Should use fallback matching, but invalid base_email won't match + assert result is False + + def test_find_duplicate_in_users_excludes_current_user(self, token_manager): + """Test that current user is excluded from duplicate check.""" + # Arrange + users = { + 'user1': {'id': 'user1', 'email': 'joe@example.com'}, + } + base_email = 'joe@example.com' + current_user_id = 'user1' # Same as user in users dict + + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + assert result is False + + def test_find_duplicate_in_users_no_match(self, token_manager): + """Test that no duplicate is found when emails don't match.""" + # Arrange + users = { + 'user1': {'id': 'user1', 'email': 'jane@example.com'}, + } + base_email = 'joe@example.com' + current_user_id = 'user2' + + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + assert result is False + + def test_find_duplicate_in_users_empty_dict(self, token_manager): + """Test that empty users dict returns False.""" + # Arrange + users: dict[str, dict] = {} + base_email = 'joe@example.com' + current_user_id = 'user1' + + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + assert result is False + + +class TestDeleteKeycloakUser: + """Test cases for delete_keycloak_user method.""" + + @pytest.mark.asyncio + async def test_delete_keycloak_user_success(self, token_manager): + """Test successful deletion of Keycloak user.""" + # Arrange + user_id = 'test_user_id' + + with ( + patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin, + patch('asyncio.to_thread') as mock_to_thread, + ): + mock_admin = MagicMock() + mock_admin.delete_user = MagicMock() + mock_get_admin.return_value = mock_admin + mock_to_thread.return_value = None + + # Act + result = await token_manager.delete_keycloak_user(user_id) + + # Assert + assert result is True + mock_to_thread.assert_called_once_with(mock_admin.delete_user, user_id) + + @pytest.mark.asyncio + async def test_delete_keycloak_user_connection_error(self, token_manager): + """Test handling of KeycloakConnectionError triggers retry and raises RetryError.""" + # Arrange + user_id = 'test_user_id' + + with ( + patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin, + patch('asyncio.to_thread') as mock_to_thread, + ): + mock_admin = MagicMock() + mock_admin.delete_user = MagicMock() + mock_get_admin.return_value = mock_admin + mock_to_thread.side_effect = KeycloakConnectionError('Connection failed') + + # Act & Assert + # KeycloakConnectionError triggers retry decorator + # After retries exhaust (2 attempts), it raises RetryError + from tenacity import RetryError + + with pytest.raises(RetryError): + await token_manager.delete_keycloak_user(user_id) + + @pytest.mark.asyncio + async def test_delete_keycloak_user_keycloak_error(self, token_manager): + """Test handling of KeycloakError (e.g., user not found).""" + # Arrange + user_id = 'test_user_id' + + with ( + patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin, + patch('asyncio.to_thread') as mock_to_thread, + ): + mock_admin = MagicMock() + mock_admin.delete_user = MagicMock() + mock_get_admin.return_value = mock_admin + mock_to_thread.side_effect = KeycloakError('User not found') + + # Act + result = await token_manager.delete_keycloak_user(user_id) + + # Assert + assert result is False + + @pytest.mark.asyncio + async def test_delete_keycloak_user_general_exception(self, token_manager): + """Test handling of general exceptions.""" + # Arrange + user_id = 'test_user_id' + + with ( + patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin, + patch('asyncio.to_thread') as mock_to_thread, + ): + mock_admin = MagicMock() + mock_admin.delete_user = MagicMock() + mock_get_admin.return_value = mock_admin + mock_to_thread.side_effect = Exception('Unexpected error') + + # Act + result = await token_manager.delete_keycloak_user(user_id) + + # Assert + assert result is False diff --git a/frontend/__tests__/MSW.md b/frontend/__tests__/MSW.md new file mode 100644 index 0000000000..f240c5a8df --- /dev/null +++ b/frontend/__tests__/MSW.md @@ -0,0 +1,146 @@ +# Mock Service Worker (MSW) Guide + +## Overview + +[Mock Service Worker (MSW)](https://mswjs.io/) is an API mocking library that intercepts outgoing network requests at the network level. Unlike traditional mocking that patches `fetch` or `axios`, MSW uses a Service Worker in the browser and direct request interception in Node.js—making mocks transparent to your application code. + +We use MSW in this project for: +- **Testing**: Write reliable unit and integration tests without real network calls +- **Development**: Run the frontend with mocked APIs when the backend isn't available or when working on features with pending backend APIs + +The same mock handlers work in both environments, so you write them once and reuse everywhere. + +## Relevant Files + +- `src/mocks/handlers.ts` - Main handler registry that combines all domain handlers +- `src/mocks/*-handlers.ts` - Domain-specific handlers (auth, billing, conversation, etc.) +- `src/mocks/browser.ts` - Browser setup for development mode +- `src/mocks/node.ts` - Node.js setup for tests +- `vitest.setup.ts` - Global test setup with MSW lifecycle hooks + +## Development Workflow + +### Running with Mocked APIs + +```sh +# Run with API mocking enabled +npm run dev:mock + +# Run with API mocking + SaaS mode simulation +npm run dev:mock:saas +``` + +These commands set `VITE_MOCK_API=true` which activates the MSW Service Worker to intercept requests. + +> [!NOTE] +> **OSS vs SaaS Mode** +> +> OpenHands runs in two modes: +> - **OSS mode**: For local/self-hosted deployments where users provide their own LLM API keys and configure git providers manually +> - **SaaS mode**: For the cloud offering with billing, managed API keys, and OAuth-based GitHub integration +> +> Use `dev:mock:saas` when working on SaaS-specific features like billing, API key management, or subscription flows. + + +## Writing Tests + +### Service Layer Mocking (Recommended) + +For most tests, mock at the service layer using `vi.spyOn`. This approach is explicit, test-scoped, and makes the scenario being tested clear. + +```typescript +import { vi } from "vitest"; +import SettingsService from "#/api/settings-service/settings-service.api"; + +const getSettingsSpy = vi.spyOn(SettingsService, "getSettings"); +getSettingsSpy.mockResolvedValue({ + llm_model: "openai/gpt-4o", + llm_api_key_set: true, + // ... other settings +}); +``` + +Use `mockResolvedValue` for success scenarios and `mockRejectedValue` for error scenarios: + +```typescript +getSettingsSpy.mockRejectedValue(new Error("Failed to fetch settings")); +``` + +### Network Layer Mocking (Advanced) + +For tests that need actual network-level behavior (WebSockets, testing retry logic, etc.), use `server.use()` to override handlers per test. + +> [!IMPORTANT] +> **Reuse the global server instance** - Don't create new `setupServer()` calls in individual tests. The project already has a global MSW server configured in `vitest.setup.ts` that handles lifecycle (`server.listen()`, `server.resetHandlers()`, `server.close()`). Use `server.use()` to add runtime handlers for specific test scenarios. + +```typescript +import { http, HttpResponse } from "msw"; +import { server } from "#/mocks/node"; + +it("should handle server errors", async () => { + server.use( + http.get("/api/my-endpoint", () => { + return new HttpResponse(null, { status: 500 }); + }), + ); + // ... test code +}); +``` + +For WebSocket testing, see `__tests__/helpers/msw-websocket-setup.ts` for utilities. + +## Adding New API Mocks + +When adding new API endpoints, create mocks in both places to maintain 1:1 similarity with the backend: + +### 1. Add to `src/mocks/` (for development) + +Create or update a domain-specific handler file: + +```typescript +// src/mocks/my-feature-handlers.ts +import { http, HttpResponse } from "msw"; + +export const MY_FEATURE_HANDLERS = [ + http.get("/api/my-feature", () => { + return HttpResponse.json({ + data: "mock response", + }); + }), +]; +``` + +Register in `handlers.ts`: + +```typescript +import { MY_FEATURE_HANDLERS } from "./my-feature-handlers"; + +export const handlers = [ + // ... existing handlers + ...MY_FEATURE_HANDLERS, +]; +``` + +### 2. Mock in tests for specific scenarios + +In your test files, spy on the service method to control responses per test case: + +```typescript +import { vi } from "vitest"; +import MyFeatureService from "#/api/my-feature-service.api"; + +const spy = vi.spyOn(MyFeatureService, "getData"); +spy.mockResolvedValue({ data: "test-specific response" }); +``` + +See `__tests__/routes/llm-settings.test.tsx` for a real-world example of service layer mocking. + +> [!TIP] +> For guidance on creating service APIs, see `src/api/README.md`. + +## Best Practices + +- **Keep mocks close to real API contracts** - Update mocks when backend changes +- **Use service layer mocking for most tests** - It's simpler and more explicit +- **Reserve network layer mocking for integration tests** - WebSockets, retry logic, etc. +- **Export mock data from handler files** - Reuse in tests (e.g., `MOCK_DEFAULT_USER_SETTINGS`) diff --git a/frontend/__tests__/components/features/auth-modal.test.tsx b/frontend/__tests__/components/features/auth-modal.test.tsx index 32b682d506..30550f7106 100644 --- a/frontend/__tests__/components/features/auth-modal.test.tsx +++ b/frontend/__tests__/components/features/auth-modal.test.tsx @@ -1,6 +1,7 @@ import { render, screen } from "@testing-library/react"; import { it, describe, expect, vi, beforeEach, afterEach } from "vitest"; import userEvent from "@testing-library/user-event"; +import { MemoryRouter } from "react-router"; import { AuthModal } from "#/components/features/waitlist/auth-modal"; // Mock the useAuthUrl hook @@ -27,11 +28,13 @@ describe("AuthModal", () => { it("should render the GitHub and GitLab buttons", () => { render( - , + + + , ); const githubButton = screen.getByRole("button", { @@ -49,11 +52,13 @@ describe("AuthModal", () => { const user = userEvent.setup(); const mockUrl = "https://github.com/login/oauth/authorize"; render( - , + + + , ); const githubButton = screen.getByRole("button", { @@ -65,10 +70,14 @@ describe("AuthModal", () => { }); it("should render Terms of Service and Privacy Policy text with correct links", () => { - render(); + render( + + + , + ); // Find the terms of service section using data-testid - const termsSection = screen.getByTestId("auth-modal-terms-of-service"); + const termsSection = screen.getByTestId("terms-and-privacy-notice"); expect(termsSection).toBeInTheDocument(); // Check that all text content is present in the paragraph @@ -105,8 +114,44 @@ describe("AuthModal", () => { expect(termsSection).toContainElement(privacyLink); }); + it("should display email verified message when emailVerified prop is true", () => { + render( + + + , + ); + + expect( + screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"), + ).toBeInTheDocument(); + }); + + it("should not display email verified message when emailVerified prop is false", () => { + render( + + + , + ); + + expect( + screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"), + ).not.toBeInTheDocument(); + }); + it("should open Terms of Service link in new tab", () => { - render(); + render( + + + , + ); const tosLink = screen.getByRole("link", { name: "COMMON$TERMS_OF_SERVICE", @@ -115,11 +160,58 @@ describe("AuthModal", () => { }); it("should open Privacy Policy link in new tab", () => { - render(); + render( + + + , + ); const privacyLink = screen.getByRole("link", { name: "COMMON$PRIVACY_POLICY", }); expect(privacyLink).toHaveAttribute("target", "_blank"); }); + + describe("Duplicate email error message", () => { + const renderAuthModalWithRouter = (initialEntries: string[]) => { + const hasDuplicatedEmail = initialEntries.includes( + "/?duplicated_email=true", + ); + + return render( + + + , + ); + }; + + it("should display error message when duplicated_email query parameter is true", () => { + // Arrange + const initialEntries = ["/?duplicated_email=true"]; + + // Act + renderAuthModalWithRouter(initialEntries); + + // Assert + const errorMessage = screen.getByText("AUTH$DUPLICATE_EMAIL_ERROR"); + expect(errorMessage).toBeInTheDocument(); + }); + + it("should not display error message when duplicated_email query parameter is missing", () => { + // Arrange + const initialEntries = ["/"]; + + // Act + renderAuthModalWithRouter(initialEntries); + + // Assert + const errorMessage = screen.queryByText("AUTH$DUPLICATE_EMAIL_ERROR"); + expect(errorMessage).not.toBeInTheDocument(); + }); + }); }); diff --git a/frontend/__tests__/components/features/waitlist/email-verification-modal.test.tsx b/frontend/__tests__/components/features/waitlist/email-verification-modal.test.tsx new file mode 100644 index 0000000000..e773461d84 --- /dev/null +++ b/frontend/__tests__/components/features/waitlist/email-verification-modal.test.tsx @@ -0,0 +1,28 @@ +import { render, screen } from "@testing-library/react"; +import { it, describe, expect, vi, beforeEach } from "vitest"; +import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal"; + +describe("EmailVerificationModal", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should render the email verification message", () => { + // Arrange & Act + render(); + + // Assert + expect( + screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"), + ).toBeInTheDocument(); + }); + + it("should render the TermsAndPrivacyNotice component", () => { + // Arrange & Act + render(); + + // Assert + const termsSection = screen.getByTestId("terms-and-privacy-notice"); + expect(termsSection).toBeInTheDocument(); + }); +}); diff --git a/frontend/__tests__/components/shared/terms-and-privacy-notice.test.tsx b/frontend/__tests__/components/shared/terms-and-privacy-notice.test.tsx new file mode 100644 index 0000000000..559a7f0df6 --- /dev/null +++ b/frontend/__tests__/components/shared/terms-and-privacy-notice.test.tsx @@ -0,0 +1,48 @@ +import { render, screen } from "@testing-library/react"; +import { it, describe, expect } from "vitest"; +import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-notice"; + +describe("TermsAndPrivacyNotice", () => { + it("should render Terms of Service and Privacy Policy links", () => { + // Arrange & Act + render(); + + // Assert + const termsSection = screen.getByTestId("terms-and-privacy-notice"); + expect(termsSection).toBeInTheDocument(); + + const tosLink = screen.getByRole("link", { + name: "COMMON$TERMS_OF_SERVICE", + }); + const privacyLink = screen.getByRole("link", { + name: "COMMON$PRIVACY_POLICY", + }); + + expect(tosLink).toBeInTheDocument(); + expect(tosLink).toHaveAttribute("href", "https://www.all-hands.dev/tos"); + expect(tosLink).toHaveAttribute("target", "_blank"); + expect(tosLink).toHaveAttribute("rel", "noopener noreferrer"); + + expect(privacyLink).toBeInTheDocument(); + expect(privacyLink).toHaveAttribute( + "href", + "https://www.all-hands.dev/privacy", + ); + expect(privacyLink).toHaveAttribute("target", "_blank"); + expect(privacyLink).toHaveAttribute("rel", "noopener noreferrer"); + }); + + it("should render all required text content", () => { + // Arrange & Act + render(); + + // Assert + const termsSection = screen.getByTestId("terms-and-privacy-notice"); + expect(termsSection).toHaveTextContent( + "AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR", + ); + expect(termsSection).toHaveTextContent("COMMON$TERMS_OF_SERVICE"); + expect(termsSection).toHaveTextContent("COMMON$AND"); + expect(termsSection).toHaveTextContent("COMMON$PRIVACY_POLICY"); + }); +}); diff --git a/frontend/__tests__/router.md b/frontend/__tests__/router.md new file mode 100644 index 0000000000..b23b4364e7 --- /dev/null +++ b/frontend/__tests__/router.md @@ -0,0 +1,227 @@ +# Testing with React Router + +## Overview + +React Router components and hooks require a routing context to function. In tests, we need to provide this context while maintaining control over the routing state. + +This guide covers the two main approaches used in the OpenHands frontend: + +1. **`createRoutesStub`** - Creates a complete route structure for testing components with their actual route configuration, loaders, and nested routes. +2. **`MemoryRouter`** - Provides a minimal routing context for components that just need router hooks to work. + +Choose your approach based on what your component actually needs from the router. + +## When to Use Each Approach + +### `createRoutesStub` (Recommended) + +Use `createRoutesStub` when your component: +- Relies on route parameters (`useParams`) +- Uses loader data (`useLoaderData`) or `clientLoader` +- Has nested routes or uses `` +- Needs to test navigation between routes + +> [!NOTE] +> `createRoutesStub` is intended for unit testing **reusable components** that depend on router context. For testing full route/page components, consider E2E tests (Playwright, Cypress) instead. + +```typescript +import { createRoutesStub } from "react-router"; +import { render } from "@testing-library/react"; + +const RouterStub = createRoutesStub([ + { + Component: MyRouteComponent, + path: "/conversations/:conversationId", + }, +]); + +render(); +``` + +**With nested routes and loaders:** + +```typescript +const RouterStub = createRoutesStub([ + { + Component: SettingsScreen, + clientLoader, + path: "/settings", + children: [ + { + Component: () =>
, + path: "/settings", + }, + { + Component: () =>
, + path: "/settings/integrations", + }, + ], + }, +]); + +render(); +``` + +> [!TIP] +> When using `clientLoader` from a Route module, you may encounter type mismatches. Use `@ts-expect-error` as a workaround: + +```typescript +import { clientLoader } from "@/routes/settings"; + +const RouterStub = createRoutesStub([ + { + path: "/settings", + Component: SettingsScreen, + // @ts-expect-error: loader types won't align between test and app code + loader: clientLoader, + }, +]); +``` + +### `MemoryRouter` + +Use `MemoryRouter` when your component: +- Only needs basic routing context to render +- Uses `` components but you don't need to test navigation +- Doesn't depend on specific route parameters or loaders + +```typescript +import { MemoryRouter } from "react-router"; +import { render } from "@testing-library/react"; + +render( + + + +); +``` + +**With initial route:** + +```typescript +render( + + + +); +``` + +## Anti-patterns to Avoid + +### Using `BrowserRouter` in tests + +`BrowserRouter` interacts with the actual browser history API, which can cause issues in test environments: + +```typescript +// ❌ Avoid +render( + + + +); + +// ✅ Use MemoryRouter instead +render( + + + +); +``` + +### Mocking router hooks when `createRoutesStub` would work + +Mocking hooks like `useParams` directly can be brittle and doesn't test the actual routing behavior: + +```typescript +// ❌ Avoid when possible +vi.mock("react-router", async () => { + const actual = await vi.importActual("react-router"); + return { + ...actual, + useParams: () => ({ conversationId: "123" }), + }; +}); + +// ✅ Prefer createRoutesStub - tests real routing behavior +const RouterStub = createRoutesStub([ + { + Component: MyComponent, + path: "/conversations/:conversationId", + }, +]); + +render(); +``` + +## Common Patterns + +### Combining with `QueryClientProvider` + +Many components need both routing and TanStack Query context: + +```typescript +import { createRoutesStub } from "react-router"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; + +const queryClient = new QueryClient({ + defaultOptions: { + queries: { retry: false }, + }, +}); + +const RouterStub = createRoutesStub([ + { + Component: MyComponent, + path: "/", + }, +]); + +render(, { + wrapper: ({ children }) => ( + + {children} + + ), +}); +``` + +### Testing navigation behavior + +Verify that user interactions trigger the expected navigation: + +```typescript +import { createRoutesStub } from "react-router"; +import { screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; + +const RouterStub = createRoutesStub([ + { + Component: HomeScreen, + path: "/", + }, + { + Component: () =>
, + path: "/settings", + }, +]); + +render(); + +const user = userEvent.setup(); +await user.click(screen.getByRole("link", { name: /settings/i })); + +expect(screen.getByTestId("settings-screen")).toBeInTheDocument(); +``` + +## See Also + +### Codebase Examples + +- [settings.test.tsx](__tests__/routes/settings.test.tsx) - `createRoutesStub` with nested routes and loaders +- [home-screen.test.tsx](__tests__/routes/home-screen.test.tsx) - `createRoutesStub` with navigation testing +- [chat-interface.test.tsx](__tests__/components/chat/chat-interface.test.tsx) - `MemoryRouter` usage + +### Official Documentation + +- [React Router Testing Guide](https://reactrouter.com/start/framework/testing) - Official guide on testing with `createRoutesStub` +- [MemoryRouter API](https://reactrouter.com/api/declarative-routers/MemoryRouter) - API reference for `MemoryRouter` diff --git a/frontend/__tests__/routes/root-layout.test.tsx b/frontend/__tests__/routes/root-layout.test.tsx new file mode 100644 index 0000000000..22de4ae616 --- /dev/null +++ b/frontend/__tests__/routes/root-layout.test.tsx @@ -0,0 +1,242 @@ +import { render, screen, waitFor } from "@testing-library/react"; +import { it, describe, expect, vi, beforeEach, afterEach } from "vitest"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { createRoutesStub } from "react-router"; +import MainApp from "#/routes/root-layout"; +import OptionService from "#/api/option-service/option-service.api"; +import AuthService from "#/api/auth-service/auth-service.api"; +import SettingsService from "#/api/settings-service/settings-service.api"; + +// Mock other hooks that are not the focus of these tests +vi.mock("#/hooks/use-github-auth-url", () => ({ + useGitHubAuthUrl: () => "https://github.com/oauth/authorize", +})); + +vi.mock("#/hooks/use-is-on-tos-page", () => ({ + useIsOnTosPage: () => false, +})); + +vi.mock("#/hooks/use-auto-login", () => ({ + useAutoLogin: () => {}, +})); + +vi.mock("#/hooks/use-auth-callback", () => ({ + useAuthCallback: () => {}, +})); + +vi.mock("#/hooks/use-migrate-user-consent", () => ({ + useMigrateUserConsent: () => ({ + migrateUserConsent: vi.fn(), + }), +})); + +vi.mock("#/hooks/use-reo-tracking", () => ({ + useReoTracking: () => {}, +})); + +vi.mock("#/hooks/use-sync-posthog-consent", () => ({ + useSyncPostHogConsent: () => {}, +})); + +vi.mock("#/utils/custom-toast-handlers", () => ({ + displaySuccessToast: vi.fn(), +})); + +const RouterStub = createRoutesStub([ + { + Component: MainApp, + path: "/", + children: [ + { + Component: () =>
Content
, + path: "/", + }, + ], + }, +]); + +const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }); + + return ({ children }: { children: React.ReactNode }) => ( + {children} + ); +}; + +describe("MainApp - Email Verification Flow", () => { + beforeEach(() => { + vi.clearAllMocks(); + + // Default mocks for services + vi.spyOn(OptionService, "getConfig").mockResolvedValue({ + APP_MODE: "saas", + GITHUB_CLIENT_ID: "test-client-id", + POSTHOG_CLIENT_KEY: "test-posthog-key", + PROVIDERS_CONFIGURED: ["github"], + AUTH_URL: "https://auth.example.com", + FEATURE_FLAGS: { + ENABLE_BILLING: false, + HIDE_LLM_SETTINGS: false, + ENABLE_JIRA: false, + ENABLE_JIRA_DC: false, + ENABLE_LINEAR: false, + }, + }); + + vi.spyOn(AuthService, "authenticate").mockResolvedValue(true); + + vi.spyOn(SettingsService, "getSettings").mockResolvedValue({ + language: "en", + user_consents_to_analytics: true, + llm_model: "", + llm_base_url: "", + agent: "", + llm_api_key: null, + llm_api_key_set: false, + search_api_key_set: false, + confirmation_mode: false, + security_analyzer: null, + remote_runtime_resource_factor: null, + provider_tokens_set: {}, + enable_default_condenser: false, + condenser_max_size: null, + enable_sound_notifications: false, + enable_proactive_conversation_starters: false, + enable_solvability_analysis: false, + max_budget_per_task: null, + }); + + // Mock localStorage + vi.stubGlobal("localStorage", { + getItem: vi.fn(() => null), + setItem: vi.fn(), + removeItem: vi.fn(), + clear: vi.fn(), + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + vi.unstubAllGlobals(); + }); + + it("should display EmailVerificationModal when email_verification_required=true is in query params", async () => { + // Arrange & Act + render( + , + { wrapper: createWrapper() }, + ); + + // Assert + await waitFor(() => { + expect( + screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"), + ).toBeInTheDocument(); + }); + }); + + it("should set emailVerified state and pass to AuthModal when email_verified=true is in query params", async () => { + // Arrange + // Mock a 401 error to simulate unauthenticated user + const axiosError = { + response: { status: 401 }, + isAxiosError: true, + }; + vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError); + + // Act + render(, { + wrapper: createWrapper(), + }); + + // Assert - Wait for AuthModal to render (since user is not authenticated) + await waitFor(() => { + expect( + screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"), + ).toBeInTheDocument(); + }); + }); + + it("should handle both email_verification_required and email_verified params together", async () => { + // Arrange & Act + render( + , + { wrapper: createWrapper() }, + ); + + // Assert - EmailVerificationModal should take precedence + await waitFor(() => { + expect( + screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"), + ).toBeInTheDocument(); + }); + }); + + it("should remove query parameters from URL after processing", async () => { + // Arrange & Act + const { container } = render( + , + { wrapper: createWrapper() }, + ); + + // Assert - Wait for the modal to appear (which indicates processing happened) + await waitFor(() => { + expect( + screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"), + ).toBeInTheDocument(); + }); + + // Verify that the query parameter was processed by checking the modal appeared + // The hook removes the parameter from the URL, so we verify the behavior indirectly + expect(container).toBeInTheDocument(); + }); + + it("should not display EmailVerificationModal when email_verification_required is not in query params", async () => { + // Arrange - No query params set + + // Act + render(, { wrapper: createWrapper() }); + + // Assert + await waitFor(() => { + expect( + screen.queryByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"), + ).not.toBeInTheDocument(); + }); + }); + + it("should not display email verified message when email_verified is not in query params", async () => { + // Arrange + // Mock a 401 error to simulate unauthenticated user + const axiosError = { + response: { status: 401 }, + isAxiosError: true, + }; + vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError); + + // Act + render(, { wrapper: createWrapper() }); + + // Assert - AuthModal should render but without email verified message + await waitFor(() => { + const authModal = screen.queryByText( + "AUTH$SIGN_IN_WITH_IDENTITY_PROVIDER", + ); + if (authModal) { + expect( + screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"), + ).not.toBeInTheDocument(); + } + }); + }); +}); diff --git a/frontend/src/api/README.md b/frontend/src/api/README.md new file mode 100644 index 0000000000..a44c7c49ca --- /dev/null +++ b/frontend/src/api/README.md @@ -0,0 +1,102 @@ +# API Services Guide + +## Overview + +Services are the abstraction layer between frontend components and backend APIs. They encapsulate HTTP requests using the shared `openHands` axios instance and provide typed methods for each endpoint. + +Each service is a plain object with async methods. + +## Structure + +Each service lives in its own directory: + +``` +src/api/ +├── billing-service/ +│ ├── billing-service.api.ts # Service methods +│ └── billing.types.ts # Types and interfaces +├── organization-service/ +│ ├── organization-service.api.ts +│ └── organization.types.ts +└── open-hands-axios.ts # Shared axios instance +``` + +## Creating a Service + +Use an object literal with named export. Use object destructuring for parameters to make calls self-documenting. + +```typescript +// feature-service/feature-service.api.ts +import { openHands } from "../open-hands-axios"; +import { Feature, CreateFeatureParams } from "./feature.types"; + +export const featureService = { + getFeature: async ({ id }: { id: string }) => { + const { data } = await openHands.get(`/api/features/${id}`); + return data; + }, + + createFeature: async ({ name, description }: CreateFeatureParams) => { + const { data } = await openHands.post("/api/features", { + name, + description, + }); + return data; + }, +}; +``` + +### Types + +Define types in a separate file within the same directory: + +```typescript +// feature-service/feature.types.ts +export interface Feature { + id: string; + name: string; + description: string; +} + +export interface CreateFeatureParams { + name: string; + description: string; +} +``` + +## Usage + +> [!IMPORTANT] +> **Don't call services directly in components.** Wrap them in TanStack Query hooks. +> +> Why? TanStack Query provides: +> - **Caching** - Avoid redundant network requests +> - **Deduplication** - Multiple components requesting the same data share one request +> - **Loading/error states** - Built-in `isLoading`, `isError`, `data` states +> - **Background refetching** - Data stays fresh automatically +> +> Hooks location: +> - `src/hooks/query/` for data fetching (`useQuery`) +> - `src/hooks/mutation/` for writes/updates (`useMutation`) + +```typescript +// src/hooks/query/use-feature.ts +import { useQuery } from "@tanstack/react-query"; +import { featureService } from "#/api/feature-service/feature-service.api"; + +export const useFeature = (id: string) => { + return useQuery({ + queryKey: ["feature", id], + queryFn: () => featureService.getFeature({ id }), + }); +}; +``` + +## Naming Conventions + +| Item | Convention | Example | +|------|------------|---------| +| Directory | `feature-service/` | `billing-service/` | +| Service file | `feature-service.api.ts` | `billing-service.api.ts` | +| Types file | `feature.types.ts` | `billing.types.ts` | +| Export name | `featureService` | `billingService` | diff --git a/frontend/src/components/features/waitlist/auth-modal.tsx b/frontend/src/components/features/waitlist/auth-modal.tsx index 2c431fbd95..6d92cb4dfc 100644 --- a/frontend/src/components/features/waitlist/auth-modal.tsx +++ b/frontend/src/components/features/waitlist/auth-modal.tsx @@ -13,12 +13,15 @@ import { useAuthUrl } from "#/hooks/use-auth-url"; import { GetConfigResponse } from "#/api/option-service/option.types"; import { Provider } from "#/types/settings"; import { useTracking } from "#/hooks/use-tracking"; +import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-notice"; interface AuthModalProps { githubAuthUrl: string | null; appMode?: GetConfigResponse["APP_MODE"] | null; authUrl?: GetConfigResponse["AUTH_URL"]; providersConfigured?: Provider[]; + emailVerified?: boolean; + hasDuplicatedEmail?: boolean; } export function AuthModal({ @@ -26,6 +29,8 @@ export function AuthModal({ appMode, authUrl, providersConfigured, + emailVerified = false, + hasDuplicatedEmail = false, }: AuthModalProps) { const { t } = useTranslation(); const { trackLoginButtonClick } = useTracking(); @@ -123,6 +128,18 @@ export function AuthModal({ + {emailVerified && ( +
+

+ {t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)} +

+
+ )} + {hasDuplicatedEmail && ( +
+ {t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)} +
+ )}

{t(I18nKey.AUTH$SIGN_IN_WITH_IDENTITY_PROVIDER)} @@ -198,30 +215,7 @@ export function AuthModal({ )}

-

- {t(I18nKey.AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR)}{" "} - - {t(I18nKey.COMMON$TERMS_OF_SERVICE)} - {" "} - {t(I18nKey.COMMON$AND)}{" "} - - {t(I18nKey.COMMON$PRIVACY_POLICY)} - - . -

+
); diff --git a/frontend/src/components/features/waitlist/email-verification-modal.tsx b/frontend/src/components/features/waitlist/email-verification-modal.tsx new file mode 100644 index 0000000000..820dce3258 --- /dev/null +++ b/frontend/src/components/features/waitlist/email-verification-modal.tsx @@ -0,0 +1,31 @@ +import { useTranslation } from "react-i18next"; +import { I18nKey } from "#/i18n/declaration"; +import OpenHandsLogo from "#/assets/branding/openhands-logo.svg?react"; +import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop"; +import { ModalBody } from "#/components/shared/modals/modal-body"; +import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-notice"; + +interface EmailVerificationModalProps { + onClose: () => void; +} + +export function EmailVerificationModal({ + onClose, +}: EmailVerificationModalProps) { + const { t } = useTranslation(); + + return ( + + + +
+

+ {t(I18nKey.AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY)} +

+
+ + +
+
+ ); +} diff --git a/frontend/src/components/shared/terms-and-privacy-notice.tsx b/frontend/src/components/shared/terms-and-privacy-notice.tsx new file mode 100644 index 0000000000..8293d734da --- /dev/null +++ b/frontend/src/components/shared/terms-and-privacy-notice.tsx @@ -0,0 +1,37 @@ +import React from "react"; +import { useTranslation } from "react-i18next"; +import { I18nKey } from "#/i18n/declaration"; + +interface TermsAndPrivacyNoticeProps { + className?: string; +} + +export function TermsAndPrivacyNotice({ + className = "mt-4 text-xs text-center text-muted-foreground", +}: TermsAndPrivacyNoticeProps) { + const { t } = useTranslation(); + + return ( +

+ {t(I18nKey.AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR)}{" "} + + {t(I18nKey.COMMON$TERMS_OF_SERVICE)} + {" "} + {t(I18nKey.COMMON$AND)}{" "} + + {t(I18nKey.COMMON$PRIVACY_POLICY)} + + . +

+ ); +} diff --git a/frontend/src/hooks/use-email-verification.ts b/frontend/src/hooks/use-email-verification.ts new file mode 100644 index 0000000000..c0068395b5 --- /dev/null +++ b/frontend/src/hooks/use-email-verification.ts @@ -0,0 +1,63 @@ +import React from "react"; +import { useSearchParams } from "react-router"; + +/** + * Hook to handle email verification logic from URL query parameters. + * Manages the email verification modal state and email verified state + * based on query parameters in the URL. + * + * @returns An object containing: + * - emailVerificationModalOpen: boolean state for modal visibility + * - setEmailVerificationModalOpen: function to control modal visibility + * - emailVerified: boolean state for email verification status + * - setEmailVerified: function to control email verification status + * - hasDuplicatedEmail: boolean state for duplicate email error status + */ +export function useEmailVerification() { + const [searchParams, setSearchParams] = useSearchParams(); + const [emailVerificationModalOpen, setEmailVerificationModalOpen] = + React.useState(false); + const [emailVerified, setEmailVerified] = React.useState(false); + const [hasDuplicatedEmail, setHasDuplicatedEmail] = React.useState(false); + + // Check for email verification query parameters + React.useEffect(() => { + const emailVerificationRequired = searchParams.get( + "email_verification_required", + ); + const emailVerifiedParam = searchParams.get("email_verified"); + const duplicatedEmailParam = searchParams.get("duplicated_email"); + let shouldUpdate = false; + + if (emailVerificationRequired === "true") { + setEmailVerificationModalOpen(true); + searchParams.delete("email_verification_required"); + shouldUpdate = true; + } + + if (emailVerifiedParam === "true") { + setEmailVerified(true); + searchParams.delete("email_verified"); + shouldUpdate = true; + } + + if (duplicatedEmailParam === "true") { + setHasDuplicatedEmail(true); + searchParams.delete("duplicated_email"); + shouldUpdate = true; + } + + // Clean up the URL by removing parameters if any were found + if (shouldUpdate) { + setSearchParams(searchParams, { replace: true }); + } + }, [searchParams, setSearchParams]); + + return { + emailVerificationModalOpen, + setEmailVerificationModalOpen, + emailVerified, + setEmailVerified, + hasDuplicatedEmail, + }; +} diff --git a/frontend/src/i18n/declaration.ts b/frontend/src/i18n/declaration.ts index 1b330730d9..e3ed93db2f 100644 --- a/frontend/src/i18n/declaration.ts +++ b/frontend/src/i18n/declaration.ts @@ -730,6 +730,9 @@ export enum I18nKey { MICROAGENT_MANAGEMENT$USE_MICROAGENTS = "MICROAGENT_MANAGEMENT$USE_MICROAGENTS", AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR = "AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR", AUTH$NO_PROVIDERS_CONFIGURED = "AUTH$NO_PROVIDERS_CONFIGURED", + AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY = "AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY", + AUTH$EMAIL_VERIFIED_PLEASE_LOGIN = "AUTH$EMAIL_VERIFIED_PLEASE_LOGIN", + AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR", COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE", COMMON$AND = "COMMON$AND", COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY", diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index a421de5ddf..81df3b6f7d 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -11679,6 +11679,54 @@ "de": "Mindestens ein Identitätsanbieter muss konfiguriert werden (z.B. GitHub)", "uk": "Принаймні один постачальник ідентифікації має бути налаштований (наприклад, GitHub)" }, + "AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY": { + "en": "Please check your email to verify your account.", + "ja": "アカウントを確認するためにメールを確認してください。", + "zh-CN": "请检查您的电子邮件以验证您的账户。", + "zh-TW": "請檢查您的電子郵件以驗證您的帳戶。", + "ko-KR": "계정을 확인하려면 이메일을 확인하세요.", + "no": "Vennligst sjekk e-posten din for å bekrefte kontoen din.", + "it": "Controlla la tua email per verificare il tuo account.", + "pt": "Por favor, verifique seu e-mail para verificar sua conta.", + "es": "Por favor, verifica tu correo electrónico para verificar tu cuenta.", + "ar": "يرجى التحقق من بريدك الإلكتروني للتحقق من حسابك.", + "fr": "Veuillez vérifier votre e-mail pour vérifier votre compte.", + "tr": "Hesabınızı doğrulamak için lütfen e-postanızı kontrol edin.", + "de": "Bitte überprüfen Sie Ihre E-Mail, um Ihr Konto zu verifizieren.", + "uk": "Будь ласка, перевірте свою електронну пошту, щоб підтвердити свій обліковий запис." + }, + "AUTH$EMAIL_VERIFIED_PLEASE_LOGIN": { + "en": "Your email has been verified. Please login below.", + "ja": "メールアドレスが確認されました。下記からログインしてください。", + "zh-CN": "您的电子邮件已验证。请在下方登录。", + "zh-TW": "您的電子郵件已驗證。請在下方登錄。", + "ko-KR": "이메일이 확인되었습니다. 아래에서 로그인하세요.", + "no": "E-posten din er bekreftet. Vennligst logg inn nedenfor.", + "it": "La tua email è stata verificata. Effettua il login qui sotto.", + "pt": "Seu e-mail foi verificado. Por favor, faça login abaixo.", + "es": "Tu correo electrónico ha sido verificado. Por favor, inicia sesión a continuación.", + "ar": "تم التحقق من بريدك الإلكتروني. يرجى تسجيل الدخول أدناه.", + "fr": "Votre e-mail a été vérifié. Veuillez vous connecter ci-dessous.", + "tr": "E-postanız doğrulandı. Lütfen aşağıdan giriş yapın.", + "de": "Ihre E-Mail wurde verifiziert. Bitte melden Sie sich unten an.", + "uk": "Вашу електронну пошту підтверджено. Будь ласка, увійдіть нижче." + }, + "AUTH$DUPLICATE_EMAIL_ERROR": { + "en": "Your account is unable to be created. Please use a different login or try again.", + "ja": "アカウントを作成できません。別のログインを使用するか、もう一度お試しください。", + "zh-CN": "无法创建您的账户。请使用其他登录方式或重试。", + "zh-TW": "無法建立您的帳戶。請使用其他登入方式或重試。", + "ko-KR": "계정을 생성할 수 없습니다. 다른 로그인을 사용하거나 다시 시도해 주세요.", + "no": "Kontoen din kan ikke opprettes. Vennligst bruk en annen innlogging eller prøv igjen.", + "it": "Impossibile creare il tuo account. Utilizza un altro accesso o riprova.", + "pt": "Não foi possível criar sua conta. Use um login diferente ou tente novamente.", + "es": "No se puede crear su cuenta. Utilice un inicio de sesión diferente o inténtelo de nuevo.", + "ar": "لا يمكن إنشاء حسابك. يرجى استخدام تسجيل دخول مختلف أو المحاولة مرة أخرى.", + "fr": "Votre compte ne peut pas être créé. Veuillez utiliser une autre connexion ou réessayer.", + "tr": "Hesabınız oluşturulamadı. Lütfen farklı bir giriş kullanın veya tekrar deneyin.", + "de": "Ihr Konto kann nicht erstellt werden. Bitte verwenden Sie eine andere Anmeldung oder versuchen Sie es erneut.", + "uk": "Ваш обліковий запис не може бути створений. Будь ласка, використовуйте інший спосіб входу або спробуйте ще раз." + }, "COMMON$TERMS_OF_SERVICE": { "en": "Terms of Service", "ja": "利用規約", diff --git a/frontend/src/routes/root-layout.tsx b/frontend/src/routes/root-layout.tsx index 876c4d8c11..73da04ea8f 100644 --- a/frontend/src/routes/root-layout.tsx +++ b/frontend/src/routes/root-layout.tsx @@ -15,6 +15,7 @@ import { useConfig } from "#/hooks/query/use-config"; import { Sidebar } from "#/components/features/sidebar/sidebar"; import { AuthModal } from "#/components/features/waitlist/auth-modal"; import { ReauthModal } from "#/components/features/waitlist/reauth-modal"; +import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal"; import { AnalyticsConsentFormModal } from "#/components/features/analytics/analytics-consent-form-modal"; import { useSettings } from "#/hooks/query/use-settings"; import { useMigrateUserConsent } from "#/hooks/use-migrate-user-consent"; @@ -26,6 +27,7 @@ import { useAutoLogin } from "#/hooks/use-auto-login"; import { useAuthCallback } from "#/hooks/use-auth-callback"; import { useReoTracking } from "#/hooks/use-reo-tracking"; import { useSyncPostHogConsent } from "#/hooks/use-sync-posthog-consent"; +import { useEmailVerification } from "#/hooks/use-email-verification"; import { LOCAL_STORAGE_KEYS } from "#/utils/local-storage"; import { EmailVerificationGuard } from "#/components/features/guards/email-verification-guard"; import { MaintenanceBanner } from "#/components/features/maintenance/maintenance-banner"; @@ -91,6 +93,12 @@ export default function MainApp() { const effectiveGitHubAuthUrl = isOnTosPage ? null : gitHubAuthUrl; const [consentFormIsOpen, setConsentFormIsOpen] = React.useState(false); + const { + emailVerificationModalOpen, + setEmailVerificationModalOpen, + emailVerified, + hasDuplicatedEmail, + } = useEmailVerification(); // Auto-login if login method is stored in local storage useAutoLogin(); @@ -236,9 +244,18 @@ export default function MainApp() { appMode={config.data?.APP_MODE} providersConfigured={config.data?.PROVIDERS_CONFIGURED} authUrl={config.data?.AUTH_URL} + emailVerified={emailVerified} + hasDuplicatedEmail={hasDuplicatedEmail} /> )} {renderReAuthModal && } + {emailVerificationModalOpen && ( + { + setEmailVerificationModalOpen(false); + }} + /> + )} {config.data?.APP_MODE === "oss" && consentFormIsOpen && ( { diff --git a/openhands/app_server/sandbox/remote_sandbox_service.py b/openhands/app_server/sandbox/remote_sandbox_service.py index 076c478478..1606fc81ae 100644 --- a/openhands/app_server/sandbox/remote_sandbox_service.py +++ b/openhands/app_server/sandbox/remote_sandbox_service.py @@ -187,7 +187,7 @@ class RemoteSandboxService(SandboxService): return SandboxStatus.MISSING status = None - pod_status = runtime['pod_status'].lower() + pod_status = (runtime.get('pod_status') or '').lower() if pod_status: status = POD_STATUS_MAPPING.get(pod_status, None) @@ -356,7 +356,7 @@ class RemoteSandboxService(SandboxService): StoredRemoteSandbox.id == runtime.get('session_id') ) result = await self.db_session.execute(query) - sandbox = result.first() + sandbox = result.scalar_one_or_none() if sandbox is None: raise ValueError('sandbox_not_found') return self._to_sandbox_info(sandbox, runtime) diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 150fa54925..c786cfd6a3 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -21,6 +21,7 @@ from litellm import completion as litellm_completion from litellm import completion_cost as litellm_completion_cost from litellm.exceptions import ( APIConnectionError, + BadGatewayError, RateLimitError, ServiceUnavailableError, ) @@ -45,6 +46,7 @@ LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = ( APIConnectionError, RateLimitError, ServiceUnavailableError, + BadGatewayError, litellm.Timeout, litellm.InternalServerError, LLMNoResponseError, diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 2a255e2082..b88c2851e2 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -510,6 +510,10 @@ async def delete_conversation( if v1_result is not None: return v1_result + # Close connections + await db_session.close() + await httpx_client.aclose() + # V0 conversation logic return await _delete_v0_conversation(conversation_id, user_id) @@ -551,11 +555,8 @@ async def _try_delete_v1_conversation( httpx_client, ) ) - except (ValueError, TypeError): - # Not a valid UUID, continue with V0 logic - pass except Exception: - # Some other error, continue with V0 logic + # Continue with V0 logic pass return result diff --git a/tests/unit/llm/test_api_connection_error_retry.py b/tests/unit/llm/test_api_connection_error_retry.py index 8bcf15f986..b88c170079 100644 --- a/tests/unit/llm/test_api_connection_error_retry.py +++ b/tests/unit/llm/test_api_connection_error_retry.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from litellm.exceptions import APIConnectionError +from litellm.exceptions import APIConnectionError, BadGatewayError from openhands.core.config import LLMConfig from openhands.llm.llm import LLM @@ -86,3 +86,25 @@ def test_completion_max_retries_api_connection_error( # The exception doesn't contain retry information in the current implementation # Just verify that we got an APIConnectionError assert 'API connection error' in str(excinfo.value) + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_retries_bad_gateway_error(mock_litellm_completion, default_config): + """Test that BadGatewayError is properly retried.""" + mock_litellm_completion.side_effect = [ + BadGatewayError( + message='Bad gateway', + llm_provider='test_provider', + model='test_model', + ), + {'choices': [{'message': {'content': 'Retry successful'}}]}, + ] + + llm = LLM(config=default_config, service_id='test-service') + response = llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert response['choices'][0]['message']['content'] == 'Retry successful' + assert mock_litellm_completion.call_count == 2 diff --git a/tests/unit/server/data_models/test_conversation.py b/tests/unit/server/data_models/test_conversation.py index 79ff91fa7f..d5e289ecfa 100644 --- a/tests/unit/server/data_models/test_conversation.py +++ b/tests/unit/server/data_models/test_conversation.py @@ -946,6 +946,10 @@ async def test_delete_conversation(): # Create a mock sandbox service mock_sandbox_service = MagicMock() + # Create mock db_session and httpx_client + mock_db_session = AsyncMock() + mock_httpx_client = AsyncMock() + # Mock the conversation manager with patch( 'openhands.server.routes.manage_conversations.conversation_manager' @@ -969,6 +973,8 @@ async def test_delete_conversation(): app_conversation_service=mock_app_conversation_service, app_conversation_info_service=mock_app_conversation_info_service, sandbox_service=mock_sandbox_service, + db_session=mock_db_session, + httpx_client=mock_httpx_client, ) # Verify the result @@ -1090,6 +1096,10 @@ async def test_delete_v1_conversation_not_found(): ) mock_service.delete_app_conversation = AsyncMock(return_value=False) + # Create mock db_session and httpx_client + mock_db_session = AsyncMock() + mock_httpx_client = AsyncMock() + # Call delete_conversation with V1 conversation ID result = await delete_conversation( request=MagicMock(), @@ -1098,6 +1108,8 @@ async def test_delete_v1_conversation_not_found(): app_conversation_service=mock_service, app_conversation_info_service=mock_info_service, sandbox_service=mock_sandbox_service, + db_session=mock_db_session, + httpx_client=mock_httpx_client, ) # Verify the result @@ -1171,6 +1183,10 @@ async def test_delete_v1_conversation_invalid_uuid(): mock_sandbox_service = MagicMock() mock_sandbox_service_dep.return_value = mock_sandbox_service + # Create mock db_session and httpx_client + mock_db_session = AsyncMock() + mock_httpx_client = AsyncMock() + # Call delete_conversation result = await delete_conversation( request=MagicMock(), @@ -1179,6 +1195,8 @@ async def test_delete_v1_conversation_invalid_uuid(): app_conversation_service=mock_service, app_conversation_info_service=mock_info_service, sandbox_service=mock_sandbox_service, + db_session=mock_db_session, + httpx_client=mock_httpx_client, ) # Verify the result @@ -1264,6 +1282,10 @@ async def test_delete_v1_conversation_service_error(): mock_runtime_cls.delete = AsyncMock() mock_get_runtime_cls.return_value = mock_runtime_cls + # Create mock db_session and httpx_client + mock_db_session = AsyncMock() + mock_httpx_client = AsyncMock() + # Call delete_conversation result = await delete_conversation( request=MagicMock(), @@ -1272,6 +1294,8 @@ async def test_delete_v1_conversation_service_error(): app_conversation_service=mock_service, app_conversation_info_service=mock_info_service, sandbox_service=mock_sandbox_service, + db_session=mock_db_session, + httpx_client=mock_httpx_client, ) # Verify the result (should fallback to V0)