mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
feat: support blocking specific email domains (#12115)
This commit is contained in:
parent
6605070d05
commit
5553d3ca2e
@ -38,3 +38,8 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
|
|||||||
'y',
|
'y',
|
||||||
'on',
|
'on',
|
||||||
)
|
)
|
||||||
|
BLOCKED_EMAIL_DOMAINS = [
|
||||||
|
domain.strip().lower()
|
||||||
|
for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',')
|
||||||
|
if domain.strip()
|
||||||
|
]
|
||||||
|
|||||||
56
enterprise/server/auth/domain_blocker.py
Normal file
56
enterprise/server/auth/domain_blocker.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from server.auth.constants import BLOCKED_EMAIL_DOMAINS
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
|
||||||
|
|
||||||
|
class DomainBlocker:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
logger.debug('Initializing DomainBlocker')
|
||||||
|
self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS
|
||||||
|
if self.blocked_domains:
|
||||||
|
logger.info(
|
||||||
|
f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
"""Check if domain blocking is enabled"""
|
||||||
|
return bool(self.blocked_domains)
|
||||||
|
|
||||||
|
def _extract_domain(self, email: str) -> str | None:
|
||||||
|
"""Extract and normalize email domain from email address"""
|
||||||
|
if not email:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
# Extract domain part after @
|
||||||
|
if '@' not in email:
|
||||||
|
return None
|
||||||
|
domain = email.split('@')[1].strip().lower()
|
||||||
|
return domain if domain else None
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_domain_blocked(self, email: str) -> bool:
|
||||||
|
"""Check if email domain is blocked"""
|
||||||
|
if not self.is_active():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not email:
|
||||||
|
logger.debug('No email provided for domain check')
|
||||||
|
return False
|
||||||
|
|
||||||
|
domain = self._extract_domain(email)
|
||||||
|
if not domain:
|
||||||
|
logger.debug(f'Could not extract domain from email: {email}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
is_blocked = domain in self.blocked_domains
|
||||||
|
if is_blocked:
|
||||||
|
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||||
|
else:
|
||||||
|
logger.debug(f'Email domain {domain} is not blocked')
|
||||||
|
|
||||||
|
return is_blocked
|
||||||
|
|
||||||
|
|
||||||
|
domain_blocker = DomainBlocker()
|
||||||
@ -13,6 +13,7 @@ from server.auth.auth_error import (
|
|||||||
ExpiredError,
|
ExpiredError,
|
||||||
NoCredentialsError,
|
NoCredentialsError,
|
||||||
)
|
)
|
||||||
|
from server.auth.domain_blocker import domain_blocker
|
||||||
from server.auth.token_manager import TokenManager
|
from server.auth.token_manager import TokenManager
|
||||||
from server.config import get_config
|
from server.config import get_config
|
||||||
from server.logger import logger
|
from server.logger import logger
|
||||||
@ -312,6 +313,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
|||||||
user_id = access_token_payload['sub']
|
user_id = access_token_payload['sub']
|
||||||
email = access_token_payload['email']
|
email = access_token_payload['email']
|
||||||
email_verified = access_token_payload['email_verified']
|
email_verified = access_token_payload['email_verified']
|
||||||
|
|
||||||
|
# Check if email domain is blocked
|
||||||
|
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||||
|
logger.warning(
|
||||||
|
f'Blocked authentication attempt for existing user with email: {email}'
|
||||||
|
)
|
||||||
|
raise AuthError(
|
||||||
|
'Access denied: Your email domain is not allowed to access this service'
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug('saas_user_auth_from_signed_token:return')
|
logger.debug('saas_user_auth_from_signed_token:return')
|
||||||
|
|
||||||
return SaasUserAuth(
|
return SaasUserAuth(
|
||||||
|
|||||||
@ -527,6 +527,49 @@ class TokenManager:
|
|||||||
github_id = github_ids[0]
|
github_id = github_ids[0]
|
||||||
return github_id
|
return github_id
|
||||||
|
|
||||||
|
async def disable_keycloak_user(
|
||||||
|
self, user_id: str, email: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Disable a Keycloak user account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The Keycloak user ID to disable
|
||||||
|
email: Optional email address for logging purposes
|
||||||
|
|
||||||
|
This method attempts to disable the user account but will not raise exceptions.
|
||||||
|
Errors are logged but do not prevent the operation from completing.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
keycloak_admin = get_keycloak_admin(self.external)
|
||||||
|
# Get current user to preserve other fields
|
||||||
|
user = await keycloak_admin.a_get_user(user_id)
|
||||||
|
if user:
|
||||||
|
# Update user with enabled=False to disable the account
|
||||||
|
await keycloak_admin.a_update_user(
|
||||||
|
user_id=user_id,
|
||||||
|
payload={
|
||||||
|
'enabled': False,
|
||||||
|
'username': user.get('username', ''),
|
||||||
|
'email': user.get('email', ''),
|
||||||
|
'emailVerified': user.get('emailVerified', False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
email_str = f', email: {email}' if email else ''
|
||||||
|
logger.info(
|
||||||
|
f'Disabled Keycloak account for user_id: {user_id}{email_str}'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f'User not found in Keycloak when attempting to disable: {user_id}'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't raise - the caller should handle the blocking regardless
|
||||||
|
email_str = f', email: {email}' if email else ''
|
||||||
|
logger.error(
|
||||||
|
f'Failed to disable Keycloak account for user_id: {user_id}{email_str}: {str(e)}',
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
def store_org_token(self, installation_id: int, installation_token: str):
|
def store_org_token(self, installation_id: int, installation_token: str):
|
||||||
"""Store a GitHub App installation token.
|
"""Store a GitHub App installation token.
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from server.auth.constants import (
|
|||||||
KEYCLOAK_SERVER_URL_EXT,
|
KEYCLOAK_SERVER_URL_EXT,
|
||||||
ROLE_CHECK_ENABLED,
|
ROLE_CHECK_ENABLED,
|
||||||
)
|
)
|
||||||
|
from server.auth.domain_blocker import domain_blocker
|
||||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||||
from server.auth.saas_user_auth import SaasUserAuth
|
from server.auth.saas_user_auth import SaasUserAuth
|
||||||
from server.auth.token_manager import TokenManager
|
from server.auth.token_manager import TokenManager
|
||||||
@ -145,7 +146,24 @@ async def keycloak_callback(
|
|||||||
content={'error': 'Missing user ID or username in response'},
|
content={'error': 'Missing user ID or username in response'},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if email domain is blocked
|
||||||
|
email = user_info.get('email')
|
||||||
user_id = user_info['sub']
|
user_id = user_info['sub']
|
||||||
|
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||||
|
logger.warning(
|
||||||
|
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable the Keycloak account
|
||||||
|
await token_manager.disable_keycloak_user(user_id, email)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={
|
||||||
|
'error': 'Access denied: Your email domain is not allowed to access this service'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# default to github IDP for now.
|
# default to github IDP for now.
|
||||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||||
|
|||||||
@ -442,3 +442,196 @@ async def test_logout_without_refresh_token():
|
|||||||
|
|
||||||
mock_token_manager.logout.assert_not_called()
|
mock_token_manager.logout.assert_not_called()
|
||||||
assert 'set-cookie' in result.headers
|
assert 'set-cookie' in result.headers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||||
|
"""Test keycloak_callback when email domain is blocked."""
|
||||||
|
# Arrange
|
||||||
|
with (
|
||||||
|
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||||
|
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||||
|
):
|
||||||
|
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||||
|
return_value=('test_access_token', 'test_refresh_token')
|
||||||
|
)
|
||||||
|
mock_token_manager.get_user_info = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
'sub': 'test_user_id',
|
||||||
|
'preferred_username': 'test_user',
|
||||||
|
'email': 'user@colsch.us',
|
||||||
|
'identity_provider': 'github',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_token_manager.disable_keycloak_user = AsyncMock()
|
||||||
|
|
||||||
|
mock_domain_blocker.is_active.return_value = True
|
||||||
|
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await keycloak_callback(
|
||||||
|
code='test_code', state='test_state', request=mock_request
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, JSONResponse)
|
||||||
|
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
assert 'error' in result.body.decode()
|
||||||
|
assert 'email domain is not allowed' in result.body.decode()
|
||||||
|
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||||
|
mock_token_manager.disable_keycloak_user.assert_called_once_with(
|
||||||
|
'test_user_id', 'user@colsch.us'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||||
|
"""Test keycloak_callback when email domain is not blocked."""
|
||||||
|
# Arrange
|
||||||
|
with (
|
||||||
|
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||||
|
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||||
|
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||||
|
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||||
|
):
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
|
||||||
|
mock_user_settings = MagicMock()
|
||||||
|
mock_user_settings.accepted_tos = '2025-01-01'
|
||||||
|
mock_query.first.return_value = mock_user_settings
|
||||||
|
|
||||||
|
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||||
|
return_value=('test_access_token', 'test_refresh_token')
|
||||||
|
)
|
||||||
|
mock_token_manager.get_user_info = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
'sub': 'test_user_id',
|
||||||
|
'preferred_username': 'test_user',
|
||||||
|
'email': 'user@example.com',
|
||||||
|
'identity_provider': 'github',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||||
|
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
mock_domain_blocker.is_active.return_value = True
|
||||||
|
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||||
|
|
||||||
|
mock_verifier.is_active.return_value = True
|
||||||
|
mock_verifier.is_user_allowed.return_value = True
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await keycloak_callback(
|
||||||
|
code='test_code', state='test_state', request=mock_request
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, RedirectResponse)
|
||||||
|
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||||
|
'user@example.com'
|
||||||
|
)
|
||||||
|
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||||
|
"""Test keycloak_callback when domain blocking is not active."""
|
||||||
|
# Arrange
|
||||||
|
with (
|
||||||
|
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||||
|
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||||
|
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||||
|
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||||
|
):
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
|
||||||
|
mock_user_settings = MagicMock()
|
||||||
|
mock_user_settings.accepted_tos = '2025-01-01'
|
||||||
|
mock_query.first.return_value = mock_user_settings
|
||||||
|
|
||||||
|
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||||
|
return_value=('test_access_token', 'test_refresh_token')
|
||||||
|
)
|
||||||
|
mock_token_manager.get_user_info = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
'sub': 'test_user_id',
|
||||||
|
'preferred_username': 'test_user',
|
||||||
|
'email': 'user@colsch.us',
|
||||||
|
'identity_provider': 'github',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||||
|
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
mock_domain_blocker.is_active.return_value = False
|
||||||
|
|
||||||
|
mock_verifier.is_active.return_value = True
|
||||||
|
mock_verifier.is_user_allowed.return_value = True
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await keycloak_callback(
|
||||||
|
code='test_code', state='test_state', request=mock_request
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, RedirectResponse)
|
||||||
|
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||||
|
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keycloak_callback_missing_email(mock_request):
|
||||||
|
"""Test keycloak_callback when user info does not contain email."""
|
||||||
|
# Arrange
|
||||||
|
with (
|
||||||
|
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||||
|
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||||
|
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||||
|
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||||
|
):
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
|
||||||
|
mock_user_settings = MagicMock()
|
||||||
|
mock_user_settings.accepted_tos = '2025-01-01'
|
||||||
|
mock_query.first.return_value = mock_user_settings
|
||||||
|
|
||||||
|
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||||
|
return_value=('test_access_token', 'test_refresh_token')
|
||||||
|
)
|
||||||
|
mock_token_manager.get_user_info = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
'sub': 'test_user_id',
|
||||||
|
'preferred_username': 'test_user',
|
||||||
|
'identity_provider': 'github',
|
||||||
|
# No email field
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||||
|
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
mock_domain_blocker.is_active.return_value = True
|
||||||
|
|
||||||
|
mock_verifier.is_active.return_value = True
|
||||||
|
mock_verifier.is_user_allowed.return_value = True
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await keycloak_callback(
|
||||||
|
code='test_code', state='test_state', request=mock_request
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, RedirectResponse)
|
||||||
|
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||||
|
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||||
|
|||||||
181
enterprise/tests/unit/test_domain_blocker.py
Normal file
181
enterprise/tests/unit/test_domain_blocker.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
"""Unit tests for DomainBlocker class."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from server.auth.domain_blocker import DomainBlocker
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def domain_blocker():
|
||||||
|
"""Create a DomainBlocker instance for testing."""
|
||||||
|
return DomainBlocker()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'blocked_domains,expected',
|
||||||
|
[
|
||||||
|
(['colsch.us', 'other-domain.com'], True),
|
||||||
|
(['example.com'], True),
|
||||||
|
([], False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_active(domain_blocker, blocked_domains, expected):
|
||||||
|
"""Test that is_active returns correct value based on blocked domains configuration."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = blocked_domains
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = domain_blocker.is_active()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'email,expected_domain',
|
||||||
|
[
|
||||||
|
('user@example.com', 'example.com'),
|
||||||
|
('test@colsch.us', 'colsch.us'),
|
||||||
|
('user.name@other-domain.com', 'other-domain.com'),
|
||||||
|
('USER@EXAMPLE.COM', 'example.com'), # Case insensitive
|
||||||
|
('user@EXAMPLE.COM', 'example.com'),
|
||||||
|
(' user@example.com ', 'example.com'), # Whitespace handling
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_domain_valid_emails(domain_blocker, email, expected_domain):
|
||||||
|
"""Test that _extract_domain correctly extracts and normalizes domains from valid emails."""
|
||||||
|
# Act
|
||||||
|
result = domain_blocker._extract_domain(email)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == expected_domain
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'email,expected',
|
||||||
|
[
|
||||||
|
(None, None),
|
||||||
|
('', None),
|
||||||
|
('invalid-email', None),
|
||||||
|
('user@', None), # Empty domain after @
|
||||||
|
('no-at-sign', None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_domain_invalid_emails(domain_blocker, email, expected):
|
||||||
|
"""Test that _extract_domain returns None for invalid email formats."""
|
||||||
|
# Act
|
||||||
|
result = domain_blocker._extract_domain(email)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_domain_blocked_when_inactive(domain_blocker):
|
||||||
|
"""Test that is_domain_blocked returns False when blocking is not active."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = []
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_domain_blocked_with_none_email(domain_blocker):
|
||||||
|
"""Test that is_domain_blocked returns False when email is None."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = ['colsch.us']
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = domain_blocker.is_domain_blocked(None)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_domain_blocked_with_empty_email(domain_blocker):
|
||||||
|
"""Test that is_domain_blocked returns False when email is empty."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = ['colsch.us']
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = domain_blocker.is_domain_blocked('')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_domain_blocked_with_invalid_email(domain_blocker):
|
||||||
|
"""Test that is_domain_blocked returns False when email format is invalid."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = ['colsch.us']
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = domain_blocker.is_domain_blocked('invalid-email')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_domain_blocked_domain_not_blocked(domain_blocker):
|
||||||
|
"""Test that is_domain_blocked returns False when domain is not in blocked list."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_domain_blocked_domain_blocked(domain_blocker):
|
||||||
|
"""Test that is_domain_blocked returns True when domain is in blocked list."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_domain_blocked_case_insensitive(domain_blocker):
|
||||||
|
"""Test that is_domain_blocked performs case-insensitive domain matching."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = ['colsch.us']
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
|
||||||
|
"""Test that is_domain_blocked correctly checks against multiple blocked domains."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org']
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
|
||||||
|
result2 = domain_blocker.is_domain_blocked('user@blocked.org')
|
||||||
|
result3 = domain_blocker.is_domain_blocked('user@allowed.com')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result1 is True
|
||||||
|
assert result2 is True
|
||||||
|
assert result3 is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_domain_blocked_with_whitespace(domain_blocker):
|
||||||
|
"""Test that is_domain_blocked handles emails with whitespace correctly."""
|
||||||
|
# Arrange
|
||||||
|
domain_blocker.blocked_domains = ['colsch.us']
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
@ -5,7 +5,12 @@ import jwt
|
|||||||
import pytest
|
import pytest
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from server.auth.auth_error import BearerTokenError, CookieError, NoCredentialsError
|
from server.auth.auth_error import (
|
||||||
|
AuthError,
|
||||||
|
BearerTokenError,
|
||||||
|
CookieError,
|
||||||
|
NoCredentialsError,
|
||||||
|
)
|
||||||
from server.auth.saas_user_auth import (
|
from server.auth.saas_user_auth import (
|
||||||
SaasUserAuth,
|
SaasUserAuth,
|
||||||
get_api_key_from_header,
|
get_api_key_from_header,
|
||||||
@ -647,3 +652,97 @@ def test_get_api_key_from_header_bearer_with_empty_token():
|
|||||||
# Assert that empty string from Bearer is returned (current behavior)
|
# Assert that empty string from Bearer is returned (current behavior)
|
||||||
# This tests the current implementation behavior
|
# This tests the current implementation behavior
|
||||||
assert api_key == ''
|
assert api_key == ''
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
|
||||||
|
"""Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked."""
|
||||||
|
# Arrange
|
||||||
|
access_payload = {
|
||||||
|
'sub': 'test_user_id',
|
||||||
|
'exp': int(time.time()) + 3600,
|
||||||
|
'email': 'user@colsch.us',
|
||||||
|
'email_verified': True,
|
||||||
|
}
|
||||||
|
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||||
|
|
||||||
|
token_payload = {
|
||||||
|
'access_token': access_token,
|
||||||
|
'refresh_token': 'test_refresh_token',
|
||||||
|
}
|
||||||
|
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||||
|
|
||||||
|
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||||
|
mock_domain_blocker.is_active.return_value = True
|
||||||
|
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(AuthError) as exc_info:
|
||||||
|
await saas_user_auth_from_signed_token(signed_token)
|
||||||
|
|
||||||
|
assert 'email domain is not allowed' in str(exc_info.value)
|
||||||
|
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
||||||
|
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
|
||||||
|
# Arrange
|
||||||
|
access_payload = {
|
||||||
|
'sub': 'test_user_id',
|
||||||
|
'exp': int(time.time()) + 3600,
|
||||||
|
'email': 'user@example.com',
|
||||||
|
'email_verified': True,
|
||||||
|
}
|
||||||
|
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||||
|
|
||||||
|
token_payload = {
|
||||||
|
'access_token': access_token,
|
||||||
|
'refresh_token': 'test_refresh_token',
|
||||||
|
}
|
||||||
|
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||||
|
|
||||||
|
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||||
|
mock_domain_blocker.is_active.return_value = True
|
||||||
|
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await saas_user_auth_from_signed_token(signed_token)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, SaasUserAuth)
|
||||||
|
assert result.user_id == 'test_user_id'
|
||||||
|
assert result.email == 'user@example.com'
|
||||||
|
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||||
|
'user@example.com'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
|
||||||
|
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
|
||||||
|
# Arrange
|
||||||
|
access_payload = {
|
||||||
|
'sub': 'test_user_id',
|
||||||
|
'exp': int(time.time()) + 3600,
|
||||||
|
'email': 'user@colsch.us',
|
||||||
|
'email_verified': True,
|
||||||
|
}
|
||||||
|
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||||
|
|
||||||
|
token_payload = {
|
||||||
|
'access_token': access_token,
|
||||||
|
'refresh_token': 'test_refresh_token',
|
||||||
|
}
|
||||||
|
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||||
|
|
||||||
|
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||||
|
mock_domain_blocker.is_active.return_value = False
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await saas_user_auth_from_signed_token(signed_token)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, SaasUserAuth)
|
||||||
|
assert result.user_id == 'test_user_id'
|
||||||
|
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from server.auth.token_manager import TokenManager, create_encryption_utility
|
from server.auth.token_manager import TokenManager, create_encryption_utility
|
||||||
@ -246,3 +246,103 @@ async def test_refresh(token_manager):
|
|||||||
mock_keycloak.return_value.a_refresh_token.assert_called_once_with(
|
mock_keycloak.return_value.a_refresh_token.assert_called_once_with(
|
||||||
'test_refresh_token'
|
'test_refresh_token'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disable_keycloak_user_success(token_manager):
|
||||||
|
"""Test successful disabling of a Keycloak user account."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'test_user_id'
|
||||||
|
email = 'user@colsch.us'
|
||||||
|
mock_user = {
|
||||||
|
'id': user_id,
|
||||||
|
'username': 'testuser',
|
||||||
|
'email': email,
|
||||||
|
'emailVerified': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||||
|
mock_admin = MagicMock()
|
||||||
|
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
|
||||||
|
mock_admin.a_update_user = AsyncMock()
|
||||||
|
mock_get_admin.return_value = mock_admin
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await token_manager.disable_keycloak_user(user_id, email)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||||
|
mock_admin.a_update_user.assert_called_once_with(
|
||||||
|
user_id=user_id,
|
||||||
|
payload={
|
||||||
|
'enabled': False,
|
||||||
|
'username': 'testuser',
|
||||||
|
'email': email,
|
||||||
|
'emailVerified': True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disable_keycloak_user_without_email(token_manager):
|
||||||
|
"""Test disabling Keycloak user without providing email."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'test_user_id'
|
||||||
|
mock_user = {
|
||||||
|
'id': user_id,
|
||||||
|
'username': 'testuser',
|
||||||
|
'email': 'user@example.com',
|
||||||
|
'emailVerified': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||||
|
mock_admin = MagicMock()
|
||||||
|
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
|
||||||
|
mock_admin.a_update_user = AsyncMock()
|
||||||
|
mock_get_admin.return_value = mock_admin
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await token_manager.disable_keycloak_user(user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||||
|
mock_admin.a_update_user.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disable_keycloak_user_not_found(token_manager):
|
||||||
|
"""Test disabling Keycloak user when user is not found."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'nonexistent_user_id'
|
||||||
|
email = 'user@colsch.us'
|
||||||
|
|
||||||
|
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||||
|
mock_admin = MagicMock()
|
||||||
|
mock_admin.a_get_user = AsyncMock(return_value=None)
|
||||||
|
mock_get_admin.return_value = mock_admin
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await token_manager.disable_keycloak_user(user_id, email)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||||
|
mock_admin.a_update_user.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disable_keycloak_user_exception_handling(token_manager):
|
||||||
|
"""Test that disable_keycloak_user handles exceptions gracefully without raising."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'test_user_id'
|
||||||
|
email = 'user@colsch.us'
|
||||||
|
|
||||||
|
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||||
|
mock_admin = MagicMock()
|
||||||
|
mock_admin.a_get_user = AsyncMock(side_effect=Exception('Connection error'))
|
||||||
|
mock_get_admin.return_value = mock_admin
|
||||||
|
|
||||||
|
# Act & Assert - should not raise exception
|
||||||
|
await token_manager.disable_keycloak_user(user_id, email)
|
||||||
|
|
||||||
|
# Verify the method was called
|
||||||
|
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user