diff --git a/enterprise/server/auth/email_validation.py b/enterprise/server/auth/email_validation.py new file mode 100644 index 0000000000..94c6f52c2a --- /dev/null +++ b/enterprise/server/auth/email_validation.py @@ -0,0 +1,109 @@ +"""Email validation utilities for preventing duplicate signups with + modifier.""" + +import re + + +def extract_base_email(email: str) -> str | None: + """Extract base email from an email address. + + For emails with + modifier, extracts the base email (local part before + and @, plus domain). + For emails without + modifier, returns the email as-is. + + Examples: + extract_base_email("joe+test@example.com") -> "joe@example.com" + extract_base_email("joe@example.com") -> "joe@example.com" + extract_base_email("joe+openhands+test@example.com") -> "joe@example.com" + + Args: + email: The email address to process + + Returns: + The base email address, or None if email format is invalid + """ + if not email or '@' not in email: + return None + + try: + local_part, domain = email.rsplit('@', 1) + # Extract the part before + if it exists + base_local = local_part.split('+', 1)[0] + return f'{base_local}@{domain}' + except (ValueError, AttributeError): + return None + + +def has_plus_modifier(email: str) -> bool: + """Check if an email address contains a + modifier. + + Args: + email: The email address to check + + Returns: + True if email contains + before @, False otherwise + """ + if not email or '@' not in email: + return False + + try: + local_part, _ = email.rsplit('@', 1) + return '+' in local_part + except (ValueError, AttributeError): + return False + + +def matches_base_email(email: str, base_email: str) -> bool: + """Check if an email matches a base email pattern. + + An email matches if: + - It is exactly the base email (e.g., joe@example.com) + - It has the same base local part and domain, with or without + modifier + (e.g., joe+test@example.com matches base joe@example.com) + + Args: + email: The email address to check + base_email: The base email to match against + + Returns: + True if email matches the base pattern, False otherwise + """ + if not email or not base_email: + return False + + # Extract base from both emails for comparison + email_base = extract_base_email(email) + base_email_normalized = extract_base_email(base_email) + + if not email_base or not base_email_normalized: + return False + + # Emails match if they have the same base + return email_base.lower() == base_email_normalized.lower() + + +def get_base_email_regex_pattern(base_email: str) -> re.Pattern | None: + """Generate a regex pattern to match emails with the same base. + + For base_email "joe@example.com", the pattern will match: + - joe@example.com + - joe+anything@example.com + + Args: + base_email: The base email address + + Returns: + A compiled regex pattern, or None if base_email is invalid + """ + base = extract_base_email(base_email) + if not base: + return None + + try: + local_part, domain = base.rsplit('@', 1) + # Escape special regex characters in local part and domain + escaped_local = re.escape(local_part) + escaped_domain = re.escape(domain) + # Pattern: joe@example.com OR joe+anything@example.com + pattern = rf'^{escaped_local}(\+[^@\s]+)?@{escaped_domain}$' + return re.compile(pattern, re.IGNORECASE) + except (ValueError, AttributeError): + return None diff --git a/enterprise/server/auth/token_manager.py b/enterprise/server/auth/token_manager.py index 04bfae0767..6061518cb4 100644 --- a/enterprise/server/auth/token_manager.py +++ b/enterprise/server/auth/token_manager.py @@ -1,3 +1,4 @@ +import asyncio import base64 import hashlib import json @@ -25,6 +26,11 @@ from server.auth.constants import ( KEYCLOAK_SERVER_URL, KEYCLOAK_SERVER_URL_EXT, ) +from server.auth.email_validation import ( + extract_base_email, + get_base_email_regex_pattern, + matches_base_email, +) from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid from server.config import get_config from server.logger import logger @@ -509,6 +515,183 @@ class TokenManager: logger.info(f'Got user ID {keycloak_user_id} from email: {email}') return keycloak_user_id + async def _query_users_by_wildcard_pattern( + self, local_part: str, domain: str + ) -> dict[str, dict]: + """Query Keycloak for users matching a wildcard email pattern. + + Tries multiple query methods to find users with emails matching + the pattern {local_part}*@{domain}. This catches the base email + and all + modifier variants. + + Args: + local_part: The local part of the email (before @) + domain: The domain part of the email (after @) + + Returns: + Dictionary mapping user IDs to user objects + """ + keycloak_admin = get_keycloak_admin(self.external) + all_users = {} + + # Query for users with emails matching the base pattern using wildcard + # Pattern: {local_part}*@{domain} - catches base email and all + variants + # This may also catch unintended matches (e.g., joesmith@example.com), but + # they will be filtered out by the regex pattern check later + # Use 'search' parameter for Keycloak 26+ (better wildcard support) + wildcard_queries = [ + {'search': f'{local_part}*@{domain}'}, # Try 'search' parameter first + {'q': f'email:{local_part}*@{domain}'}, # Fallback to 'q' parameter + ] + + for query_params in wildcard_queries: + try: + users = await keycloak_admin.a_get_users(query_params) + for user in users: + all_users[user.get('id')] = user + break # Success, no need to try fallback + except Exception as e: + logger.debug( + f'Wildcard query failed with {list(query_params.keys())[0]}: {e}' + ) + continue # Try next query method + + return all_users + + def _find_duplicate_in_users( + self, users: dict[str, dict], base_email: str, current_user_id: str + ) -> bool: + """Check if any user in the provided list matches the base email pattern. + + Filters users to find duplicates that match the base email pattern, + excluding the current user. + + Args: + users: Dictionary mapping user IDs to user objects + base_email: The base email to match against + current_user_id: The user ID to exclude from the check + + Returns: + True if a duplicate is found, False otherwise + """ + regex_pattern = get_base_email_regex_pattern(base_email) + if not regex_pattern: + logger.warning( + f'Could not generate regex pattern for base email: {base_email}' + ) + # Fallback to simple matching + for user in users.values(): + user_email = user.get('email', '').lower() + if ( + user_email + and user.get('id') != current_user_id + and matches_base_email(user_email, base_email) + ): + logger.info( + f'Found duplicate email: {user_email} matches base {base_email}' + ) + return True + else: + for user in users.values(): + user_email = user.get('email', '') + if ( + user_email + and user.get('id') != current_user_id + and regex_pattern.match(user_email) + ): + logger.info( + f'Found duplicate email: {user_email} matches base {base_email}' + ) + return True + + return False + + @retry( + stop=stop_after_attempt(2), + retry=retry_if_exception_type(KeycloakConnectionError), + before_sleep=_before_sleep_callback, + ) + async def check_duplicate_base_email( + self, email: str, current_user_id: str + ) -> bool: + """Check if a user with the same base email already exists. + + This method checks for duplicate signups using email + modifier. + It checks if any user exists with the same base email, regardless of whether + the provided email has a + modifier or not. + + Examples: + - If email is "joe+test@example.com", it checks for existing users with + base email "joe@example.com" (e.g., "joe@example.com", "joe+1@example.com") + - If email is "joe@example.com", it checks for existing users with + base email "joe@example.com" (e.g., "joe+1@example.com", "joe+test@example.com") + + Args: + email: The email address to check (may or may not contain + modifier) + current_user_id: The user ID of the current user (to exclude from check) + + Returns: + True if a duplicate is found (excluding current user), False otherwise + """ + if not email: + return False + + base_email = extract_base_email(email) + if not base_email: + logger.warning(f'Could not extract base email from: {email}') + return False + + try: + local_part, domain = base_email.rsplit('@', 1) + users = await self._query_users_by_wildcard_pattern(local_part, domain) + return self._find_duplicate_in_users(users, base_email, current_user_id) + + except KeycloakConnectionError: + logger.exception('KeycloakConnectionError when checking duplicate email') + raise + except Exception as e: + logger.exception(f'Unexpected error checking duplicate email: {e}') + # On any error, allow signup to proceed (fail open) + return False + + @retry( + stop=stop_after_attempt(2), + retry=retry_if_exception_type(KeycloakConnectionError), + before_sleep=_before_sleep_callback, + ) + async def delete_keycloak_user(self, user_id: str) -> bool: + """Delete a user from Keycloak. + + This method is used to clean up user accounts that were created + but should not exist (e.g., duplicate email signups). + + Args: + user_id: The Keycloak user ID to delete + + Returns: + True if deletion was successful, False otherwise + """ + try: + keycloak_admin = get_keycloak_admin(self.external) + # Use the sync method (python-keycloak doesn't have async delete_user) + # Run it in a thread executor to avoid blocking the event loop + await asyncio.to_thread(keycloak_admin.delete_user, user_id) + logger.info(f'Successfully deleted Keycloak user {user_id}') + return True + except KeycloakConnectionError: + logger.exception(f'KeycloakConnectionError when deleting user {user_id}') + raise + except KeycloakError as e: + # User might not exist or already deleted + logger.warning( + f'KeycloakError when deleting user {user_id}: {e}', + extra={'user_id': user_id, 'error': str(e)}, + ) + return False + except Exception as e: + logger.exception(f'Unexpected error deleting Keycloak user {user_id}: {e}') + return False + async def get_user_info_from_user_id(self, user_id: str) -> dict | None: keycloak_admin = get_keycloak_admin(self.external) user = await keycloak_admin.a_get_user(user_id) diff --git a/enterprise/server/routes/auth.py b/enterprise/server/routes/auth.py index 2ee50bbd2d..3ea384b403 100644 --- a/enterprise/server/routes/auth.py +++ b/enterprise/server/routes/auth.py @@ -146,9 +146,11 @@ async def keycloak_callback( content={'error': 'Missing user ID or username in response'}, ) - # Check if email domain is blocked email = user_info.get('email') user_id = user_info['sub'] + + # Check if email domain is blocked + email = user_info.get('email') if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email): logger.warning( f'Blocked authentication attempt for email: {email}, user_id: {user_id}' @@ -164,6 +166,42 @@ async def keycloak_callback( }, ) + # Check for duplicate email with + modifier + if email: + try: + has_duplicate = await token_manager.check_duplicate_base_email( + email, user_id + ) + if has_duplicate: + logger.warning( + f'Blocked signup attempt for email {email} - duplicate base email found', + extra={'user_id': user_id, 'email': email}, + ) + + # Delete the Keycloak user that was automatically created during OAuth + # This prevents orphaned accounts in Keycloak + # The delete_keycloak_user method already handles all errors internally + deletion_success = await token_manager.delete_keycloak_user(user_id) + if deletion_success: + logger.info( + f'Deleted Keycloak user {user_id} after detecting duplicate email {email}' + ) + else: + logger.warning( + f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. ' + f'User may need to be manually cleaned up.' + ) + + # Redirect to home page with query parameter indicating the issue + home_url = f'{request.base_url}?duplicated_email=true' + return RedirectResponse(home_url, status_code=302) + except Exception as e: + # Log error but allow signup to proceed (fail open) + logger.error( + f'Error checking duplicate email for {email}: {e}', + extra={'user_id': user_id, 'email': email}, + ) + # default to github IDP for now. # TODO: remove default once Keycloak is updated universally with the new attribute. idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value) diff --git a/enterprise/tests/unit/test_auth_routes.py b/enterprise/tests/unit/test_auth_routes.py index d3e8f47fbe..0eeca12dcf 100644 --- a/enterprise/tests/unit/test_auth_routes.py +++ b/enterprise/tests/unit/test_auth_routes.py @@ -635,3 +635,219 @@ async def test_keycloak_callback_missing_email(mock_request): assert isinstance(result, RedirectResponse) mock_domain_blocker.is_domain_blocked.assert_not_called() mock_token_manager.disable_keycloak_user.assert_not_called() + + +@pytest.mark.asyncio +async def test_keycloak_callback_duplicate_email_detected(mock_request): + """Test keycloak_callback when duplicate email is detected.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + ): + # Arrange + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'joe+test@example.com', + 'identity_provider': 'github', + } + ) + mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True) + mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True) + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + assert 'duplicated_email=true' in result.headers['location'] + mock_token_manager.check_duplicate_base_email.assert_called_once_with( + 'joe+test@example.com', 'test_user_id' + ) + mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id') + + +@pytest.mark.asyncio +async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request): + """Test keycloak_callback when duplicate is detected but deletion fails.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + ): + # Arrange + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'joe+test@example.com', + 'identity_provider': 'github', + } + ) + mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True) + mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False) + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + assert 'duplicated_email=true' in result.headers['location'] + mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id') + + +@pytest.mark.asyncio +async def test_keycloak_callback_duplicate_check_exception(mock_request): + """Test keycloak_callback when duplicate check raises exception.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.session_maker') as mock_session_maker, + ): + # Arrange + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_user_settings = MagicMock() + mock_user_settings.accepted_tos = '2025-01-01' + mock_query.first.return_value = mock_user_settings + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'joe+test@example.com', + 'identity_provider': 'github', + } + ) + mock_token_manager.check_duplicate_base_email = AsyncMock( + side_effect=Exception('Check failed') + ) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + # Should proceed with normal flow despite exception (fail open) + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + + +@pytest.mark.asyncio +async def test_keycloak_callback_no_duplicate_email(mock_request): + """Test keycloak_callback when no duplicate email is found.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.session_maker') as mock_session_maker, + ): + # Arrange + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_user_settings = MagicMock() + mock_user_settings.accepted_tos = '2025-01-01' + mock_query.first.return_value = mock_user_settings + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'joe+test@example.com', + 'identity_provider': 'github', + } + ) + mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + mock_token_manager.check_duplicate_base_email.assert_called_once_with( + 'joe+test@example.com', 'test_user_id' + ) + # Should not delete user when no duplicate found + mock_token_manager.delete_keycloak_user.assert_not_called() + + +@pytest.mark.asyncio +async def test_keycloak_callback_no_email_in_user_info(mock_request): + """Test keycloak_callback when email is not in user_info.""" + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.session_maker') as mock_session_maker, + ): + # Arrange + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_user_settings = MagicMock() + mock_user_settings.accepted_tos = '2025-01-01' + mock_query.first.return_value = mock_user_settings + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + # No email field + 'identity_provider': 'github', + } + ) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + # Should not check for duplicate when email is missing + mock_token_manager.check_duplicate_base_email.assert_not_called() diff --git a/enterprise/tests/unit/test_email_validation.py b/enterprise/tests/unit/test_email_validation.py new file mode 100644 index 0000000000..320c5d4699 --- /dev/null +++ b/enterprise/tests/unit/test_email_validation.py @@ -0,0 +1,294 @@ +"""Tests for email validation utilities.""" + +import re + +from server.auth.email_validation import ( + extract_base_email, + get_base_email_regex_pattern, + has_plus_modifier, + matches_base_email, +) + + +class TestExtractBaseEmail: + """Test cases for extract_base_email function.""" + + def test_extract_base_email_with_plus_modifier(self): + """Test extracting base email from email with + modifier.""" + # Arrange + email = 'joe+test@example.com' + + # Act + result = extract_base_email(email) + + # Assert + assert result == 'joe@example.com' + + def test_extract_base_email_without_plus_modifier(self): + """Test that email without + modifier is returned as-is.""" + # Arrange + email = 'joe@example.com' + + # Act + result = extract_base_email(email) + + # Assert + assert result == 'joe@example.com' + + def test_extract_base_email_multiple_plus_signs(self): + """Test extracting base email when multiple + signs exist.""" + # Arrange + email = 'joe+openhands+test@example.com' + + # Act + result = extract_base_email(email) + + # Assert + assert result == 'joe@example.com' + + def test_extract_base_email_invalid_no_at_symbol(self): + """Test that invalid email without @ returns None.""" + # Arrange + email = 'invalid-email' + + # Act + result = extract_base_email(email) + + # Assert + assert result is None + + def test_extract_base_email_empty_string(self): + """Test that empty string returns None.""" + # Arrange + email = '' + + # Act + result = extract_base_email(email) + + # Assert + assert result is None + + def test_extract_base_email_none(self): + """Test that None input returns None.""" + # Arrange + email = None + + # Act + result = extract_base_email(email) + + # Assert + assert result is None + + +class TestHasPlusModifier: + """Test cases for has_plus_modifier function.""" + + def test_has_plus_modifier_true(self): + """Test detecting + modifier in email.""" + # Arrange + email = 'joe+test@example.com' + + # Act + result = has_plus_modifier(email) + + # Assert + assert result is True + + def test_has_plus_modifier_false(self): + """Test that email without + modifier returns False.""" + # Arrange + email = 'joe@example.com' + + # Act + result = has_plus_modifier(email) + + # Assert + assert result is False + + def test_has_plus_modifier_invalid_no_at_symbol(self): + """Test that invalid email without @ returns False.""" + # Arrange + email = 'invalid-email' + + # Act + result = has_plus_modifier(email) + + # Assert + assert result is False + + def test_has_plus_modifier_empty_string(self): + """Test that empty string returns False.""" + # Arrange + email = '' + + # Act + result = has_plus_modifier(email) + + # Assert + assert result is False + + +class TestMatchesBaseEmail: + """Test cases for matches_base_email function.""" + + def test_matches_base_email_exact_match(self): + """Test that exact base email matches.""" + # Arrange + email = 'joe@example.com' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is True + + def test_matches_base_email_with_plus_variant(self): + """Test that email with + variant matches base email.""" + # Arrange + email = 'joe+test@example.com' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is True + + def test_matches_base_email_different_base(self): + """Test that different base emails do not match.""" + # Arrange + email = 'jane@example.com' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is False + + def test_matches_base_email_different_domain(self): + """Test that same local part but different domain does not match.""" + # Arrange + email = 'joe@other.com' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is False + + def test_matches_base_email_case_insensitive(self): + """Test that matching is case-insensitive.""" + # Arrange + email = 'JOE+TEST@EXAMPLE.COM' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is True + + def test_matches_base_email_empty_strings(self): + """Test that empty strings return False.""" + # Arrange + email = '' + base_email = 'joe@example.com' + + # Act + result = matches_base_email(email, base_email) + + # Assert + assert result is False + + +class TestGetBaseEmailRegexPattern: + """Test cases for get_base_email_regex_pattern function.""" + + def test_get_base_email_regex_pattern_valid(self): + """Test generating valid regex pattern for base email.""" + # Arrange + base_email = 'joe@example.com' + + # Act + pattern = get_base_email_regex_pattern(base_email) + + # Assert + assert pattern is not None + assert isinstance(pattern, re.Pattern) + assert pattern.match('joe@example.com') is not None + assert pattern.match('joe+test@example.com') is not None + assert pattern.match('joe+openhands@example.com') is not None + + def test_get_base_email_regex_pattern_matches_plus_variant(self): + """Test that regex pattern matches + variant.""" + # Arrange + base_email = 'joe@example.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('joe+test@example.com') + + # Assert + assert match is not None + + def test_get_base_email_regex_pattern_rejects_different_base(self): + """Test that regex pattern rejects different base email.""" + # Arrange + base_email = 'joe@example.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('jane@example.com') + + # Assert + assert match is None + + def test_get_base_email_regex_pattern_rejects_different_domain(self): + """Test that regex pattern rejects different domain.""" + # Arrange + base_email = 'joe@example.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('joe@other.com') + + # Assert + assert match is None + + def test_get_base_email_regex_pattern_case_insensitive(self): + """Test that regex pattern is case-insensitive.""" + # Arrange + base_email = 'joe@example.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('JOE+TEST@EXAMPLE.COM') + + # Assert + assert match is not None + + def test_get_base_email_regex_pattern_special_characters(self): + """Test that regex pattern handles special characters in email.""" + # Arrange + base_email = 'user.name+tag@example-site.com' + pattern = get_base_email_regex_pattern(base_email) + + # Act + match = pattern.match('user.name+test@example-site.com') + + # Assert + assert match is not None + + def test_get_base_email_regex_pattern_invalid_base_email(self): + """Test that invalid base email returns None.""" + # Arrange + base_email = 'invalid-email' + + # Act + pattern = get_base_email_regex_pattern(base_email) + + # Assert + assert pattern is None diff --git a/enterprise/tests/unit/test_token_manager.py b/enterprise/tests/unit/test_token_manager.py index 413962d60c..0498ff1cb5 100644 --- a/enterprise/tests/unit/test_token_manager.py +++ b/enterprise/tests/unit/test_token_manager.py @@ -1,6 +1,8 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from keycloak.exceptions import KeycloakConnectionError, KeycloakError +from server.auth.token_manager import TokenManager from sqlalchemy.orm import Session from storage.offline_token_store import OfflineTokenStore from storage.stored_offline_token import StoredOfflineToken @@ -32,6 +34,14 @@ def token_store(mock_session_maker, mock_config): return OfflineTokenStore('test_user_id', mock_session_maker, mock_config) +@pytest.fixture +def token_manager(): + with patch('server.config.get_config') as mock_get_config: + mock_config = mock_get_config.return_value + mock_config.jwt_secret.get_secret_value.return_value = 'test_secret' + return TokenManager(external=False) + + @pytest.mark.asyncio async def test_store_token_new_record(token_store, mock_session): # Setup @@ -109,3 +119,419 @@ async def test_get_instance(mock_config): assert isinstance(result, OfflineTokenStore) assert result.user_id == test_user_id assert result.config == mock_config + + +class TestCheckDuplicateBaseEmail: + """Test cases for check_duplicate_base_email method.""" + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_no_plus_modifier(self, token_manager): + """Test that emails without + modifier are still checked for duplicates.""" + # Arrange + email = 'joe@example.com' + current_user_id = 'user123' + + with ( + patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query, + patch.object(token_manager, '_find_duplicate_in_users') as mock_find, + ): + mock_find.return_value = False + mock_query.return_value = {} + + # Act + result = await token_manager.check_duplicate_base_email( + email, current_user_id + ) + + # Assert + assert result is False + mock_query.assert_called_once() + mock_find.assert_called_once() + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_empty_email(self, token_manager): + """Test that empty email returns False.""" + # Arrange + email = '' + current_user_id = 'user123' + + # Act + result = await token_manager.check_duplicate_base_email(email, current_user_id) + + # Assert + assert result is False + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_invalid_email(self, token_manager): + """Test that invalid email returns False.""" + # Arrange + email = 'invalid-email' + current_user_id = 'user123' + + # Act + result = await token_manager.check_duplicate_base_email(email, current_user_id) + + # Assert + assert result is False + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_duplicate_found(self, token_manager): + """Test that duplicate email is detected when found.""" + # Arrange + email = 'joe+test@example.com' + current_user_id = 'user123' + existing_user = { + 'id': 'existing_user_id', + 'email': 'joe@example.com', + } + + with ( + patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query, + patch.object(token_manager, '_find_duplicate_in_users') as mock_find, + ): + mock_find.return_value = True + mock_query.return_value = {'existing_user_id': existing_user} + + # Act + result = await token_manager.check_duplicate_base_email( + email, current_user_id + ) + + # Assert + assert result is True + mock_query.assert_called_once() + mock_find.assert_called_once() + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_no_duplicate(self, token_manager): + """Test that no duplicate is found when none exists.""" + # Arrange + email = 'joe+test@example.com' + current_user_id = 'user123' + + with ( + patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query, + patch.object(token_manager, '_find_duplicate_in_users') as mock_find, + ): + mock_find.return_value = False + mock_query.return_value = {} + + # Act + result = await token_manager.check_duplicate_base_email( + email, current_user_id + ) + + # Assert + assert result is False + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_keycloak_connection_error( + self, token_manager + ): + """Test that KeycloakConnectionError triggers retry and raises RetryError.""" + # Arrange + email = 'joe+test@example.com' + current_user_id = 'user123' + + with patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query: + mock_query.side_effect = KeycloakConnectionError('Connection failed') + + # Act & Assert + # KeycloakConnectionError is re-raised, which triggers retry decorator + # After retries exhaust (2 attempts), it raises RetryError + from tenacity import RetryError + + with pytest.raises(RetryError): + await token_manager.check_duplicate_base_email(email, current_user_id) + + @pytest.mark.asyncio + async def test_check_duplicate_base_email_general_exception(self, token_manager): + """Test that general exceptions are handled gracefully.""" + # Arrange + email = 'joe+test@example.com' + current_user_id = 'user123' + + with patch.object( + token_manager, '_query_users_by_wildcard_pattern' + ) as mock_query: + mock_query.side_effect = Exception('Unexpected error') + + # Act + result = await token_manager.check_duplicate_base_email( + email, current_user_id + ) + + # Assert + assert result is False + + +class TestQueryUsersByWildcardPattern: + """Test cases for _query_users_by_wildcard_pattern method.""" + + @pytest.mark.asyncio + async def test_query_users_by_wildcard_pattern_success_with_search( + self, token_manager + ): + """Test successful query using search parameter.""" + # Arrange + local_part = 'joe' + domain = 'example.com' + mock_users = [ + {'id': 'user1', 'email': 'joe@example.com'}, + {'id': 'user2', 'email': 'joe+test@example.com'}, + ] + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + mock_admin.a_get_users = AsyncMock(return_value=mock_users) + mock_get_admin.return_value = mock_admin + + # Act + result = await token_manager._query_users_by_wildcard_pattern( + local_part, domain + ) + + # Assert + assert len(result) == 2 + assert 'user1' in result + assert 'user2' in result + mock_admin.a_get_users.assert_called_once_with( + {'search': 'joe*@example.com'} + ) + + @pytest.mark.asyncio + async def test_query_users_by_wildcard_pattern_fallback_to_q(self, token_manager): + """Test fallback to q parameter when search fails.""" + # Arrange + local_part = 'joe' + domain = 'example.com' + mock_users = [{'id': 'user1', 'email': 'joe@example.com'}] + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + # First call fails, second succeeds + mock_admin.a_get_users = AsyncMock( + side_effect=[Exception('Search failed'), mock_users] + ) + mock_get_admin.return_value = mock_admin + + # Act + result = await token_manager._query_users_by_wildcard_pattern( + local_part, domain + ) + + # Assert + assert len(result) == 1 + assert 'user1' in result + assert mock_admin.a_get_users.call_count == 2 + + @pytest.mark.asyncio + async def test_query_users_by_wildcard_pattern_empty_result(self, token_manager): + """Test query returns empty dict when no users found.""" + # Arrange + local_part = 'joe' + domain = 'example.com' + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + mock_admin.a_get_users = AsyncMock(return_value=[]) + mock_get_admin.return_value = mock_admin + + # Act + result = await token_manager._query_users_by_wildcard_pattern( + local_part, domain + ) + + # Assert + assert result == {} + + +class TestFindDuplicateInUsers: + """Test cases for _find_duplicate_in_users method.""" + + def test_find_duplicate_in_users_with_regex_match(self, token_manager): + """Test finding duplicate using regex pattern.""" + # Arrange + users = { + 'user1': {'id': 'user1', 'email': 'joe@example.com'}, + 'user2': {'id': 'user2', 'email': 'joe+test@example.com'}, + } + base_email = 'joe@example.com' + current_user_id = 'user3' + + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + assert result is True + + def test_find_duplicate_in_users_fallback_to_simple_matching(self, token_manager): + """Test fallback to simple matching when regex pattern is None.""" + # Arrange + users = { + 'user1': {'id': 'user1', 'email': 'joe@example.com'}, + } + base_email = 'invalid-email' # Will cause regex pattern to be None + current_user_id = 'user2' + + with patch( + 'server.auth.token_manager.get_base_email_regex_pattern', return_value=None + ): + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + # Should use fallback matching, but invalid base_email won't match + assert result is False + + def test_find_duplicate_in_users_excludes_current_user(self, token_manager): + """Test that current user is excluded from duplicate check.""" + # Arrange + users = { + 'user1': {'id': 'user1', 'email': 'joe@example.com'}, + } + base_email = 'joe@example.com' + current_user_id = 'user1' # Same as user in users dict + + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + assert result is False + + def test_find_duplicate_in_users_no_match(self, token_manager): + """Test that no duplicate is found when emails don't match.""" + # Arrange + users = { + 'user1': {'id': 'user1', 'email': 'jane@example.com'}, + } + base_email = 'joe@example.com' + current_user_id = 'user2' + + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + assert result is False + + def test_find_duplicate_in_users_empty_dict(self, token_manager): + """Test that empty users dict returns False.""" + # Arrange + users: dict[str, dict] = {} + base_email = 'joe@example.com' + current_user_id = 'user1' + + # Act + result = token_manager._find_duplicate_in_users( + users, base_email, current_user_id + ) + + # Assert + assert result is False + + +class TestDeleteKeycloakUser: + """Test cases for delete_keycloak_user method.""" + + @pytest.mark.asyncio + async def test_delete_keycloak_user_success(self, token_manager): + """Test successful deletion of Keycloak user.""" + # Arrange + user_id = 'test_user_id' + + with ( + patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin, + patch('asyncio.to_thread') as mock_to_thread, + ): + mock_admin = MagicMock() + mock_admin.delete_user = MagicMock() + mock_get_admin.return_value = mock_admin + mock_to_thread.return_value = None + + # Act + result = await token_manager.delete_keycloak_user(user_id) + + # Assert + assert result is True + mock_to_thread.assert_called_once_with(mock_admin.delete_user, user_id) + + @pytest.mark.asyncio + async def test_delete_keycloak_user_connection_error(self, token_manager): + """Test handling of KeycloakConnectionError triggers retry and raises RetryError.""" + # Arrange + user_id = 'test_user_id' + + with ( + patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin, + patch('asyncio.to_thread') as mock_to_thread, + ): + mock_admin = MagicMock() + mock_admin.delete_user = MagicMock() + mock_get_admin.return_value = mock_admin + mock_to_thread.side_effect = KeycloakConnectionError('Connection failed') + + # Act & Assert + # KeycloakConnectionError triggers retry decorator + # After retries exhaust (2 attempts), it raises RetryError + from tenacity import RetryError + + with pytest.raises(RetryError): + await token_manager.delete_keycloak_user(user_id) + + @pytest.mark.asyncio + async def test_delete_keycloak_user_keycloak_error(self, token_manager): + """Test handling of KeycloakError (e.g., user not found).""" + # Arrange + user_id = 'test_user_id' + + with ( + patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin, + patch('asyncio.to_thread') as mock_to_thread, + ): + mock_admin = MagicMock() + mock_admin.delete_user = MagicMock() + mock_get_admin.return_value = mock_admin + mock_to_thread.side_effect = KeycloakError('User not found') + + # Act + result = await token_manager.delete_keycloak_user(user_id) + + # Assert + assert result is False + + @pytest.mark.asyncio + async def test_delete_keycloak_user_general_exception(self, token_manager): + """Test handling of general exceptions.""" + # Arrange + user_id = 'test_user_id' + + with ( + patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin, + patch('asyncio.to_thread') as mock_to_thread, + ): + mock_admin = MagicMock() + mock_admin.delete_user = MagicMock() + mock_get_admin.return_value = mock_admin + mock_to_thread.side_effect = Exception('Unexpected error') + + # Act + result = await token_manager.delete_keycloak_user(user_id) + + # Assert + assert result is False diff --git a/frontend/__tests__/components/features/auth-modal.test.tsx b/frontend/__tests__/components/features/auth-modal.test.tsx index 32b682d506..4f32841b12 100644 --- a/frontend/__tests__/components/features/auth-modal.test.tsx +++ b/frontend/__tests__/components/features/auth-modal.test.tsx @@ -1,6 +1,7 @@ import { render, screen } from "@testing-library/react"; import { it, describe, expect, vi, beforeEach, afterEach } from "vitest"; import userEvent from "@testing-library/user-event"; +import { MemoryRouter } from "react-router"; import { AuthModal } from "#/components/features/waitlist/auth-modal"; // Mock the useAuthUrl hook @@ -27,11 +28,13 @@ describe("AuthModal", () => { it("should render the GitHub and GitLab buttons", () => { render( - , + + + , ); const githubButton = screen.getByRole("button", { @@ -49,11 +52,13 @@ describe("AuthModal", () => { const user = userEvent.setup(); const mockUrl = "https://github.com/login/oauth/authorize"; render( - , + + + , ); const githubButton = screen.getByRole("button", { @@ -65,7 +70,11 @@ describe("AuthModal", () => { }); it("should render Terms of Service and Privacy Policy text with correct links", () => { - render(); + render( + + + , + ); // Find the terms of service section using data-testid const termsSection = screen.getByTestId("auth-modal-terms-of-service"); @@ -106,7 +115,11 @@ describe("AuthModal", () => { }); it("should open Terms of Service link in new tab", () => { - render(); + render( + + + , + ); const tosLink = screen.getByRole("link", { name: "COMMON$TERMS_OF_SERVICE", @@ -115,11 +128,53 @@ describe("AuthModal", () => { }); it("should open Privacy Policy link in new tab", () => { - render(); + render( + + + , + ); const privacyLink = screen.getByRole("link", { name: "COMMON$PRIVACY_POLICY", }); expect(privacyLink).toHaveAttribute("target", "_blank"); }); + + describe("Duplicate email error message", () => { + const renderAuthModalWithRouter = (initialEntries: string[]) => { + return render( + + + , + ); + }; + + it("should display error message when duplicated_email query parameter is true", () => { + // Arrange + const initialEntries = ["/?duplicated_email=true"]; + + // Act + renderAuthModalWithRouter(initialEntries); + + // Assert + const errorMessage = screen.getByText("AUTH$DUPLICATE_EMAIL_ERROR"); + expect(errorMessage).toBeInTheDocument(); + }); + + it("should not display error message when duplicated_email query parameter is missing", () => { + // Arrange + const initialEntries = ["/"]; + + // Act + renderAuthModalWithRouter(initialEntries); + + // Assert + const errorMessage = screen.queryByText("AUTH$DUPLICATE_EMAIL_ERROR"); + expect(errorMessage).not.toBeInTheDocument(); + }); + }); }); diff --git a/frontend/src/components/features/waitlist/auth-modal.tsx b/frontend/src/components/features/waitlist/auth-modal.tsx index 2c431fbd95..e1d52a7965 100644 --- a/frontend/src/components/features/waitlist/auth-modal.tsx +++ b/frontend/src/components/features/waitlist/auth-modal.tsx @@ -1,5 +1,6 @@ import React from "react"; import { useTranslation } from "react-i18next"; +import { useSearchParams } from "react-router"; import { I18nKey } from "#/i18n/declaration"; import OpenHandsLogo from "#/assets/branding/openhands-logo.svg?react"; import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop"; @@ -29,6 +30,8 @@ export function AuthModal({ }: AuthModalProps) { const { t } = useTranslation(); const { trackLoginButtonClick } = useTracking(); + const [searchParams] = useSearchParams(); + const hasDuplicatedEmail = searchParams.get("duplicated_email") === "true"; const gitlabAuthUrl = useAuthUrl({ appMode: appMode || null, @@ -123,6 +126,11 @@ export function AuthModal({ + {hasDuplicatedEmail && ( +
+ {t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)} +
+ )}

