mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-25 21:36:52 +08:00
feat: support blocking specific email domains (#12115)
This commit is contained in:
parent
6605070d05
commit
5553d3ca2e
@ -38,3 +38,8 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
|
||||
'y',
|
||||
'on',
|
||||
)
|
||||
BLOCKED_EMAIL_DOMAINS = [
|
||||
domain.strip().lower()
|
||||
for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',')
|
||||
if domain.strip()
|
||||
]
|
||||
|
||||
56
enterprise/server/auth/domain_blocker.py
Normal file
56
enterprise/server/auth/domain_blocker.py
Normal file
@ -0,0 +1,56 @@
|
||||
from server.auth.constants import BLOCKED_EMAIL_DOMAINS
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class DomainBlocker:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing DomainBlocker')
|
||||
self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS
|
||||
if self.blocked_domains:
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}'
|
||||
)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if domain blocking is enabled"""
|
||||
return bool(self.blocked_domains)
|
||||
|
||||
def _extract_domain(self, email: str) -> str | None:
|
||||
"""Extract and normalize email domain from email address"""
|
||||
if not email:
|
||||
return None
|
||||
try:
|
||||
# Extract domain part after @
|
||||
if '@' not in email:
|
||||
return None
|
||||
domain = email.split('@')[1].strip().lower()
|
||||
return domain if domain else None
|
||||
except Exception:
|
||||
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||
return None
|
||||
|
||||
def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked"""
|
||||
if not self.is_active():
|
||||
return False
|
||||
|
||||
if not email:
|
||||
logger.debug('No email provided for domain check')
|
||||
return False
|
||||
|
||||
domain = self._extract_domain(email)
|
||||
if not domain:
|
||||
logger.debug(f'Could not extract domain from email: {email}')
|
||||
return False
|
||||
|
||||
is_blocked = domain in self.blocked_domains
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
else:
|
||||
logger.debug(f'Email domain {domain} is not blocked')
|
||||
|
||||
return is_blocked
|
||||
|
||||
|
||||
domain_blocker = DomainBlocker()
|
||||
@ -13,6 +13,7 @@ from server.auth.auth_error import (
|
||||
ExpiredError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@ -312,6 +313,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
user_id = access_token_payload['sub']
|
||||
email = access_token_payload['email']
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
raise AuthError(
|
||||
'Access denied: Your email domain is not allowed to access this service'
|
||||
)
|
||||
|
||||
logger.debug('saas_user_auth_from_signed_token:return')
|
||||
|
||||
return SaasUserAuth(
|
||||
|
||||
@ -527,6 +527,49 @@ class TokenManager:
|
||||
github_id = github_ids[0]
|
||||
return github_id
|
||||
|
||||
async def disable_keycloak_user(
|
||||
self, user_id: str, email: str | None = None
|
||||
) -> None:
|
||||
"""Disable a Keycloak user account.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to disable
|
||||
email: Optional email address for logging purposes
|
||||
|
||||
This method attempts to disable the user account but will not raise exceptions.
|
||||
Errors are logged but do not prevent the operation from completing.
|
||||
"""
|
||||
try:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
# Get current user to preserve other fields
|
||||
user = await keycloak_admin.a_get_user(user_id)
|
||||
if user:
|
||||
# Update user with enabled=False to disable the account
|
||||
await keycloak_admin.a_update_user(
|
||||
user_id=user_id,
|
||||
payload={
|
||||
'enabled': False,
|
||||
'username': user.get('username', ''),
|
||||
'email': user.get('email', ''),
|
||||
'emailVerified': user.get('emailVerified', False),
|
||||
},
|
||||
)
|
||||
email_str = f', email: {email}' if email else ''
|
||||
logger.info(
|
||||
f'Disabled Keycloak account for user_id: {user_id}{email_str}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'User not found in Keycloak when attempting to disable: {user_id}'
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't raise - the caller should handle the blocking regardless
|
||||
email_str = f', email: {email}' if email else ''
|
||||
logger.error(
|
||||
f'Failed to disable Keycloak account for user_id: {user_id}{email_str}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def store_org_token(self, installation_id: int, installation_token: str):
|
||||
"""Store a GitHub App installation token.
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from server.auth.constants import (
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
ROLE_CHECK_ENABLED,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
@ -145,7 +146,24 @@ async def keycloak_callback(
|
||||
content={'error': 'Missing user ID or username in response'},
|
||||
)
|
||||
|
||||
# Check if email domain is blocked
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
|
||||
# Disable the Keycloak account
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': 'Access denied: Your email domain is not allowed to access this service'
|
||||
},
|
||||
)
|
||||
|
||||
# default to github IDP for now.
|
||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||
|
||||
@ -442,3 +442,196 @@ async def test_logout_without_refresh_token():
|
||||
|
||||
mock_token_manager.logout.assert_not_called()
|
||||
assert 'set-cookie' in result.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
"""Test keycloak_callback when email domain is blocked."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@colsch.us',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert 'error' in result.body.decode()
|
||||
assert 'email domain is not allowed' in result.body.decode()
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
mock_token_manager.disable_keycloak_user.assert_called_once_with(
|
||||
'test_user_id', 'user@colsch.us'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
"""Test keycloak_callback when email domain is not blocked."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||
'user@example.com'
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
"""Test keycloak_callback when domain blocking is not active."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@colsch.us',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_missing_email(mock_request):
|
||||
"""Test keycloak_callback when user info does not contain email."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
# No email field
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
181
enterprise/tests/unit/test_domain_blocker.py
Normal file
181
enterprise/tests/unit/test_domain_blocker.py
Normal file
@ -0,0 +1,181 @@
|
||||
"""Unit tests for DomainBlocker class."""
|
||||
|
||||
import pytest
|
||||
from server.auth.domain_blocker import DomainBlocker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def domain_blocker():
|
||||
"""Create a DomainBlocker instance for testing."""
|
||||
return DomainBlocker()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'blocked_domains,expected',
|
||||
[
|
||||
(['colsch.us', 'other-domain.com'], True),
|
||||
(['example.com'], True),
|
||||
([], False),
|
||||
],
|
||||
)
|
||||
def test_is_active(domain_blocker, blocked_domains, expected):
|
||||
"""Test that is_active returns correct value based on blocked domains configuration."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = blocked_domains
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_active()
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected_domain',
|
||||
[
|
||||
('user@example.com', 'example.com'),
|
||||
('test@colsch.us', 'colsch.us'),
|
||||
('user.name@other-domain.com', 'other-domain.com'),
|
||||
('USER@EXAMPLE.COM', 'example.com'), # Case insensitive
|
||||
('user@EXAMPLE.COM', 'example.com'),
|
||||
(' user@example.com ', 'example.com'), # Whitespace handling
|
||||
],
|
||||
)
|
||||
def test_extract_domain_valid_emails(domain_blocker, email, expected_domain):
|
||||
"""Test that _extract_domain correctly extracts and normalizes domains from valid emails."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected_domain
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected',
|
||||
[
|
||||
(None, None),
|
||||
('', None),
|
||||
('invalid-email', None),
|
||||
('user@', None), # Empty domain after @
|
||||
('no-at-sign', None),
|
||||
],
|
||||
)
|
||||
def test_extract_domain_invalid_emails(domain_blocker, email, expected):
|
||||
"""Test that _extract_domain returns None for invalid email formats."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_is_domain_blocked_when_inactive(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when blocking is not active."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = []
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_none_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email is None."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(None)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_empty_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email is empty."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_invalid_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email format is invalid."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('invalid-email')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_not_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when domain is not in blocked list."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns True when domain is in blocked list."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_case_insensitive(domain_blocker):
|
||||
"""Test that is_domain_blocked performs case-insensitive domain matching."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
|
||||
"""Test that is_domain_blocked correctly checks against multiple blocked domains."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org']
|
||||
|
||||
# Act
|
||||
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
|
||||
result2 = domain_blocker.is_domain_blocked('user@blocked.org')
|
||||
result3 = domain_blocker.is_domain_blocked('user@allowed.com')
|
||||
|
||||
# Assert
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_whitespace(domain_blocker):
|
||||
"""Test that is_domain_blocked handles emails with whitespace correctly."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
@ -5,7 +5,12 @@ import jwt
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_error import BearerTokenError, CookieError, NoCredentialsError
|
||||
from server.auth.auth_error import (
|
||||
AuthError,
|
||||
BearerTokenError,
|
||||
CookieError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.saas_user_auth import (
|
||||
SaasUserAuth,
|
||||
get_api_key_from_header,
|
||||
@ -647,3 +652,97 @@ def test_get_api_key_from_header_bearer_with_empty_token():
|
||||
# Assert that empty string from Bearer is returned (current behavior)
|
||||
# This tests the current implementation behavior
|
||||
assert api_key == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@colsch.us',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
assert 'email domain is not allowed' in str(exc_info.value)
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
assert result.email == 'user@example.com'
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||
'user@example.com'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@colsch.us',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.auth.token_manager import TokenManager, create_encryption_utility
|
||||
@ -246,3 +246,103 @@ async def test_refresh(token_manager):
|
||||
mock_keycloak.return_value.a_refresh_token.assert_called_once_with(
|
||||
'test_refresh_token'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_success(token_manager):
|
||||
"""Test successful disabling of a Keycloak user account."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
email = 'user@colsch.us'
|
||||
mock_user = {
|
||||
'id': user_id,
|
||||
'username': 'testuser',
|
||||
'email': email,
|
||||
'emailVerified': True,
|
||||
}
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
|
||||
mock_admin.a_update_user = AsyncMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
# Assert
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
mock_admin.a_update_user.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
payload={
|
||||
'enabled': False,
|
||||
'username': 'testuser',
|
||||
'email': email,
|
||||
'emailVerified': True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_without_email(token_manager):
|
||||
"""Test disabling Keycloak user without providing email."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_user = {
|
||||
'id': user_id,
|
||||
'username': 'testuser',
|
||||
'email': 'user@example.com',
|
||||
'emailVerified': False,
|
||||
}
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
|
||||
mock_admin.a_update_user = AsyncMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
await token_manager.disable_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
mock_admin.a_update_user.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_not_found(token_manager):
|
||||
"""Test disabling Keycloak user when user is not found."""
|
||||
# Arrange
|
||||
user_id = 'nonexistent_user_id'
|
||||
email = 'user@colsch.us'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(return_value=None)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
# Assert
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
mock_admin.a_update_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_exception_handling(token_manager):
|
||||
"""Test that disable_keycloak_user handles exceptions gracefully without raising."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
email = 'user@colsch.us'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(side_effect=Exception('Connection error'))
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act & Assert - should not raise exception
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
# Verify the method was called
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user