mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
feat: prevent signups using email addresses with a plus sign and enforce the existing email pattern (#12124)
This commit is contained in:
parent
fae83230ee
commit
f6e7628bff
109
enterprise/server/auth/email_validation.py
Normal file
109
enterprise/server/auth/email_validation.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Email validation utilities for preventing duplicate signups with + modifier."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def extract_base_email(email: str) -> str | None:
|
||||
"""Extract base email from an email address.
|
||||
|
||||
For emails with + modifier, extracts the base email (local part before + and @, plus domain).
|
||||
For emails without + modifier, returns the email as-is.
|
||||
|
||||
Examples:
|
||||
extract_base_email("joe+test@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe+openhands+test@example.com") -> "joe@example.com"
|
||||
|
||||
Args:
|
||||
email: The email address to process
|
||||
|
||||
Returns:
|
||||
The base email address, or None if email format is invalid
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = email.rsplit('@', 1)
|
||||
# Extract the part before + if it exists
|
||||
base_local = local_part.split('+', 1)[0]
|
||||
return f'{base_local}@{domain}'
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def has_plus_modifier(email: str) -> bool:
|
||||
"""Check if an email address contains a + modifier.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if email contains + before @, False otherwise
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, _ = email.rsplit('@', 1)
|
||||
return '+' in local_part
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
def matches_base_email(email: str, base_email: str) -> bool:
|
||||
"""Check if an email matches a base email pattern.
|
||||
|
||||
An email matches if:
|
||||
- It is exactly the base email (e.g., joe@example.com)
|
||||
- It has the same base local part and domain, with or without + modifier
|
||||
(e.g., joe+test@example.com matches base joe@example.com)
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
base_email: The base email to match against
|
||||
|
||||
Returns:
|
||||
True if email matches the base pattern, False otherwise
|
||||
"""
|
||||
if not email or not base_email:
|
||||
return False
|
||||
|
||||
# Extract base from both emails for comparison
|
||||
email_base = extract_base_email(email)
|
||||
base_email_normalized = extract_base_email(base_email)
|
||||
|
||||
if not email_base or not base_email_normalized:
|
||||
return False
|
||||
|
||||
# Emails match if they have the same base
|
||||
return email_base.lower() == base_email_normalized.lower()
|
||||
|
||||
|
||||
def get_base_email_regex_pattern(base_email: str) -> re.Pattern | None:
|
||||
"""Generate a regex pattern to match emails with the same base.
|
||||
|
||||
For base_email "joe@example.com", the pattern will match:
|
||||
- joe@example.com
|
||||
- joe+anything@example.com
|
||||
|
||||
Args:
|
||||
base_email: The base email address
|
||||
|
||||
Returns:
|
||||
A compiled regex pattern, or None if base_email is invalid
|
||||
"""
|
||||
base = extract_base_email(base_email)
|
||||
if not base:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = base.rsplit('@', 1)
|
||||
# Escape special regex characters in local part and domain
|
||||
escaped_local = re.escape(local_part)
|
||||
escaped_domain = re.escape(domain)
|
||||
# Pattern: joe@example.com OR joe+anything@example.com
|
||||
pattern = rf'^{escaped_local}(\+[^@\s]+)?@{escaped_domain}$'
|
||||
return re.compile(pattern, re.IGNORECASE)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
@ -25,6 +26,11 @@ from server.auth.constants import (
|
||||
KEYCLOAK_SERVER_URL,
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
)
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
matches_base_email,
|
||||
)
|
||||
from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@ -509,6 +515,183 @@ class TokenManager:
|
||||
logger.info(f'Got user ID {keycloak_user_id} from email: {email}')
|
||||
return keycloak_user_id
|
||||
|
||||
async def _query_users_by_wildcard_pattern(
|
||||
self, local_part: str, domain: str
|
||||
) -> dict[str, dict]:
|
||||
"""Query Keycloak for users matching a wildcard email pattern.
|
||||
|
||||
Tries multiple query methods to find users with emails matching
|
||||
the pattern {local_part}*@{domain}. This catches the base email
|
||||
and all + modifier variants.
|
||||
|
||||
Args:
|
||||
local_part: The local part of the email (before @)
|
||||
domain: The domain part of the email (after @)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping user IDs to user objects
|
||||
"""
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
all_users = {}
|
||||
|
||||
# Query for users with emails matching the base pattern using wildcard
|
||||
# Pattern: {local_part}*@{domain} - catches base email and all + variants
|
||||
# This may also catch unintended matches (e.g., joesmith@example.com), but
|
||||
# they will be filtered out by the regex pattern check later
|
||||
# Use 'search' parameter for Keycloak 26+ (better wildcard support)
|
||||
wildcard_queries = [
|
||||
{'search': f'{local_part}*@{domain}'}, # Try 'search' parameter first
|
||||
{'q': f'email:{local_part}*@{domain}'}, # Fallback to 'q' parameter
|
||||
]
|
||||
|
||||
for query_params in wildcard_queries:
|
||||
try:
|
||||
users = await keycloak_admin.a_get_users(query_params)
|
||||
for user in users:
|
||||
all_users[user.get('id')] = user
|
||||
break # Success, no need to try fallback
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f'Wildcard query failed with {list(query_params.keys())[0]}: {e}'
|
||||
)
|
||||
continue # Try next query method
|
||||
|
||||
return all_users
|
||||
|
||||
def _find_duplicate_in_users(
|
||||
self, users: dict[str, dict], base_email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if any user in the provided list matches the base email pattern.
|
||||
|
||||
Filters users to find duplicates that match the base email pattern,
|
||||
excluding the current user.
|
||||
|
||||
Args:
|
||||
users: Dictionary mapping user IDs to user objects
|
||||
base_email: The base email to match against
|
||||
current_user_id: The user ID to exclude from the check
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found, False otherwise
|
||||
"""
|
||||
regex_pattern = get_base_email_regex_pattern(base_email)
|
||||
if not regex_pattern:
|
||||
logger.warning(
|
||||
f'Could not generate regex pattern for base email: {base_email}'
|
||||
)
|
||||
# Fallback to simple matching
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '').lower()
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and matches_base_email(user_email, base_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
else:
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '')
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and regex_pattern.match(user_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def check_duplicate_base_email(
|
||||
self, email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if a user with the same base email already exists.
|
||||
|
||||
This method checks for duplicate signups using email + modifier.
|
||||
It checks if any user exists with the same base email, regardless of whether
|
||||
the provided email has a + modifier or not.
|
||||
|
||||
Examples:
|
||||
- If email is "joe+test@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe@example.com", "joe+1@example.com")
|
||||
- If email is "joe@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe+1@example.com", "joe+test@example.com")
|
||||
|
||||
Args:
|
||||
email: The email address to check (may or may not contain + modifier)
|
||||
current_user_id: The user ID of the current user (to exclude from check)
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found (excluding current user), False otherwise
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
|
||||
base_email = extract_base_email(email)
|
||||
if not base_email:
|
||||
logger.warning(f'Could not extract base email from: {email}')
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, domain = base_email.rsplit('@', 1)
|
||||
users = await self._query_users_by_wildcard_pattern(local_part, domain)
|
||||
return self._find_duplicate_in_users(users, base_email, current_user_id)
|
||||
|
||||
except KeycloakConnectionError:
|
||||
logger.exception('KeycloakConnectionError when checking duplicate email')
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error checking duplicate email: {e}')
|
||||
# On any error, allow signup to proceed (fail open)
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def delete_keycloak_user(self, user_id: str) -> bool:
|
||||
"""Delete a user from Keycloak.
|
||||
|
||||
This method is used to clean up user accounts that were created
|
||||
but should not exist (e.g., duplicate email signups).
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to delete
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
# Use the sync method (python-keycloak doesn't have async delete_user)
|
||||
# Run it in a thread executor to avoid blocking the event loop
|
||||
await asyncio.to_thread(keycloak_admin.delete_user, user_id)
|
||||
logger.info(f'Successfully deleted Keycloak user {user_id}')
|
||||
return True
|
||||
except KeycloakConnectionError:
|
||||
logger.exception(f'KeycloakConnectionError when deleting user {user_id}')
|
||||
raise
|
||||
except KeycloakError as e:
|
||||
# User might not exist or already deleted
|
||||
logger.warning(
|
||||
f'KeycloakError when deleting user {user_id}: {e}',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error deleting Keycloak user {user_id}: {e}')
|
||||
return False
|
||||
|
||||
async def get_user_info_from_user_id(self, user_id: str) -> dict | None:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
user = await keycloak_admin.a_get_user(user_id)
|
||||
|
||||
@ -146,9 +146,11 @@ async def keycloak_callback(
|
||||
content={'error': 'Missing user ID or username in response'},
|
||||
)
|
||||
|
||||
# Check if email domain is blocked
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
|
||||
# Check if email domain is blocked
|
||||
email = user_info.get('email')
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
@ -164,6 +166,42 @@ async def keycloak_callback(
|
||||
},
|
||||
)
|
||||
|
||||
# Check for duplicate email with + modifier
|
||||
if email:
|
||||
try:
|
||||
has_duplicate = await token_manager.check_duplicate_base_email(
|
||||
email, user_id
|
||||
)
|
||||
if has_duplicate:
|
||||
logger.warning(
|
||||
f'Blocked signup attempt for email {email} - duplicate base email found',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Delete the Keycloak user that was automatically created during OAuth
|
||||
# This prevents orphaned accounts in Keycloak
|
||||
# The delete_keycloak_user method already handles all errors internally
|
||||
deletion_success = await token_manager.delete_keycloak_user(user_id)
|
||||
if deletion_success:
|
||||
logger.info(
|
||||
f'Deleted Keycloak user {user_id} after detecting duplicate email {email}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. '
|
||||
f'User may need to be manually cleaned up.'
|
||||
)
|
||||
|
||||
# Redirect to home page with query parameter indicating the issue
|
||||
home_url = f'{request.base_url}?duplicated_email=true'
|
||||
return RedirectResponse(home_url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log error but allow signup to proceed (fail open)
|
||||
logger.error(
|
||||
f'Error checking duplicate email for {email}: {e}',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# default to github IDP for now.
|
||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||
|
||||
@ -635,3 +635,219 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
"""Test keycloak_callback when duplicate email is detected."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
"""Test keycloak_callback when duplicate is detected but deletion fails."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
"""Test keycloak_callback when duplicate check raises exception."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(
|
||||
side_effect=Exception('Check failed')
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should proceed with normal flow despite exception (fail open)
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
"""Test keycloak_callback when no duplicate email is found."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
# Should not delete user when no duplicate found
|
||||
mock_token_manager.delete_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
"""Test keycloak_callback when email is not in user_info."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
# No email field
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
# Should not check for duplicate when email is missing
|
||||
mock_token_manager.check_duplicate_base_email.assert_not_called()
|
||||
|
||||
294
enterprise/tests/unit/test_email_validation.py
Normal file
294
enterprise/tests/unit/test_email_validation.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""Tests for email validation utilities."""
|
||||
|
||||
import re
|
||||
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
has_plus_modifier,
|
||||
matches_base_email,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractBaseEmail:
|
||||
"""Test cases for extract_base_email function."""
|
||||
|
||||
def test_extract_base_email_with_plus_modifier(self):
|
||||
"""Test extracting base email from email with + modifier."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_without_plus_modifier(self):
|
||||
"""Test that email without + modifier is returned as-is."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_multiple_plus_signs(self):
|
||||
"""Test extracting base email when multiple + signs exist."""
|
||||
# Arrange
|
||||
email = 'joe+openhands+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns None."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_empty_string(self):
|
||||
"""Test that empty string returns None."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_none(self):
|
||||
"""Test that None input returns None."""
|
||||
# Arrange
|
||||
email = None
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestHasPlusModifier:
|
||||
"""Test cases for has_plus_modifier function."""
|
||||
|
||||
def test_has_plus_modifier_true(self):
|
||||
"""Test detecting + modifier in email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_has_plus_modifier_false(self):
|
||||
"""Test that email without + modifier returns False."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns False."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_empty_string(self):
|
||||
"""Test that empty string returns False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMatchesBaseEmail:
|
||||
"""Test cases for matches_base_email function."""
|
||||
|
||||
def test_matches_base_email_exact_match(self):
|
||||
"""Test that exact base email matches."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_with_plus_variant(self):
|
||||
"""Test that email with + variant matches base email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_different_base(self):
|
||||
"""Test that different base emails do not match."""
|
||||
# Arrange
|
||||
email = 'jane@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_different_domain(self):
|
||||
"""Test that same local part but different domain does not match."""
|
||||
# Arrange
|
||||
email = 'joe@other.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_case_insensitive(self):
|
||||
"""Test that matching is case-insensitive."""
|
||||
# Arrange
|
||||
email = 'JOE+TEST@EXAMPLE.COM'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_empty_strings(self):
|
||||
"""Test that empty strings return False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetBaseEmailRegexPattern:
|
||||
"""Test cases for get_base_email_regex_pattern function."""
|
||||
|
||||
def test_get_base_email_regex_pattern_valid(self):
|
||||
"""Test generating valid regex pattern for base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is not None
|
||||
assert isinstance(pattern, re.Pattern)
|
||||
assert pattern.match('joe@example.com') is not None
|
||||
assert pattern.match('joe+test@example.com') is not None
|
||||
assert pattern.match('joe+openhands@example.com') is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_matches_plus_variant(self):
|
||||
"""Test that regex pattern matches + variant."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe+test@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_base(self):
|
||||
"""Test that regex pattern rejects different base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('jane@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_domain(self):
|
||||
"""Test that regex pattern rejects different domain."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe@other.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_case_insensitive(self):
|
||||
"""Test that regex pattern is case-insensitive."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('JOE+TEST@EXAMPLE.COM')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_special_characters(self):
|
||||
"""Test that regex pattern handles special characters in email."""
|
||||
# Arrange
|
||||
base_email = 'user.name+tag@example-site.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('user.name+test@example-site.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_invalid_base_email(self):
|
||||
"""Test that invalid base email returns None."""
|
||||
# Arrange
|
||||
base_email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is None
|
||||
@ -1,6 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from keycloak.exceptions import KeycloakConnectionError, KeycloakError
|
||||
from server.auth.token_manager import TokenManager
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.offline_token_store import OfflineTokenStore
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
@ -32,6 +34,14 @@ def token_store(mock_session_maker, mock_config):
|
||||
return OfflineTokenStore('test_user_id', mock_session_maker, mock_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_manager():
|
||||
with patch('server.config.get_config') as mock_get_config:
|
||||
mock_config = mock_get_config.return_value
|
||||
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
||||
return TokenManager(external=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_token_new_record(token_store, mock_session):
|
||||
# Setup
|
||||
@ -109,3 +119,419 @@ async def test_get_instance(mock_config):
|
||||
assert isinstance(result, OfflineTokenStore)
|
||||
assert result.user_id == test_user_id
|
||||
assert result.config == mock_config
|
||||
|
||||
|
||||
class TestCheckDuplicateBaseEmail:
|
||||
"""Test cases for check_duplicate_base_email method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_no_plus_modifier(self, token_manager):
|
||||
"""Test that emails without + modifier are still checked for duplicates."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = False
|
||||
mock_query.return_value = {}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_query.assert_called_once()
|
||||
mock_find.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_empty_email(self, token_manager):
|
||||
"""Test that empty email returns False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
current_user_id = 'user123'
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_invalid_email(self, token_manager):
|
||||
"""Test that invalid email returns False."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
current_user_id = 'user123'
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_duplicate_found(self, token_manager):
|
||||
"""Test that duplicate email is detected when found."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
existing_user = {
|
||||
'id': 'existing_user_id',
|
||||
'email': 'joe@example.com',
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = True
|
||||
mock_query.return_value = {'existing_user_id': existing_user}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_query.assert_called_once()
|
||||
mock_find.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_no_duplicate(self, token_manager):
|
||||
"""Test that no duplicate is found when none exists."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = False
|
||||
mock_query.return_value = {}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_keycloak_connection_error(
|
||||
self, token_manager
|
||||
):
|
||||
"""Test that KeycloakConnectionError triggers retry and raises RetryError."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query:
|
||||
mock_query.side_effect = KeycloakConnectionError('Connection failed')
|
||||
|
||||
# Act & Assert
|
||||
# KeycloakConnectionError is re-raised, which triggers retry decorator
|
||||
# After retries exhaust (2 attempts), it raises RetryError
|
||||
from tenacity import RetryError
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_general_exception(self, token_manager):
|
||||
"""Test that general exceptions are handled gracefully."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query:
|
||||
mock_query.side_effect = Exception('Unexpected error')
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestQueryUsersByWildcardPattern:
|
||||
"""Test cases for _query_users_by_wildcard_pattern method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_success_with_search(
|
||||
self, token_manager
|
||||
):
|
||||
"""Test successful query using search parameter."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
mock_users = [
|
||||
{'id': 'user1', 'email': 'joe@example.com'},
|
||||
{'id': 'user2', 'email': 'joe+test@example.com'},
|
||||
]
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_users = AsyncMock(return_value=mock_users)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert 'user1' in result
|
||||
assert 'user2' in result
|
||||
mock_admin.a_get_users.assert_called_once_with(
|
||||
{'search': 'joe*@example.com'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_fallback_to_q(self, token_manager):
|
||||
"""Test fallback to q parameter when search fails."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
mock_users = [{'id': 'user1', 'email': 'joe@example.com'}]
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
# First call fails, second succeeds
|
||||
mock_admin.a_get_users = AsyncMock(
|
||||
side_effect=[Exception('Search failed'), mock_users]
|
||||
)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert 'user1' in result
|
||||
assert mock_admin.a_get_users.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_empty_result(self, token_manager):
|
||||
"""Test query returns empty dict when no users found."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_users = AsyncMock(return_value=[])
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestFindDuplicateInUsers:
|
||||
"""Test cases for _find_duplicate_in_users method."""
|
||||
|
||||
def test_find_duplicate_in_users_with_regex_match(self, token_manager):
|
||||
"""Test finding duplicate using regex pattern."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
'user2': {'id': 'user2', 'email': 'joe+test@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user3'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_find_duplicate_in_users_fallback_to_simple_matching(self, token_manager):
|
||||
"""Test fallback to simple matching when regex pattern is None."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
}
|
||||
base_email = 'invalid-email' # Will cause regex pattern to be None
|
||||
current_user_id = 'user2'
|
||||
|
||||
with patch(
|
||||
'server.auth.token_manager.get_base_email_regex_pattern', return_value=None
|
||||
):
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should use fallback matching, but invalid base_email won't match
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_excludes_current_user(self, token_manager):
|
||||
"""Test that current user is excluded from duplicate check."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user1' # Same as user in users dict
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_no_match(self, token_manager):
|
||||
"""Test that no duplicate is found when emails don't match."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'jane@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user2'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_empty_dict(self, token_manager):
|
||||
"""Test that empty users dict returns False."""
|
||||
# Arrange
|
||||
users: dict[str, dict] = {}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user1'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDeleteKeycloakUser:
|
||||
"""Test cases for delete_keycloak_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_success(self, token_manager):
|
||||
"""Test successful deletion of Keycloak user."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.return_value = None
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_to_thread.assert_called_once_with(mock_admin.delete_user, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_connection_error(self, token_manager):
|
||||
"""Test handling of KeycloakConnectionError triggers retry and raises RetryError."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = KeycloakConnectionError('Connection failed')
|
||||
|
||||
# Act & Assert
|
||||
# KeycloakConnectionError triggers retry decorator
|
||||
# After retries exhaust (2 attempts), it raises RetryError
|
||||
from tenacity import RetryError
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_keycloak_error(self, token_manager):
|
||||
"""Test handling of KeycloakError (e.g., user not found)."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = KeycloakError('User not found')
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_general_exception(self, token_manager):
|
||||
"""Test handling of general exceptions."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = Exception('Unexpected error')
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { AuthModal } from "#/components/features/waitlist/auth-modal";
|
||||
|
||||
// Mock the useAuthUrl hook
|
||||
@ -27,11 +28,13 @@ describe("AuthModal", () => {
|
||||
|
||||
it("should render the GitHub and GitLab buttons", () => {
|
||||
render(
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github", "gitlab"]}
|
||||
/>,
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github", "gitlab"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
@ -49,11 +52,13 @@ describe("AuthModal", () => {
|
||||
const user = userEvent.setup();
|
||||
const mockUrl = "https://github.com/login/oauth/authorize";
|
||||
render(
|
||||
<AuthModal
|
||||
githubAuthUrl={mockUrl}
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>,
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl={mockUrl}
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
@ -65,7 +70,11 @@ describe("AuthModal", () => {
|
||||
});
|
||||
|
||||
it("should render Terms of Service and Privacy Policy text with correct links", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Find the terms of service section using data-testid
|
||||
const termsSection = screen.getByTestId("auth-modal-terms-of-service");
|
||||
@ -106,7 +115,11 @@ describe("AuthModal", () => {
|
||||
});
|
||||
|
||||
it("should open Terms of Service link in new tab", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const tosLink = screen.getByRole("link", {
|
||||
name: "COMMON$TERMS_OF_SERVICE",
|
||||
@ -115,11 +128,53 @@ describe("AuthModal", () => {
|
||||
});
|
||||
|
||||
it("should open Privacy Policy link in new tab", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const privacyLink = screen.getByRole("link", {
|
||||
name: "COMMON$PRIVACY_POLICY",
|
||||
});
|
||||
expect(privacyLink).toHaveAttribute("target", "_blank");
|
||||
});
|
||||
|
||||
describe("Duplicate email error message", () => {
|
||||
const renderAuthModalWithRouter = (initialEntries: string[]) => {
|
||||
return render(
|
||||
<MemoryRouter initialEntries={initialEntries}>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
};
|
||||
|
||||
it("should display error message when duplicated_email query parameter is true", () => {
|
||||
// Arrange
|
||||
const initialEntries = ["/?duplicated_email=true"];
|
||||
|
||||
// Act
|
||||
renderAuthModalWithRouter(initialEntries);
|
||||
|
||||
// Assert
|
||||
const errorMessage = screen.getByText("AUTH$DUPLICATE_EMAIL_ERROR");
|
||||
expect(errorMessage).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display error message when duplicated_email query parameter is missing", () => {
|
||||
// Arrange
|
||||
const initialEntries = ["/"];
|
||||
|
||||
// Act
|
||||
renderAuthModalWithRouter(initialEntries);
|
||||
|
||||
// Assert
|
||||
const errorMessage = screen.queryByText("AUTH$DUPLICATE_EMAIL_ERROR");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useSearchParams } from "react-router";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import OpenHandsLogo from "#/assets/branding/openhands-logo.svg?react";
|
||||
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
|
||||
@ -29,6 +30,8 @@ export function AuthModal({
|
||||
}: AuthModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const { trackLoginButtonClick } = useTracking();
|
||||
const [searchParams] = useSearchParams();
|
||||
const hasDuplicatedEmail = searchParams.get("duplicated_email") === "true";
|
||||
|
||||
const gitlabAuthUrl = useAuthUrl({
|
||||
appMode: appMode || null,
|
||||
@ -123,6 +126,11 @@ export function AuthModal({
|
||||
<ModalBackdrop>
|
||||
<ModalBody className="border border-tertiary">
|
||||
<OpenHandsLogo width={68} height={46} />
|
||||
{hasDuplicatedEmail && (
|
||||
<div className="text-center text-danger text-sm mt-2 mb-2">
|
||||
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex flex-col gap-2 w-full items-center text-center">
|
||||
<h1 className="text-2xl font-bold">
|
||||
{t(I18nKey.AUTH$SIGN_IN_WITH_IDENTITY_PROVIDER)}
|
||||
|
||||
@ -730,6 +730,7 @@ export enum I18nKey {
|
||||
MICROAGENT_MANAGEMENT$USE_MICROAGENTS = "MICROAGENT_MANAGEMENT$USE_MICROAGENTS",
|
||||
AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR = "AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR",
|
||||
AUTH$NO_PROVIDERS_CONFIGURED = "AUTH$NO_PROVIDERS_CONFIGURED",
|
||||
AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR",
|
||||
COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE",
|
||||
COMMON$AND = "COMMON$AND",
|
||||
COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY",
|
||||
|
||||
@ -11679,6 +11679,22 @@
|
||||
"de": "Mindestens ein Identitätsanbieter muss konfiguriert werden (z.B. GitHub)",
|
||||
"uk": "Принаймні один постачальник ідентифікації має бути налаштований (наприклад, GitHub)"
|
||||
},
|
||||
"AUTH$DUPLICATE_EMAIL_ERROR": {
|
||||
"en": "Your account is unable to be created. Please use a different login or try again.",
|
||||
"ja": "アカウントを作成できません。別のログインを使用するか、もう一度お試しください。",
|
||||
"zh-CN": "无法创建您的账户。请使用其他登录方式或重试。",
|
||||
"zh-TW": "無法建立您的帳戶。請使用其他登入方式或重試。",
|
||||
"ko-KR": "계정을 생성할 수 없습니다. 다른 로그인을 사용하거나 다시 시도해 주세요.",
|
||||
"no": "Kontoen din kan ikke opprettes. Vennligst bruk en annen innlogging eller prøv igjen.",
|
||||
"it": "Impossibile creare il tuo account. Utilizza un altro accesso o riprova.",
|
||||
"pt": "Não foi possível criar sua conta. Use um login diferente ou tente novamente.",
|
||||
"es": "No se puede crear su cuenta. Utilice un inicio de sesión diferente o inténtelo de nuevo.",
|
||||
"ar": "لا يمكن إنشاء حسابك. يرجى استخدام تسجيل دخول مختلف أو المحاولة مرة أخرى.",
|
||||
"fr": "Votre compte ne peut pas être créé. Veuillez utiliser une autre connexion ou réessayer.",
|
||||
"tr": "Hesabınız oluşturulamadı. Lütfen farklı bir giriş kullanın veya tekrar deneyin.",
|
||||
"de": "Ihr Konto kann nicht erstellt werden. Bitte verwenden Sie eine andere Anmeldung oder versuchen Sie es erneut.",
|
||||
"uk": "Ваш обліковий запис не може бути створений. Будь ласка, використовуйте інший спосіб входу або спробуйте ще раз."
|
||||
},
|
||||
"COMMON$TERMS_OF_SERVICE": {
|
||||
"en": "Terms of Service",
|
||||
"ja": "利用規約",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user