{t(I18nKey.AUTH$SIGN_IN_WITH_IDENTITY_PROVIDER)} diff --git a/frontend/src/i18n/declaration.ts b/frontend/src/i18n/declaration.ts index 1b330730d9..0dd668cacc 100644 --- a/frontend/src/i18n/declaration.ts +++ b/frontend/src/i18n/declaration.ts @@ -730,6 +730,7 @@ export enum I18nKey { MICROAGENT_MANAGEMENT$USE_MICROAGENTS = "MICROAGENT_MANAGEMENT$USE_MICROAGENTS", AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR = "AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR", AUTH$NO_PROVIDERS_CONFIGURED = "AUTH$NO_PROVIDERS_CONFIGURED", + AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR", COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE", COMMON$AND = "COMMON$AND", COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY", diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index a421de5ddf..2950b3ab72 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -11679,6 +11679,22 @@ "de": "Mindestens ein Identitätsanbieter muss konfiguriert werden (z.B. GitHub)", "uk": "Принаймні один постачальник ідентифікації має бути налаштований (наприклад, GitHub)" }, + "AUTH$DUPLICATE_EMAIL_ERROR": { + "en": "Your account is unable to be created. Please use a different login or try again.", + "ja": "アカウントを作成できません。別のログインを使用するか、もう一度お試しください。", + "zh-CN": "无法创建您的账户。请使用其他登录方式或重试。", + "zh-TW": "無法建立您的帳戶。請使用其他登入方式或重試。", + "ko-KR": "계정을 생성할 수 없습니다. 다른 로그인을 사용하거나 다시 시도해 주세요.", + "no": "Kontoen din kan ikke opprettes. Vennligst bruk en annen innlogging eller prøv igjen.", + "it": "Impossibile creare il tuo account. Utilizza un altro accesso o riprova.", + "pt": "Não foi possível criar sua conta. Use um login diferente ou tente novamente.", + "es": "No se puede crear su cuenta. Utilice un inicio de sesión diferente o inténtelo de nuevo.", + "ar": "لا يمكن إنشاء حسابك. يرجى استخدام تسجيل دخول مختلف أو المحاولة مرة أخرى.", + "fr": "Votre compte ne peut pas être créé. Veuillez utiliser une autre connexion ou réessayer.", + "tr": "Hesabınız oluşturulamadı. Lütfen farklı bir giriş kullanın veya tekrar deneyin.", + "de": "Ihr Konto kann nicht erstellt werden. Bitte verwenden Sie eine andere Anmeldung oder versuchen Sie es erneut.", + "uk": "Ваш обліковий запис не може бути створений. Будь ласка, використовуйте інший спосіб входу або спробуйте ще раз." + }, "COMMON$TERMS_OF_SERVICE": { "en": "Terms of Service", "ja": "利用規約",