From 8ddb815a8928e0d22912a3b21a2d454d24a5ad5a Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Wed, 7 Jan 2026 13:41:43 +0700 Subject: [PATCH] refactor(backend): enhance storage and retrieval of blocked domains (#12273) --- .../086_create_blocked_email_domains_table.py | 54 +++ enterprise/server/auth/constants.py | 5 - enterprise/server/auth/domain_blocker.py | 60 ++-- enterprise/server/auth/saas_user_auth.py | 2 +- enterprise/server/routes/auth.py | 2 +- enterprise/storage/blocked_email_domain.py | 30 ++ .../storage/blocked_email_domain_store.py | 45 +++ enterprise/tests/unit/test_auth_routes.py | 10 +- enterprise/tests/unit/test_domain_blocker.py | 332 ++++++------------ enterprise/tests/unit/test_saas_user_auth.py | 8 +- 10 files changed, 280 insertions(+), 268 deletions(-) create mode 100644 enterprise/migrations/versions/086_create_blocked_email_domains_table.py create mode 100644 enterprise/storage/blocked_email_domain.py create mode 100644 enterprise/storage/blocked_email_domain_store.py diff --git a/enterprise/migrations/versions/086_create_blocked_email_domains_table.py b/enterprise/migrations/versions/086_create_blocked_email_domains_table.py new file mode 100644 index 0000000000..7333f6afb2 --- /dev/null +++ b/enterprise/migrations/versions/086_create_blocked_email_domains_table.py @@ -0,0 +1,54 @@ +"""create blocked_email_domains table + +Revision ID: 086 +Revises: 085 +Create Date: 2025-01-27 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '086' +down_revision: Union[str, None] = '085' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create blocked_email_domains table for storing blocked email domain patterns.""" + op.create_table( + 'blocked_email_domains', + sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True), + sa.Column('domain', sa.String(), nullable=False), + sa.Column( + 'created_at', + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text('CURRENT_TIMESTAMP'), + ), + sa.Column( + 'updated_at', + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text('CURRENT_TIMESTAMP'), + ), + sa.PrimaryKeyConstraint('id'), + ) + + # Create unique index on domain column + op.create_index( + 'ix_blocked_email_domains_domain', + 'blocked_email_domains', + ['domain'], + unique=True, + ) + + +def downgrade() -> None: + """Drop blocked_email_domains table.""" + op.drop_index('ix_blocked_email_domains_domain', table_name='blocked_email_domains') + op.drop_table('blocked_email_domains') diff --git a/enterprise/server/auth/constants.py b/enterprise/server/auth/constants.py index 242237e93d..15d3b0f704 100644 --- a/enterprise/server/auth/constants.py +++ b/enterprise/server/auth/constants.py @@ -38,8 +38,3 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in ( 'y', 'on', ) -BLOCKED_EMAIL_DOMAINS = [ - domain.strip().lower() - for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',') - if domain.strip() -] diff --git a/enterprise/server/auth/domain_blocker.py b/enterprise/server/auth/domain_blocker.py index 9c8164dfc6..3844f1bf85 100644 --- a/enterprise/server/auth/domain_blocker.py +++ b/enterprise/server/auth/domain_blocker.py @@ -1,20 +1,13 @@ -from server.auth.constants import BLOCKED_EMAIL_DOMAINS +from storage.blocked_email_domain_store import BlockedEmailDomainStore +from storage.database import session_maker from openhands.core.logger import openhands_logger as logger class DomainBlocker: - def __init__(self) -> None: + def __init__(self, store: BlockedEmailDomainStore) -> None: logger.debug('Initializing DomainBlocker') - self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS - if self.blocked_domains: - logger.info( - f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}' - ) - - def is_active(self) -> bool: - """Check if domain blocking is enabled""" - return bool(self.blocked_domains) + self.store = store def _extract_domain(self, email: str) -> str | None: """Extract and normalize email domain from email address""" @@ -31,16 +24,16 @@ class DomainBlocker: return None def is_domain_blocked(self, email: str) -> bool: - """Check if email domain is blocked + """Check if email domain is blocked by querying the database directly via SQL. Supports blocking: - Exact domains: 'example.com' blocks 'user@example.com' - Subdomains: 'example.com' blocks 'user@subdomain.example.com' - TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us' - """ - if not self.is_active(): - return False + The blocking logic is handled efficiently in SQL, avoiding the need to load + all blocked domains into memory. + """ if not email: logger.debug('No email provided for domain check') return False @@ -50,26 +43,25 @@ class DomainBlocker: logger.debug(f'Could not extract domain from email: {email}') return False - # Check if domain matches any blocked pattern - for blocked_pattern in self.blocked_domains: - if blocked_pattern.startswith('.'): - # TLD pattern (e.g., '.us') - check if domain ends with it - if domain.endswith(blocked_pattern): - logger.warning( - f'Email domain {domain} is blocked by TLD pattern {blocked_pattern} for email: {email}' - ) - return True + try: + # Query database directly via SQL to check if domain is blocked + is_blocked = self.store.is_domain_blocked(domain) + + if is_blocked: + logger.warning(f'Email domain {domain} is blocked for email: {email}') else: - # Full domain pattern (e.g., 'example.com') - # Block exact match or subdomains - if domain == blocked_pattern or domain.endswith(f'.{blocked_pattern}'): - logger.warning( - f'Email domain {domain} is blocked by domain pattern {blocked_pattern} for email: {email}' - ) - return True + logger.debug(f'Email domain {domain} is not blocked') - logger.debug(f'Email domain {domain} is not blocked') - return False + return is_blocked + except Exception as e: + logger.error( + f'Error checking if domain is blocked for email {email}: {e}', + exc_info=True, + ) + # Fail-safe: if database query fails, don't block (allow auth to proceed) + return False -domain_blocker = DomainBlocker() +# Initialize store and domain blocker +_store = BlockedEmailDomainStore(session_maker=session_maker) +domain_blocker = DomainBlocker(store=_store) diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index 73a7217fd2..bdef2f8bf4 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -317,7 +317,7 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth: email_verified = access_token_payload['email_verified'] # Check if email domain is blocked - if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email): + if email and domain_blocker.is_domain_blocked(email): logger.warning( f'Blocked authentication attempt for existing user with email: {email}' ) diff --git a/enterprise/server/routes/auth.py b/enterprise/server/routes/auth.py index dac7d6871a..c434d2a3d2 100644 --- a/enterprise/server/routes/auth.py +++ b/enterprise/server/routes/auth.py @@ -151,7 +151,7 @@ async def keycloak_callback( # Check if email domain is blocked email = user_info.get('email') - if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email): + if email and domain_blocker.is_domain_blocked(email): logger.warning( f'Blocked authentication attempt for email: {email}, user_id: {user_id}' ) diff --git a/enterprise/storage/blocked_email_domain.py b/enterprise/storage/blocked_email_domain.py new file mode 100644 index 0000000000..59783ba975 --- /dev/null +++ b/enterprise/storage/blocked_email_domain.py @@ -0,0 +1,30 @@ +from datetime import UTC, datetime + +from sqlalchemy import Column, DateTime, Identity, Integer, String +from storage.base import Base + + +class BlockedEmailDomain(Base): # type: ignore + """Stores blocked email domain patterns. + + Supports blocking: + - Exact domains: 'example.com' blocks 'user@example.com' + - Subdomains: 'example.com' blocks 'user@subdomain.example.com' + - TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us' + """ + + __tablename__ = 'blocked_email_domains' + + id = Column(Integer, Identity(), primary_key=True) + domain = Column(String, nullable=False, unique=True) + created_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + nullable=False, + ) + updated_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + nullable=False, + ) diff --git a/enterprise/storage/blocked_email_domain_store.py b/enterprise/storage/blocked_email_domain_store.py new file mode 100644 index 0000000000..2b1fae212d --- /dev/null +++ b/enterprise/storage/blocked_email_domain_store.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +from sqlalchemy import text +from sqlalchemy.orm import sessionmaker + + +@dataclass +class BlockedEmailDomainStore: + session_maker: sessionmaker + + def is_domain_blocked(self, domain: str) -> bool: + """Check if a domain is blocked by querying the database directly. + + This method uses SQL to efficiently check if the domain matches any blocked pattern: + - TLD patterns (e.g., '.us'): checks if domain ends with the pattern + - Full domain patterns (e.g., 'example.com'): checks for exact match or subdomain match + + Args: + domain: The extracted domain from the email (e.g., 'example.com' or 'subdomain.example.com') + + Returns: + True if the domain is blocked, False otherwise + """ + with self.session_maker() as session: + # SQL query that handles both TLD patterns and full domain patterns + # TLD patterns (starting with '.'): check if domain ends with the pattern + # Full domain patterns: check for exact match or subdomain match + # All comparisons are case-insensitive using LOWER() to ensure consistent matching + query = text(""" + SELECT EXISTS( + SELECT 1 + FROM blocked_email_domains + WHERE + -- TLD pattern (e.g., '.us') - check if domain ends with it (case-insensitive) + (LOWER(domain) LIKE '.%' AND LOWER(:domain) LIKE '%' || LOWER(domain)) OR + -- Full domain pattern (e.g., 'example.com') + -- Block exact match or subdomains (case-insensitive) + (LOWER(domain) NOT LIKE '.%' AND ( + LOWER(:domain) = LOWER(domain) OR + LOWER(:domain) LIKE '%.' || LOWER(domain) + )) + ) + """) + result = session.execute(query, {'domain': domain}).scalar() + return bool(result) diff --git a/enterprise/tests/unit/test_auth_routes.py b/enterprise/tests/unit/test_auth_routes.py index 362940b1f7..eabc499911 100644 --- a/enterprise/tests/unit/test_auth_routes.py +++ b/enterprise/tests/unit/test_auth_routes.py @@ -546,7 +546,6 @@ async def test_keycloak_callback_blocked_email_domain(mock_request): ) mock_token_manager.disable_keycloak_user = AsyncMock() - mock_domain_blocker.is_active.return_value = True mock_domain_blocker.is_domain_blocked.return_value = True # Act @@ -600,7 +599,6 @@ async def test_keycloak_callback_allowed_email_domain(mock_request): mock_token_manager.store_idp_tokens = AsyncMock() mock_token_manager.validate_offline_token = AsyncMock(return_value=True) - mock_domain_blocker.is_active.return_value = True mock_domain_blocker.is_domain_blocked.return_value = False mock_verifier.is_active.return_value = True @@ -621,7 +619,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request): @pytest.mark.asyncio async def test_keycloak_callback_domain_blocking_inactive(mock_request): - """Test keycloak_callback when domain blocking is not active.""" + """Test keycloak_callback when email domain is not blocked.""" # Arrange with ( patch('server.routes.auth.token_manager') as mock_token_manager, @@ -654,7 +652,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request): mock_token_manager.store_idp_tokens = AsyncMock() mock_token_manager.validate_offline_token = AsyncMock(return_value=True) - mock_domain_blocker.is_active.return_value = False + mock_domain_blocker.is_domain_blocked.return_value = False mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -666,7 +664,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request): # Assert assert isinstance(result, RedirectResponse) - mock_domain_blocker.is_domain_blocked.assert_not_called() + mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us') mock_token_manager.disable_keycloak_user.assert_not_called() @@ -705,8 +703,6 @@ async def test_keycloak_callback_missing_email(mock_request): mock_token_manager.store_idp_tokens = AsyncMock() mock_token_manager.validate_offline_token = AsyncMock(return_value=True) - mock_domain_blocker.is_active.return_value = True - mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True diff --git a/enterprise/tests/unit/test_domain_blocker.py b/enterprise/tests/unit/test_domain_blocker.py index 7d8aa8ed60..cae944e949 100644 --- a/enterprise/tests/unit/test_domain_blocker.py +++ b/enterprise/tests/unit/test_domain_blocker.py @@ -1,33 +1,21 @@ """Unit tests for DomainBlocker class.""" +from unittest.mock import MagicMock + import pytest from server.auth.domain_blocker import DomainBlocker @pytest.fixture -def domain_blocker(): - """Create a DomainBlocker instance for testing.""" - return DomainBlocker() +def mock_store(): + """Create a mock BlockedEmailDomainStore for testing.""" + return MagicMock() -@pytest.mark.parametrize( - 'blocked_domains,expected', - [ - (['colsch.us', 'other-domain.com'], True), - (['example.com'], True), - ([], False), - ], -) -def test_is_active(domain_blocker, blocked_domains, expected): - """Test that is_active returns correct value based on blocked domains configuration.""" - # Arrange - domain_blocker.blocked_domains = blocked_domains - - # Act - result = domain_blocker.is_active() - - # Assert - assert result == expected +@pytest.fixture +def domain_blocker(mock_store): + """Create a DomainBlocker instance for testing with a mocked store.""" + return DomainBlocker(store=mock_store) @pytest.mark.parametrize( @@ -69,94 +57,104 @@ def test_extract_domain_invalid_emails(domain_blocker, email, expected): assert result == expected -def test_is_domain_blocked_when_inactive(domain_blocker): - """Test that is_domain_blocked returns False when blocking is not active.""" - # Arrange - domain_blocker.blocked_domains = [] - - # Act - result = domain_blocker.is_domain_blocked('user@colsch.us') - - # Assert - assert result is False - - -def test_is_domain_blocked_with_none_email(domain_blocker): +def test_is_domain_blocked_with_none_email(domain_blocker, mock_store): """Test that is_domain_blocked returns False when email is None.""" # Arrange - domain_blocker.blocked_domains = ['colsch.us'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked(None) # Assert assert result is False + mock_store.is_domain_blocked.assert_not_called() -def test_is_domain_blocked_with_empty_email(domain_blocker): +def test_is_domain_blocked_with_empty_email(domain_blocker, mock_store): """Test that is_domain_blocked returns False when email is empty.""" # Arrange - domain_blocker.blocked_domains = ['colsch.us'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('') # Assert assert result is False + mock_store.is_domain_blocked.assert_not_called() -def test_is_domain_blocked_with_invalid_email(domain_blocker): +def test_is_domain_blocked_with_invalid_email(domain_blocker, mock_store): """Test that is_domain_blocked returns False when email format is invalid.""" # Arrange - domain_blocker.blocked_domains = ['colsch.us'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('invalid-email') # Assert assert result is False + mock_store.is_domain_blocked.assert_not_called() -def test_is_domain_blocked_domain_not_blocked(domain_blocker): - """Test that is_domain_blocked returns False when domain is not in blocked list.""" +def test_is_domain_blocked_domain_not_blocked(domain_blocker, mock_store): + """Test that is_domain_blocked returns False when domain is not blocked.""" # Arrange - domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com'] + mock_store.is_domain_blocked.return_value = False # Act result = domain_blocker.is_domain_blocked('user@example.com') # Assert assert result is False + mock_store.is_domain_blocked.assert_called_once_with('example.com') -def test_is_domain_blocked_domain_blocked(domain_blocker): - """Test that is_domain_blocked returns True when domain is in blocked list.""" +def test_is_domain_blocked_domain_blocked(domain_blocker, mock_store): + """Test that is_domain_blocked returns True when domain is blocked.""" # Arrange - domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('user@colsch.us') # Assert assert result is True + mock_store.is_domain_blocked.assert_called_once_with('colsch.us') -def test_is_domain_blocked_case_insensitive(domain_blocker): - """Test that is_domain_blocked performs case-insensitive domain matching.""" +def test_is_domain_blocked_case_insensitive(domain_blocker, mock_store): + """Test that is_domain_blocked performs case-insensitive domain extraction.""" # Arrange - domain_blocker.blocked_domains = ['colsch.us'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('user@COLSCH.US') # Assert assert result is True + mock_store.is_domain_blocked.assert_called_once_with('colsch.us') -def test_is_domain_blocked_multiple_blocked_domains(domain_blocker): - """Test that is_domain_blocked correctly checks against multiple blocked domains.""" +def test_is_domain_blocked_with_whitespace(domain_blocker, mock_store): + """Test that is_domain_blocked handles emails with whitespace correctly.""" # Arrange - domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org'] + mock_store.is_domain_blocked.return_value = True + + # Act + result = domain_blocker.is_domain_blocked(' user@colsch.us ') + + # Assert + assert result is True + mock_store.is_domain_blocked.assert_called_once_with('colsch.us') + + +def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_store): + """Test that is_domain_blocked correctly checks multiple domains.""" + # Arrange + mock_store.is_domain_blocked.side_effect = lambda domain: domain in [ + 'other-domain.com', + 'blocked.org', + ] # Act result1 = domain_blocker.is_domain_blocked('user@other-domain.com') @@ -167,109 +165,71 @@ def test_is_domain_blocked_multiple_blocked_domains(domain_blocker): assert result1 is True assert result2 is True assert result3 is False + assert mock_store.is_domain_blocked.call_count == 3 -def test_is_domain_blocked_with_whitespace(domain_blocker): - """Test that is_domain_blocked handles emails with whitespace correctly.""" - # Arrange - domain_blocker.blocked_domains = ['colsch.us'] - - # Act - result = domain_blocker.is_domain_blocked(' user@colsch.us ') - - # Assert - assert result is True - - -# ============================================================================ -# TLD Blocking Tests (patterns starting with '.') -# ============================================================================ - - -def test_is_domain_blocked_tld_pattern_blocks_matching_domain(domain_blocker): +def test_is_domain_blocked_tld_pattern_blocks_matching_domain( + domain_blocker, mock_store +): """Test that TLD pattern blocks domains ending with that TLD.""" # Arrange - domain_blocker.blocked_domains = ['.us'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('user@company.us') # Assert assert result is True + mock_store.is_domain_blocked.assert_called_once_with('company.us') -def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(domain_blocker): +def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld( + domain_blocker, mock_store +): """Test that TLD pattern blocks subdomains with that TLD.""" # Arrange - domain_blocker.blocked_domains = ['.us'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('user@subdomain.company.us') # Assert assert result is True + mock_store.is_domain_blocked.assert_called_once_with('subdomain.company.us') -def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(domain_blocker): +def test_is_domain_blocked_tld_pattern_does_not_block_different_tld( + domain_blocker, mock_store +): """Test that TLD pattern does not block domains with different TLD.""" # Arrange - domain_blocker.blocked_domains = ['.us'] + mock_store.is_domain_blocked.return_value = False # Act result = domain_blocker.is_domain_blocked('user@company.com') # Assert assert result is False + mock_store.is_domain_blocked.assert_called_once_with('company.com') -def test_is_domain_blocked_tld_pattern_does_not_block_substring_match( - domain_blocker, -): - """Test that TLD pattern does not block domains that contain but don't end with the TLD.""" - # Arrange - domain_blocker.blocked_domains = ['.us'] - - # Act - result = domain_blocker.is_domain_blocked('user@focus.com') - - # Assert - assert result is False - - -def test_is_domain_blocked_tld_pattern_case_insensitive(domain_blocker): +def test_is_domain_blocked_tld_pattern_case_insensitive(domain_blocker, mock_store): """Test that TLD pattern matching is case-insensitive.""" # Arrange - domain_blocker.blocked_domains = ['.us'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('user@COMPANY.US') # Assert assert result is True + mock_store.is_domain_blocked.assert_called_once_with('company.us') -def test_is_domain_blocked_multiple_tld_patterns(domain_blocker): - """Test blocking with multiple TLD patterns.""" - # Arrange - domain_blocker.blocked_domains = ['.us', '.vn', '.com'] - - # Act - result_us = domain_blocker.is_domain_blocked('user@test.us') - result_vn = domain_blocker.is_domain_blocked('user@test.vn') - result_com = domain_blocker.is_domain_blocked('user@test.com') - result_org = domain_blocker.is_domain_blocked('user@test.org') - - # Assert - assert result_us is True - assert result_vn is True - assert result_com is True - assert result_org is False - - -def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker): +def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker, mock_store): """Test that TLD pattern works with multi-level TLDs like .co.uk.""" # Arrange - domain_blocker.blocked_domains = ['.co.uk'] + mock_store.is_domain_blocked.side_effect = lambda domain: domain.endswith('.co.uk') # Act result_match = domain_blocker.is_domain_blocked('user@example.co.uk') @@ -282,81 +242,87 @@ def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker): assert result_no_match is False -# ============================================================================ -# Subdomain Blocking Tests (domain patterns now block subdomains) -# ============================================================================ - - -def test_is_domain_blocked_domain_pattern_blocks_exact_match(domain_blocker): +def test_is_domain_blocked_domain_pattern_blocks_exact_match( + domain_blocker, mock_store +): """Test that domain pattern blocks exact domain match.""" # Arrange - domain_blocker.blocked_domains = ['example.com'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('user@example.com') # Assert assert result is True + mock_store.is_domain_blocked.assert_called_once_with('example.com') -def test_is_domain_blocked_domain_pattern_blocks_subdomain(domain_blocker): +def test_is_domain_blocked_domain_pattern_blocks_subdomain(domain_blocker, mock_store): """Test that domain pattern blocks subdomains of that domain.""" # Arrange - domain_blocker.blocked_domains = ['example.com'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('user@subdomain.example.com') # Assert assert result is True + mock_store.is_domain_blocked.assert_called_once_with('subdomain.example.com') def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain( - domain_blocker, + domain_blocker, mock_store ): """Test that domain pattern blocks multi-level subdomains.""" # Arrange - domain_blocker.blocked_domains = ['example.com'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked('user@api.v2.example.com') # Assert assert result is True + mock_store.is_domain_blocked.assert_called_once_with('api.v2.example.com') def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain( - domain_blocker, + domain_blocker, mock_store ): """Test that domain pattern does not block domains that contain but don't match the pattern.""" # Arrange - domain_blocker.blocked_domains = ['example.com'] + mock_store.is_domain_blocked.return_value = False # Act result = domain_blocker.is_domain_blocked('user@notexample.com') # Assert assert result is False + mock_store.is_domain_blocked.assert_called_once_with('notexample.com') def test_is_domain_blocked_domain_pattern_does_not_block_different_tld( - domain_blocker, + domain_blocker, mock_store ): """Test that domain pattern does not block same domain with different TLD.""" # Arrange - domain_blocker.blocked_domains = ['example.com'] + mock_store.is_domain_blocked.return_value = False # Act result = domain_blocker.is_domain_blocked('user@example.org') # Assert assert result is False + mock_store.is_domain_blocked.assert_called_once_with('example.org') -def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(domain_blocker): +def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested( + domain_blocker, mock_store +): """Test that blocking a subdomain also blocks its nested subdomains.""" # Arrange - domain_blocker.blocked_domains = ['api.example.com'] + mock_store.is_domain_blocked.side_effect = ( + lambda domain: 'api.example.com' in domain + ) # Act result_exact = domain_blocker.is_domain_blocked('user@api.example.com') @@ -369,80 +335,10 @@ def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(domain_bloc assert result_parent is False -# ============================================================================ -# Mixed Pattern Tests (TLD + domain patterns together) -# ============================================================================ - - -def test_is_domain_blocked_mixed_patterns_tld_and_domain(domain_blocker): - """Test blocking with both TLD and domain patterns.""" - # Arrange - domain_blocker.blocked_domains = ['.us', 'openhands.dev'] - - # Act - result_tld = domain_blocker.is_domain_blocked('user@company.us') - result_domain = domain_blocker.is_domain_blocked('user@openhands.dev') - result_subdomain = domain_blocker.is_domain_blocked('user@api.openhands.dev') - result_allowed = domain_blocker.is_domain_blocked('user@example.com') - - # Assert - assert result_tld is True - assert result_domain is True - assert result_subdomain is True - assert result_allowed is False - - -def test_is_domain_blocked_overlapping_patterns(domain_blocker): - """Test that overlapping patterns (TLD and specific domain) both work.""" - # Arrange - domain_blocker.blocked_domains = ['.us', 'test.us'] - - # Act - result_specific = domain_blocker.is_domain_blocked('user@test.us') - result_other_us = domain_blocker.is_domain_blocked('user@other.us') - - # Assert - assert result_specific is True - assert result_other_us is True - - -def test_is_domain_blocked_complex_multi_pattern_scenario(domain_blocker): - """Test complex scenario with multiple TLD and domain patterns.""" - # Arrange - domain_blocker.blocked_domains = [ - '.us', - '.vn', - 'test.com', - 'openhands.dev', - ] - - # Act & Assert - # TLD patterns - assert domain_blocker.is_domain_blocked('user@anything.us') is True - assert domain_blocker.is_domain_blocked('user@company.vn') is True - - # Domain patterns (exact) - assert domain_blocker.is_domain_blocked('user@test.com') is True - assert domain_blocker.is_domain_blocked('user@openhands.dev') is True - - # Domain patterns (subdomains) - assert domain_blocker.is_domain_blocked('user@api.test.com') is True - assert domain_blocker.is_domain_blocked('user@staging.openhands.dev') is True - - # Not blocked - assert domain_blocker.is_domain_blocked('user@allowed.com') is False - assert domain_blocker.is_domain_blocked('user@example.org') is False - - -# ============================================================================ -# Edge Case Tests -# ============================================================================ - - -def test_is_domain_blocked_domain_with_hyphens(domain_blocker): +def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store): """Test that domain patterns work with hyphenated domains.""" # Arrange - domain_blocker.blocked_domains = ['my-company.com'] + mock_store.is_domain_blocked.return_value = True # Act result_exact = domain_blocker.is_domain_blocked('user@my-company.com') @@ -451,12 +347,13 @@ def test_is_domain_blocked_domain_with_hyphens(domain_blocker): # Assert assert result_exact is True assert result_subdomain is True + assert mock_store.is_domain_blocked.call_count == 2 -def test_is_domain_blocked_domain_with_numbers(domain_blocker): +def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store): """Test that domain patterns work with numeric domains.""" # Arrange - domain_blocker.blocked_domains = ['test123.com'] + mock_store.is_domain_blocked.return_value = True # Act result_exact = domain_blocker.is_domain_blocked('user@test123.com') @@ -465,24 +362,13 @@ def test_is_domain_blocked_domain_with_numbers(domain_blocker): # Assert assert result_exact is True assert result_subdomain is True + assert mock_store.is_domain_blocked.call_count == 2 -def test_is_domain_blocked_short_tld(domain_blocker): - """Test that short TLD patterns work correctly.""" - # Arrange - domain_blocker.blocked_domains = ['.io'] - - # Act - result = domain_blocker.is_domain_blocked('user@company.io') - - # Assert - assert result is True - - -def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker): +def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_store): """Test that blocking works with very long subdomain chains.""" # Arrange - domain_blocker.blocked_domains = ['example.com'] + mock_store.is_domain_blocked.return_value = True # Act result = domain_blocker.is_domain_blocked( @@ -491,3 +377,19 @@ def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker): # Assert assert result is True + mock_store.is_domain_blocked.assert_called_once_with( + 'level4.level3.level2.level1.example.com' + ) + + +def test_is_domain_blocked_handles_store_exception(domain_blocker, mock_store): + """Test that is_domain_blocked returns False when store raises an exception.""" + # Arrange + mock_store.is_domain_blocked.side_effect = Exception('Database connection error') + + # Act + result = domain_blocker.is_domain_blocked('user@example.com') + + # Assert + assert result is False + mock_store.is_domain_blocked.assert_called_once_with('example.com') diff --git a/enterprise/tests/unit/test_saas_user_auth.py b/enterprise/tests/unit/test_saas_user_auth.py index a518beb28e..6d9ced0057 100644 --- a/enterprise/tests/unit/test_saas_user_auth.py +++ b/enterprise/tests/unit/test_saas_user_auth.py @@ -673,7 +673,6 @@ async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config): signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker: - mock_domain_blocker.is_active.return_value = True mock_domain_blocker.is_domain_blocked.return_value = True # Act & Assert @@ -703,7 +702,6 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config): signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker: - mock_domain_blocker.is_active.return_value = True mock_domain_blocker.is_domain_blocked.return_value = False # Act @@ -720,7 +718,7 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config): @pytest.mark.asyncio async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config): - """Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active.""" + """Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked.""" # Arrange access_payload = { 'sub': 'test_user_id', @@ -737,7 +735,7 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker: - mock_domain_blocker.is_active.return_value = False + mock_domain_blocker.is_domain_blocked.return_value = False # Act result = await saas_user_auth_from_signed_token(signed_token) @@ -745,4 +743,4 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co # Assert assert isinstance(result, SaasUserAuth) assert result.user_id == 'test_user_id' - mock_domain_blocker.is_domain_blocked.assert_not_called() + mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')