From 5553d3ca2e17f5e59fc40088ab41bfa7f6ec8a88 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Sun, 21 Dec 2025 19:49:11 +0700 Subject: [PATCH] feat: support blocking specific email domains (#12115) --- enterprise/server/auth/constants.py | 5 + enterprise/server/auth/domain_blocker.py | 56 +++++ enterprise/server/auth/saas_user_auth.py | 11 + enterprise/server/auth/token_manager.py | 43 ++++ enterprise/server/routes/auth.py | 18 ++ enterprise/tests/unit/test_auth_routes.py | 193 ++++++++++++++++++ enterprise/tests/unit/test_domain_blocker.py | 181 ++++++++++++++++ enterprise/tests/unit/test_saas_user_auth.py | 101 ++++++++- .../tests/unit/test_token_manager_extended.py | 102 ++++++++- 9 files changed, 708 insertions(+), 2 deletions(-) create mode 100644 enterprise/server/auth/domain_blocker.py create mode 100644 enterprise/tests/unit/test_domain_blocker.py diff --git a/enterprise/server/auth/constants.py b/enterprise/server/auth/constants.py index 15d3b0f704..242237e93d 100644 --- a/enterprise/server/auth/constants.py +++ b/enterprise/server/auth/constants.py @@ -38,3 +38,8 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in ( 'y', 'on', ) +BLOCKED_EMAIL_DOMAINS = [ + domain.strip().lower() + for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',') + if domain.strip() +] diff --git a/enterprise/server/auth/domain_blocker.py b/enterprise/server/auth/domain_blocker.py new file mode 100644 index 0000000000..169545ae2d --- /dev/null +++ b/enterprise/server/auth/domain_blocker.py @@ -0,0 +1,56 @@ +from server.auth.constants import BLOCKED_EMAIL_DOMAINS + +from openhands.core.logger import openhands_logger as logger + + +class DomainBlocker: + def __init__(self) -> None: + logger.debug('Initializing DomainBlocker') + self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS + if self.blocked_domains: + logger.info( + f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}' + ) + + def is_active(self) -> bool: + """Check if domain blocking is enabled""" + return bool(self.blocked_domains) + + def _extract_domain(self, email: str) -> str | None: + """Extract and normalize email domain from email address""" + if not email: + return None + try: + # Extract domain part after @ + if '@' not in email: + return None + domain = email.split('@')[1].strip().lower() + return domain if domain else None + except Exception: + logger.debug(f'Error extracting domain from email: {email}', exc_info=True) + return None + + def is_domain_blocked(self, email: str) -> bool: + """Check if email domain is blocked""" + if not self.is_active(): + return False + + if not email: + logger.debug('No email provided for domain check') + return False + + domain = self._extract_domain(email) + if not domain: + logger.debug(f'Could not extract domain from email: {email}') + return False + + is_blocked = domain in self.blocked_domains + if is_blocked: + logger.warning(f'Email domain {domain} is blocked for email: {email}') + else: + logger.debug(f'Email domain {domain} is not blocked') + + return is_blocked + + +domain_blocker = DomainBlocker() diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index 2f399a74cf..b51d336997 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -13,6 +13,7 @@ from server.auth.auth_error import ( ExpiredError, NoCredentialsError, ) +from server.auth.domain_blocker import domain_blocker from server.auth.token_manager import TokenManager from server.config import get_config from server.logger import logger @@ -312,6 +313,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth: user_id = access_token_payload['sub'] email = access_token_payload['email'] email_verified = access_token_payload['email_verified'] + + # Check if email domain is blocked + if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email): + logger.warning( + f'Blocked authentication attempt for existing user with email: {email}' + ) + raise AuthError( + 'Access denied: Your email domain is not allowed to access this service' + ) + logger.debug('saas_user_auth_from_signed_token:return') return SaasUserAuth( diff --git a/enterprise/server/auth/token_manager.py b/enterprise/server/auth/token_manager.py index 0b873bc7fc..04bfae0767 100644 --- a/enterprise/server/auth/token_manager.py +++ b/enterprise/server/auth/token_manager.py @@ -527,6 +527,49 @@ class TokenManager: github_id = github_ids[0] return github_id + async def disable_keycloak_user( + self, user_id: str, email: str | None = None + ) -> None: + """Disable a Keycloak user account. + + Args: + user_id: The Keycloak user ID to disable + email: Optional email address for logging purposes + + This method attempts to disable the user account but will not raise exceptions. + Errors are logged but do not prevent the operation from completing. + """ + try: + keycloak_admin = get_keycloak_admin(self.external) + # Get current user to preserve other fields + user = await keycloak_admin.a_get_user(user_id) + if user: + # Update user with enabled=False to disable the account + await keycloak_admin.a_update_user( + user_id=user_id, + payload={ + 'enabled': False, + 'username': user.get('username', ''), + 'email': user.get('email', ''), + 'emailVerified': user.get('emailVerified', False), + }, + ) + email_str = f', email: {email}' if email else '' + logger.info( + f'Disabled Keycloak account for user_id: {user_id}{email_str}' + ) + else: + logger.warning( + f'User not found in Keycloak when attempting to disable: {user_id}' + ) + except Exception as e: + # Log error but don't raise - the caller should handle the blocking regardless + email_str = f', email: {email}' if email else '' + logger.error( + f'Failed to disable Keycloak account for user_id: {user_id}{email_str}: {str(e)}', + exc_info=True, + ) + def store_org_token(self, installation_id: int, installation_token: str): """Store a GitHub App installation token. diff --git a/enterprise/server/routes/auth.py b/enterprise/server/routes/auth.py index ba7aadb883..2ee50bbd2d 100644 --- a/enterprise/server/routes/auth.py +++ b/enterprise/server/routes/auth.py @@ -14,6 +14,7 @@ from server.auth.constants import ( KEYCLOAK_SERVER_URL_EXT, ROLE_CHECK_ENABLED, ) +from server.auth.domain_blocker import domain_blocker from server.auth.gitlab_sync import schedule_gitlab_repo_sync from server.auth.saas_user_auth import SaasUserAuth from server.auth.token_manager import TokenManager @@ -145,7 +146,24 @@ async def keycloak_callback( content={'error': 'Missing user ID or username in response'}, ) + # Check if email domain is blocked + email = user_info.get('email') user_id = user_info['sub'] + if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email): + logger.warning( + f'Blocked authentication attempt for email: {email}, user_id: {user_id}' + ) + + # Disable the Keycloak account + await token_manager.disable_keycloak_user(user_id, email) + + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + 'error': 'Access denied: Your email domain is not allowed to access this service' + }, + ) + # default to github IDP for now. # TODO: remove default once Keycloak is updated universally with the new attribute. idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value) diff --git a/enterprise/tests/unit/test_auth_routes.py b/enterprise/tests/unit/test_auth_routes.py index bf74f0055c..d3e8f47fbe 100644 --- a/enterprise/tests/unit/test_auth_routes.py +++ b/enterprise/tests/unit/test_auth_routes.py @@ -442,3 +442,196 @@ async def test_logout_without_refresh_token(): mock_token_manager.logout.assert_not_called() assert 'set-cookie' in result.headers + + +@pytest.mark.asyncio +async def test_keycloak_callback_blocked_email_domain(mock_request): + """Test keycloak_callback when email domain is blocked.""" + # Arrange + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.domain_blocker') as mock_domain_blocker, + ): + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'user@colsch.us', + 'identity_provider': 'github', + } + ) + mock_token_manager.disable_keycloak_user = AsyncMock() + + mock_domain_blocker.is_active.return_value = True + mock_domain_blocker.is_domain_blocked.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, JSONResponse) + assert result.status_code == status.HTTP_401_UNAUTHORIZED + assert 'error' in result.body.decode() + assert 'email domain is not allowed' in result.body.decode() + mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us') + mock_token_manager.disable_keycloak_user.assert_called_once_with( + 'test_user_id', 'user@colsch.us' + ) + + +@pytest.mark.asyncio +async def test_keycloak_callback_allowed_email_domain(mock_request): + """Test keycloak_callback when email domain is not blocked.""" + # Arrange + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.domain_blocker') as mock_domain_blocker, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.session_maker') as mock_session_maker, + ): + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + + mock_user_settings = MagicMock() + mock_user_settings.accepted_tos = '2025-01-01' + mock_query.first.return_value = mock_user_settings + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'user@example.com', + 'identity_provider': 'github', + } + ) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + + mock_domain_blocker.is_active.return_value = True + mock_domain_blocker.is_domain_blocked.return_value = False + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + mock_domain_blocker.is_domain_blocked.assert_called_once_with( + 'user@example.com' + ) + mock_token_manager.disable_keycloak_user.assert_not_called() + + +@pytest.mark.asyncio +async def test_keycloak_callback_domain_blocking_inactive(mock_request): + """Test keycloak_callback when domain blocking is not active.""" + # Arrange + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.domain_blocker') as mock_domain_blocker, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.session_maker') as mock_session_maker, + ): + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + + mock_user_settings = MagicMock() + mock_user_settings.accepted_tos = '2025-01-01' + mock_query.first.return_value = mock_user_settings + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'email': 'user@colsch.us', + 'identity_provider': 'github', + } + ) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + + mock_domain_blocker.is_active.return_value = False + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + mock_domain_blocker.is_domain_blocked.assert_not_called() + mock_token_manager.disable_keycloak_user.assert_not_called() + + +@pytest.mark.asyncio +async def test_keycloak_callback_missing_email(mock_request): + """Test keycloak_callback when user info does not contain email.""" + # Arrange + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.domain_blocker') as mock_domain_blocker, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.session_maker') as mock_session_maker, + ): + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + + mock_user_settings = MagicMock() + mock_user_settings.accepted_tos = '2025-01-01' + mock_query.first.return_value = mock_user_settings + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'identity_provider': 'github', + # No email field + } + ) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + + mock_domain_blocker.is_active.return_value = True + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + # Act + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + # Assert + assert isinstance(result, RedirectResponse) + mock_domain_blocker.is_domain_blocked.assert_not_called() + mock_token_manager.disable_keycloak_user.assert_not_called() diff --git a/enterprise/tests/unit/test_domain_blocker.py b/enterprise/tests/unit/test_domain_blocker.py new file mode 100644 index 0000000000..e199a80b9b --- /dev/null +++ b/enterprise/tests/unit/test_domain_blocker.py @@ -0,0 +1,181 @@ +"""Unit tests for DomainBlocker class.""" + +import pytest +from server.auth.domain_blocker import DomainBlocker + + +@pytest.fixture +def domain_blocker(): + """Create a DomainBlocker instance for testing.""" + return DomainBlocker() + + +@pytest.mark.parametrize( + 'blocked_domains,expected', + [ + (['colsch.us', 'other-domain.com'], True), + (['example.com'], True), + ([], False), + ], +) +def test_is_active(domain_blocker, blocked_domains, expected): + """Test that is_active returns correct value based on blocked domains configuration.""" + # Arrange + domain_blocker.blocked_domains = blocked_domains + + # Act + result = domain_blocker.is_active() + + # Assert + assert result == expected + + +@pytest.mark.parametrize( + 'email,expected_domain', + [ + ('user@example.com', 'example.com'), + ('test@colsch.us', 'colsch.us'), + ('user.name@other-domain.com', 'other-domain.com'), + ('USER@EXAMPLE.COM', 'example.com'), # Case insensitive + ('user@EXAMPLE.COM', 'example.com'), + (' user@example.com ', 'example.com'), # Whitespace handling + ], +) +def test_extract_domain_valid_emails(domain_blocker, email, expected_domain): + """Test that _extract_domain correctly extracts and normalizes domains from valid emails.""" + # Act + result = domain_blocker._extract_domain(email) + + # Assert + assert result == expected_domain + + +@pytest.mark.parametrize( + 'email,expected', + [ + (None, None), + ('', None), + ('invalid-email', None), + ('user@', None), # Empty domain after @ + ('no-at-sign', None), + ], +) +def test_extract_domain_invalid_emails(domain_blocker, email, expected): + """Test that _extract_domain returns None for invalid email formats.""" + # Act + result = domain_blocker._extract_domain(email) + + # Assert + assert result == expected + + +def test_is_domain_blocked_when_inactive(domain_blocker): + """Test that is_domain_blocked returns False when blocking is not active.""" + # Arrange + domain_blocker.blocked_domains = [] + + # Act + result = domain_blocker.is_domain_blocked('user@colsch.us') + + # Assert + assert result is False + + +def test_is_domain_blocked_with_none_email(domain_blocker): + """Test that is_domain_blocked returns False when email is None.""" + # Arrange + domain_blocker.blocked_domains = ['colsch.us'] + + # Act + result = domain_blocker.is_domain_blocked(None) + + # Assert + assert result is False + + +def test_is_domain_blocked_with_empty_email(domain_blocker): + """Test that is_domain_blocked returns False when email is empty.""" + # Arrange + domain_blocker.blocked_domains = ['colsch.us'] + + # Act + result = domain_blocker.is_domain_blocked('') + + # Assert + assert result is False + + +def test_is_domain_blocked_with_invalid_email(domain_blocker): + """Test that is_domain_blocked returns False when email format is invalid.""" + # Arrange + domain_blocker.blocked_domains = ['colsch.us'] + + # Act + result = domain_blocker.is_domain_blocked('invalid-email') + + # Assert + assert result is False + + +def test_is_domain_blocked_domain_not_blocked(domain_blocker): + """Test that is_domain_blocked returns False when domain is not in blocked list.""" + # Arrange + domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com'] + + # Act + result = domain_blocker.is_domain_blocked('user@example.com') + + # Assert + assert result is False + + +def test_is_domain_blocked_domain_blocked(domain_blocker): + """Test that is_domain_blocked returns True when domain is in blocked list.""" + # Arrange + domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com'] + + # Act + result = domain_blocker.is_domain_blocked('user@colsch.us') + + # Assert + assert result is True + + +def test_is_domain_blocked_case_insensitive(domain_blocker): + """Test that is_domain_blocked performs case-insensitive domain matching.""" + # Arrange + domain_blocker.blocked_domains = ['colsch.us'] + + # Act + result = domain_blocker.is_domain_blocked('user@COLSCH.US') + + # Assert + assert result is True + + +def test_is_domain_blocked_multiple_blocked_domains(domain_blocker): + """Test that is_domain_blocked correctly checks against multiple blocked domains.""" + # Arrange + domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org'] + + # Act + result1 = domain_blocker.is_domain_blocked('user@other-domain.com') + result2 = domain_blocker.is_domain_blocked('user@blocked.org') + result3 = domain_blocker.is_domain_blocked('user@allowed.com') + + # Assert + assert result1 is True + assert result2 is True + assert result3 is False + + +def test_is_domain_blocked_with_whitespace(domain_blocker): + """Test that is_domain_blocked handles emails with whitespace correctly.""" + # Arrange + domain_blocker.blocked_domains = ['colsch.us'] + + # Act + result = domain_blocker.is_domain_blocked(' user@colsch.us ') + + # Assert + assert result is True diff --git a/enterprise/tests/unit/test_saas_user_auth.py b/enterprise/tests/unit/test_saas_user_auth.py index d4ba902677..a518beb28e 100644 --- a/enterprise/tests/unit/test_saas_user_auth.py +++ b/enterprise/tests/unit/test_saas_user_auth.py @@ -5,7 +5,12 @@ import jwt import pytest from fastapi import Request from pydantic import SecretStr -from server.auth.auth_error import BearerTokenError, CookieError, NoCredentialsError +from server.auth.auth_error import ( + AuthError, + BearerTokenError, + CookieError, + NoCredentialsError, +) from server.auth.saas_user_auth import ( SaasUserAuth, get_api_key_from_header, @@ -647,3 +652,97 @@ def test_get_api_key_from_header_bearer_with_empty_token(): # Assert that empty string from Bearer is returned (current behavior) # This tests the current implementation behavior assert api_key == '' + + +@pytest.mark.asyncio +async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config): + """Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked.""" + # Arrange + access_payload = { + 'sub': 'test_user_id', + 'exp': int(time.time()) + 3600, + 'email': 'user@colsch.us', + 'email_verified': True, + } + access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256') + + token_payload = { + 'access_token': access_token, + 'refresh_token': 'test_refresh_token', + } + signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') + + with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker: + mock_domain_blocker.is_active.return_value = True + mock_domain_blocker.is_domain_blocked.return_value = True + + # Act & Assert + with pytest.raises(AuthError) as exc_info: + await saas_user_auth_from_signed_token(signed_token) + + assert 'email domain is not allowed' in str(exc_info.value) + mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us') + + +@pytest.mark.asyncio +async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config): + """Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked.""" + # Arrange + access_payload = { + 'sub': 'test_user_id', + 'exp': int(time.time()) + 3600, + 'email': 'user@example.com', + 'email_verified': True, + } + access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256') + + token_payload = { + 'access_token': access_token, + 'refresh_token': 'test_refresh_token', + } + signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') + + with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker: + mock_domain_blocker.is_active.return_value = True + mock_domain_blocker.is_domain_blocked.return_value = False + + # Act + result = await saas_user_auth_from_signed_token(signed_token) + + # Assert + assert isinstance(result, SaasUserAuth) + assert result.user_id == 'test_user_id' + assert result.email == 'user@example.com' + mock_domain_blocker.is_domain_blocked.assert_called_once_with( + 'user@example.com' + ) + + +@pytest.mark.asyncio +async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config): + """Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active.""" + # Arrange + access_payload = { + 'sub': 'test_user_id', + 'exp': int(time.time()) + 3600, + 'email': 'user@colsch.us', + 'email_verified': True, + } + access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256') + + token_payload = { + 'access_token': access_token, + 'refresh_token': 'test_refresh_token', + } + signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') + + with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker: + mock_domain_blocker.is_active.return_value = False + + # Act + result = await saas_user_auth_from_signed_token(signed_token) + + # Assert + assert isinstance(result, SaasUserAuth) + assert result.user_id == 'test_user_id' + mock_domain_blocker.is_domain_blocked.assert_not_called() diff --git a/enterprise/tests/unit/test_token_manager_extended.py b/enterprise/tests/unit/test_token_manager_extended.py index 744f208b02..c3b09434a3 100644 --- a/enterprise/tests/unit/test_token_manager_extended.py +++ b/enterprise/tests/unit/test_token_manager_extended.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from server.auth.token_manager import TokenManager, create_encryption_utility @@ -246,3 +246,103 @@ async def test_refresh(token_manager): mock_keycloak.return_value.a_refresh_token.assert_called_once_with( 'test_refresh_token' ) + + +@pytest.mark.asyncio +async def test_disable_keycloak_user_success(token_manager): + """Test successful disabling of a Keycloak user account.""" + # Arrange + user_id = 'test_user_id' + email = 'user@colsch.us' + mock_user = { + 'id': user_id, + 'username': 'testuser', + 'email': email, + 'emailVerified': True, + } + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + mock_admin.a_get_user = AsyncMock(return_value=mock_user) + mock_admin.a_update_user = AsyncMock() + mock_get_admin.return_value = mock_admin + + # Act + await token_manager.disable_keycloak_user(user_id, email) + + # Assert + mock_admin.a_get_user.assert_called_once_with(user_id) + mock_admin.a_update_user.assert_called_once_with( + user_id=user_id, + payload={ + 'enabled': False, + 'username': 'testuser', + 'email': email, + 'emailVerified': True, + }, + ) + + +@pytest.mark.asyncio +async def test_disable_keycloak_user_without_email(token_manager): + """Test disabling Keycloak user without providing email.""" + # Arrange + user_id = 'test_user_id' + mock_user = { + 'id': user_id, + 'username': 'testuser', + 'email': 'user@example.com', + 'emailVerified': False, + } + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + mock_admin.a_get_user = AsyncMock(return_value=mock_user) + mock_admin.a_update_user = AsyncMock() + mock_get_admin.return_value = mock_admin + + # Act + await token_manager.disable_keycloak_user(user_id) + + # Assert + mock_admin.a_get_user.assert_called_once_with(user_id) + mock_admin.a_update_user.assert_called_once() + + +@pytest.mark.asyncio +async def test_disable_keycloak_user_not_found(token_manager): + """Test disabling Keycloak user when user is not found.""" + # Arrange + user_id = 'nonexistent_user_id' + email = 'user@colsch.us' + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + mock_admin.a_get_user = AsyncMock(return_value=None) + mock_get_admin.return_value = mock_admin + + # Act + await token_manager.disable_keycloak_user(user_id, email) + + # Assert + mock_admin.a_get_user.assert_called_once_with(user_id) + mock_admin.a_update_user.assert_not_called() + + +@pytest.mark.asyncio +async def test_disable_keycloak_user_exception_handling(token_manager): + """Test that disable_keycloak_user handles exceptions gracefully without raising.""" + # Arrange + user_id = 'test_user_id' + email = 'user@colsch.us' + + with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin: + mock_admin = MagicMock() + mock_admin.a_get_user = AsyncMock(side_effect=Exception('Connection error')) + mock_get_admin.return_value = mock_admin + + # Act & Assert - should not raise exception + await token_manager.disable_keycloak_user(user_id, email) + + # Verify the method was called + mock_admin.a_get_user.assert_called_once_with(user_id)