mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Refactor user authorization: Replace domain blocklist with flexible whitelist/blacklist pattern matching (#13207)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
"""Create user_authorizations table and migrate blocked_email_domains
|
||||
|
||||
Revision ID: 099
|
||||
Revises: 098
|
||||
Create Date: 2025-03-05 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '099'
|
||||
down_revision: Union[str, None] = '098'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _seed_from_environment() -> None:
|
||||
"""Seed user_authorizations table from environment variables.
|
||||
|
||||
Reads EMAIL_PATTERN_BLACKLIST and EMAIL_PATTERN_WHITELIST environment variables.
|
||||
Each should be a comma-separated list of SQL LIKE patterns (e.g., '%@example.com').
|
||||
|
||||
If the environment variables are not set or empty, this function does nothing.
|
||||
|
||||
This allows us to set up feature deployments with particular patterns already
|
||||
blacklisted or whitelisted. (For example, you could blacklist everything with
|
||||
`%`, and then whitelist certain email accounts.)
|
||||
"""
|
||||
blacklist_patterns = os.environ.get('EMAIL_PATTERN_BLACKLIST', '').strip()
|
||||
whitelist_patterns = os.environ.get('EMAIL_PATTERN_WHITELIST', '').strip()
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
if blacklist_patterns:
|
||||
for pattern in blacklist_patterns.split(','):
|
||||
pattern = pattern.strip()
|
||||
if pattern:
|
||||
connection.execute(
|
||||
sa.text("""
|
||||
INSERT INTO user_authorizations
|
||||
(email_pattern, provider_type, type)
|
||||
VALUES
|
||||
(:pattern, NULL, 'blacklist')
|
||||
"""),
|
||||
{'pattern': pattern},
|
||||
)
|
||||
|
||||
if whitelist_patterns:
|
||||
for pattern in whitelist_patterns.split(','):
|
||||
pattern = pattern.strip()
|
||||
if pattern:
|
||||
connection.execute(
|
||||
sa.text("""
|
||||
INSERT INTO user_authorizations
|
||||
(email_pattern, provider_type, type)
|
||||
VALUES
|
||||
(:pattern, NULL, 'whitelist')
|
||||
"""),
|
||||
{'pattern': pattern},
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create user_authorizations table, migrate data, and drop blocked_email_domains."""
|
||||
# Create user_authorizations table
|
||||
op.create_table(
|
||||
'user_authorizations',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||
sa.Column('email_pattern', sa.String(), nullable=True),
|
||||
sa.Column('provider_type', sa.String(), nullable=True),
|
||||
sa.Column('type', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
# Create index on email_pattern for efficient LIKE queries
|
||||
op.create_index(
|
||||
'ix_user_authorizations_email_pattern',
|
||||
'user_authorizations',
|
||||
['email_pattern'],
|
||||
)
|
||||
|
||||
# Create index on type for efficient filtering
|
||||
op.create_index(
|
||||
'ix_user_authorizations_type',
|
||||
'user_authorizations',
|
||||
['type'],
|
||||
)
|
||||
|
||||
# Migrate existing blocked_email_domains to user_authorizations as blacklist entries
|
||||
# The domain patterns are converted to SQL LIKE patterns:
|
||||
# - 'example.com' becomes '%@example.com' (matches user@example.com)
|
||||
# - '.us' becomes '%@%.us' (matches user@anything.us)
|
||||
# We also add '%.' prefix for subdomain matching
|
||||
op.execute("""
|
||||
INSERT INTO user_authorizations (email_pattern, provider_type, type, created_at, updated_at)
|
||||
SELECT
|
||||
CASE
|
||||
WHEN domain LIKE '.%' THEN '%' || domain
|
||||
ELSE '%@%' || domain
|
||||
END as email_pattern,
|
||||
NULL as provider_type,
|
||||
'blacklist' as type,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM blocked_email_domains
|
||||
""")
|
||||
|
||||
# Seed additional patterns from environment variables (if set)
|
||||
_seed_from_environment()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Recreate blocked_email_domains table and migrate data back."""
|
||||
# Drop user_authorizations table
|
||||
op.drop_index('ix_user_authorizations_type', table_name='user_authorizations')
|
||||
op.drop_index(
|
||||
'ix_user_authorizations_email_pattern', table_name='user_authorizations'
|
||||
)
|
||||
op.drop_table('user_authorizations')
|
||||
@@ -1,53 +0,0 @@
|
||||
import os
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class UserVerifier:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing UserVerifier')
|
||||
self.file_users: list[str] | None = None
|
||||
|
||||
# Initialize from environment variables
|
||||
self._init_file_users()
|
||||
|
||||
def _init_file_users(self) -> None:
|
||||
"""Load users from text file if configured."""
|
||||
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
|
||||
if not waitlist:
|
||||
logger.debug('GITHUB_USER_LIST_FILE not configured')
|
||||
return
|
||||
|
||||
if not os.path.exists(waitlist):
|
||||
logger.error(f'User list file not found: {waitlist}')
|
||||
raise FileNotFoundError(f'User list file not found: {waitlist}')
|
||||
|
||||
try:
|
||||
with open(waitlist, 'r') as f:
|
||||
self.file_users = [line.strip().lower() for line in f if line.strip()]
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.file_users)} users from {waitlist}'
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f'Error reading user list file {waitlist}')
|
||||
|
||||
def is_active(self) -> bool:
|
||||
if os.getenv('DISABLE_WAITLIST', '').lower() == 'true':
|
||||
logger.info('Waitlist disabled via DISABLE_WAITLIST env var')
|
||||
return False
|
||||
return bool(self.file_users)
|
||||
|
||||
def is_user_allowed(self, username: str) -> bool:
|
||||
"""Check if user is allowed based on file and/or sheet configuration."""
|
||||
logger.debug(f'Checking if GitHub user {username} is allowed')
|
||||
if self.file_users:
|
||||
if username.lower() in self.file_users:
|
||||
logger.debug(f'User {username} found in text file allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in text file allowlist')
|
||||
|
||||
logger.debug(f'User {username} not found in any allowlist')
|
||||
return False
|
||||
|
||||
|
||||
user_verifier = UserVerifier()
|
||||
@@ -1,66 +0,0 @@
|
||||
from storage.blocked_email_domain_store import BlockedEmailDomainStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class DomainBlocker:
|
||||
def __init__(self, store: BlockedEmailDomainStore) -> None:
|
||||
logger.debug('Initializing DomainBlocker')
|
||||
self.store = store
|
||||
|
||||
def _extract_domain(self, email: str) -> str | None:
|
||||
"""Extract and normalize email domain from email address"""
|
||||
if not email:
|
||||
return None
|
||||
try:
|
||||
# Extract domain part after @
|
||||
if '@' not in email:
|
||||
return None
|
||||
domain = email.split('@')[1].strip().lower()
|
||||
return domain if domain else None
|
||||
except Exception:
|
||||
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||
return None
|
||||
|
||||
async def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked by querying the database directly via SQL.
|
||||
|
||||
Supports blocking:
|
||||
- Exact domains: 'example.com' blocks 'user@example.com'
|
||||
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
|
||||
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
|
||||
|
||||
The blocking logic is handled efficiently in SQL, avoiding the need to load
|
||||
all blocked domains into memory.
|
||||
"""
|
||||
if not email:
|
||||
logger.debug('No email provided for domain check')
|
||||
return False
|
||||
|
||||
domain = self._extract_domain(email)
|
||||
if not domain:
|
||||
logger.debug(f'Could not extract domain from email: {email}')
|
||||
return False
|
||||
|
||||
try:
|
||||
# Query database directly via SQL to check if domain is blocked
|
||||
is_blocked = await self.store.is_domain_blocked(domain)
|
||||
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
else:
|
||||
logger.debug(f'Email domain {domain} is not blocked')
|
||||
|
||||
return is_blocked
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error checking if domain is blocked for email {email}: {e}',
|
||||
exc_info=True,
|
||||
)
|
||||
# Fail-safe: if database query fails, don't block (allow auth to proceed)
|
||||
return False
|
||||
|
||||
|
||||
# Initialize store and domain blocker
|
||||
_store = BlockedEmailDomainStore()
|
||||
domain_blocker = DomainBlocker(store=_store)
|
||||
@@ -13,7 +13,6 @@ from server.auth.auth_error import (
|
||||
ExpiredError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@@ -24,6 +23,8 @@ from storage.auth_tokens import AuthTokens
|
||||
from storage.database import a_session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_authorization import UserAuthorizationType
|
||||
from storage.user_authorization_store import UserAuthorizationStore
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
|
||||
from openhands.integrations.provider import (
|
||||
@@ -326,14 +327,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
email = access_token_payload['email']
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and await domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
raise AuthError(
|
||||
'Access denied: Your email domain is not allowed to access this service'
|
||||
)
|
||||
# Check if email is blacklisted (whitelist takes precedence)
|
||||
if email:
|
||||
auth_type = await UserAuthorizationStore.get_authorization_type(email, None)
|
||||
if auth_type == UserAuthorizationType.BLACKLIST:
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
raise AuthError(
|
||||
'Access denied: Your email domain is not allowed to access this service'
|
||||
)
|
||||
|
||||
logger.debug('saas_user_auth_from_signed_token:return')
|
||||
|
||||
|
||||
0
enterprise/server/auth/user/__init__.py
Normal file
0
enterprise/server/auth/user/__init__.py
Normal file
98
enterprise/server/auth/user/default_user_authorizer.py
Normal file
98
enterprise/server/auth/user/default_user_authorizer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import Field
|
||||
from server.auth.email_validation import extract_base_email
|
||||
from server.auth.token_manager import KeycloakUserInfo, TokenManager
|
||||
from server.auth.user.user_authorizer import (
|
||||
UserAuthorizationResponse,
|
||||
UserAuthorizer,
|
||||
UserAuthorizerInjector,
|
||||
)
|
||||
from storage.user_authorization import UserAuthorizationType
|
||||
from storage.user_authorization_store import UserAuthorizationStore
|
||||
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
token_manager = TokenManager()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DefaultUserAuthorizer(UserAuthorizer):
|
||||
"""Class determining whether a user may be authorized.
|
||||
|
||||
Uses the user_authorizations database table to check whitelist/blacklist rules.
|
||||
"""
|
||||
|
||||
prevent_duplicates: bool
|
||||
|
||||
async def authorize_user(
|
||||
self, user_info: KeycloakUserInfo
|
||||
) -> UserAuthorizationResponse:
|
||||
user_id = user_info.sub
|
||||
email = user_info.email
|
||||
provider_type = user_info.identity_provider
|
||||
try:
|
||||
if not email:
|
||||
logger.warning(f'No email provided for user_id: {user_id}')
|
||||
return UserAuthorizationResponse(
|
||||
success=False, error_detail='missing_email'
|
||||
)
|
||||
|
||||
if self.prevent_duplicates:
|
||||
has_duplicate = await token_manager.check_duplicate_base_email(
|
||||
email, user_id
|
||||
)
|
||||
if has_duplicate:
|
||||
logger.warning(
|
||||
f'Blocked signup attempt for email {email} - duplicate base email found',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
return UserAuthorizationResponse(
|
||||
success=False, error_detail='duplicate_email'
|
||||
)
|
||||
|
||||
# Check authorization rules (whitelist takes precedence over blacklist)
|
||||
base_email = extract_base_email(email)
|
||||
if base_email is None:
|
||||
return UserAuthorizationResponse(
|
||||
success=False, error_detail='invalid_email'
|
||||
)
|
||||
auth_type = await UserAuthorizationStore.get_authorization_type(
|
||||
base_email, provider_type
|
||||
)
|
||||
|
||||
if auth_type == UserAuthorizationType.WHITELIST:
|
||||
logger.debug(
|
||||
f'User {email} matched whitelist rule',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
return UserAuthorizationResponse(success=True)
|
||||
|
||||
if auth_type == UserAuthorizationType.BLACKLIST:
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
return UserAuthorizationResponse(success=False, error_detail='blocked')
|
||||
|
||||
return UserAuthorizationResponse(success=True)
|
||||
except Exception:
|
||||
logger.exception('error authorizing user', extra={'user_id': user_id})
|
||||
return UserAuthorizationResponse(success=False)
|
||||
|
||||
|
||||
class DefaultUserAuthorizerInjector(UserAuthorizerInjector):
|
||||
prevent_duplicates: bool = Field(
|
||||
default=True,
|
||||
description='Whether duplicate emails (containing +) are filtered',
|
||||
)
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[UserAuthorizer, None]:
|
||||
yield DefaultUserAuthorizer(
|
||||
prevent_duplicates=self.prevent_duplicates,
|
||||
)
|
||||
48
enterprise/server/auth/user/user_authorizer.py
Normal file
48
enterprise/server/auth/user/user_authorizer.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from server.auth.token_manager import KeycloakUserInfo
|
||||
|
||||
from openhands.agent_server.env_parser import from_env
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserAuthorizationResponse(BaseModel):
|
||||
success: bool
|
||||
error_detail: str | None = None
|
||||
|
||||
|
||||
class UserAuthorizer(ABC):
|
||||
"""Class determining whether a user may be authorized."""
|
||||
|
||||
@abstractmethod
|
||||
async def authorize_user(
|
||||
self, user_info: KeycloakUserInfo
|
||||
) -> UserAuthorizationResponse:
|
||||
"""Determine whether the info given is permitted."""
|
||||
|
||||
|
||||
class UserAuthorizerInjector(DiscriminatedUnionMixin, Injector[UserAuthorizer], ABC):
|
||||
pass
|
||||
|
||||
|
||||
def depends_user_authorizer():
|
||||
from server.auth.user.default_user_authorizer import (
|
||||
DefaultUserAuthorizerInjector,
|
||||
)
|
||||
|
||||
try:
|
||||
injector: UserAuthorizerInjector = from_env(
|
||||
UserAuthorizerInjector, 'OH_USER_AUTHORIZER'
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
logger.info('Using default UserAuthorizer')
|
||||
injector = DefaultUserAuthorizerInjector()
|
||||
|
||||
return Depends(injector.depends)
|
||||
@@ -4,14 +4,13 @@ import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional, cast
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote, urlencode
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
import posthog
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_utils import user_verifier
|
||||
from server.auth.constants import (
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
@@ -19,11 +18,14 @@ from server.auth.constants import (
|
||||
RECAPTCHA_SITE_KEY,
|
||||
ROLE_CHECK_ENABLED,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.recaptcha_service import recaptcha_service
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.auth.user.user_authorizer import (
|
||||
UserAuthorizer,
|
||||
depends_user_authorizer,
|
||||
)
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
@@ -40,6 +42,7 @@ from storage.database import a_session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.app_server.config import get_global_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import ProviderType, TokenResponse
|
||||
@@ -157,11 +160,16 @@ async def keycloak_callback(
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
error_description: Optional[str] = None,
|
||||
user_authorizer: UserAuthorizer = depends_user_authorizer(),
|
||||
):
|
||||
# Extract redirect URL, reCAPTCHA token, and invitation token from state
|
||||
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
|
||||
if not redirect_url:
|
||||
redirect_url = str(request.base_url)
|
||||
|
||||
if redirect_url is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Missing state in request params',
|
||||
)
|
||||
|
||||
if not code:
|
||||
# check if this is a forward from the account linking page
|
||||
@@ -170,36 +178,40 @@ async def keycloak_callback(
|
||||
and error_description == 'authentication_expired'
|
||||
):
|
||||
return RedirectResponse(redirect_url, status_code=302)
|
||||
return JSONResponse(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={'error': 'Missing code in request params'},
|
||||
detail='Missing code in request params',
|
||||
)
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}{request.url.path}'
|
||||
logger.debug(f'code: {code}, redirect_uri: {redirect_uri}')
|
||||
|
||||
web_url = get_global_config().web_url
|
||||
if not web_url:
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
web_url = f'{scheme}://{request.url.netloc}'
|
||||
redirect_uri = web_url + request.url.path
|
||||
|
||||
(
|
||||
keycloak_access_token,
|
||||
keycloak_refresh_token,
|
||||
) = await token_manager.get_keycloak_tokens(code, redirect_uri)
|
||||
if not keycloak_access_token or not keycloak_refresh_token:
|
||||
return JSONResponse(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={'error': 'Problem retrieving Keycloak tokens'},
|
||||
detail='Problem retrieving Keycloak tokens',
|
||||
)
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
logger.debug(f'user_info: {user_info}')
|
||||
if ROLE_CHECK_ENABLED and user_info.roles is None:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Missing required role'},
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail='Missing required role'
|
||||
)
|
||||
|
||||
if user_info.preferred_username is None:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={'error': 'Missing user ID or username in response'},
|
||||
authorization = await user_authorizer.authorize_user(user_info)
|
||||
if not authorization.success:
|
||||
# Return unauthorized
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=authorization.error_detail,
|
||||
)
|
||||
|
||||
email = user_info.email
|
||||
@@ -214,12 +226,10 @@ async def keycloak_callback(
|
||||
await UserStore.backfill_user_email(user_id, user_info_dict)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info.preferred_username}')
|
||||
return JSONResponse(
|
||||
logger.error(f'Failed to authenticate user {user_info.email}')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Failed to authenticate user {user_info.preferred_username}'
|
||||
},
|
||||
detail=f'Failed to authenticate user {user_info.email}',
|
||||
)
|
||||
|
||||
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
|
||||
@@ -234,7 +244,7 @@ async def keycloak_callback(
|
||||
'email': email,
|
||||
},
|
||||
)
|
||||
error_url = f'{request.base_url}login?recaptcha_blocked=true'
|
||||
error_url = f'{web_url}/login?recaptcha_blocked=true'
|
||||
return RedirectResponse(error_url, status_code=302)
|
||||
|
||||
user_ip = request.client.host if request.client else 'unknown'
|
||||
@@ -265,65 +275,13 @@ async def keycloak_callback(
|
||||
},
|
||||
)
|
||||
# Redirect to home with error parameter
|
||||
error_url = f'{request.base_url}login?recaptcha_blocked=true'
|
||||
error_url = f'{web_url}/login?recaptcha_blocked=true'
|
||||
return RedirectResponse(error_url, status_code=302)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'reCAPTCHA verification error at callback: {e}')
|
||||
# Fail open - continue with login if reCAPTCHA service unavailable
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and await domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
|
||||
# Disable the Keycloak account
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': 'Access denied: Your email domain is not allowed to access this service'
|
||||
},
|
||||
)
|
||||
|
||||
# Check for duplicate email with + modifier
|
||||
if email:
|
||||
try:
|
||||
has_duplicate = await token_manager.check_duplicate_base_email(
|
||||
email, user_id
|
||||
)
|
||||
if has_duplicate:
|
||||
logger.warning(
|
||||
f'Blocked signup attempt for email {email} - duplicate base email found',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Delete the Keycloak user that was automatically created during OAuth
|
||||
# This prevents orphaned accounts in Keycloak
|
||||
# The delete_keycloak_user method already handles all errors internally
|
||||
deletion_success = await token_manager.delete_keycloak_user(user_id)
|
||||
if deletion_success:
|
||||
logger.info(
|
||||
f'Deleted Keycloak user {user_id} after detecting duplicate email {email}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. '
|
||||
f'User may need to be manually cleaned up.'
|
||||
)
|
||||
|
||||
# Redirect to home page with query parameter indicating the issue
|
||||
home_url = f'{request.base_url}/login?duplicated_email=true'
|
||||
return RedirectResponse(home_url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log error but allow signup to proceed (fail open)
|
||||
logger.error(
|
||||
f'Error checking duplicate email for {email}: {e}',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Check email verification status
|
||||
email_verified = user_info.email_verified or False
|
||||
if not email_verified:
|
||||
@@ -358,6 +316,7 @@ async def keycloak_callback(
|
||||
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
if rate_limited:
|
||||
verification_redirect_url = f'{verification_redirect_url}&rate_limited=true'
|
||||
|
||||
# Preserve invitation token so it can be included in OAuth state after verification
|
||||
if invitation_token:
|
||||
verification_redirect_url = (
|
||||
@@ -379,13 +338,6 @@ async def keycloak_callback(
|
||||
ProviderType(idp), user_id, keycloak_access_token
|
||||
)
|
||||
|
||||
username = user_info.preferred_username
|
||||
if user_verifier.is_active() and not user_verifier.is_user_allowed(username):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Not authorized via waitlist'},
|
||||
)
|
||||
|
||||
valid_offline_token = (
|
||||
await token_manager.validate_offline_token(user_id=user_info.sub)
|
||||
if idp_type != 'saml'
|
||||
@@ -431,13 +383,19 @@ async def keycloak_callback(
|
||||
)
|
||||
|
||||
if not valid_offline_token:
|
||||
param_str = urlencode(
|
||||
{
|
||||
'client_id': KEYCLOAK_CLIENT_ID,
|
||||
'response_type': 'code',
|
||||
'kc_idp_hint': idp,
|
||||
'redirect_uri': f'{web_url}/oauth/keycloak/offline/callback',
|
||||
'scope': 'openid email profile offline_access',
|
||||
'state': state,
|
||||
}
|
||||
)
|
||||
redirect_url = (
|
||||
f'{KEYCLOAK_SERVER_URL_EXT}/realms/{KEYCLOAK_REALM_NAME}/protocol/openid-connect/auth'
|
||||
f'?client_id={KEYCLOAK_CLIENT_ID}&response_type=code'
|
||||
f'&kc_idp_hint={idp}'
|
||||
f'&redirect_uri={scheme}%3A%2F%2F{request.url.netloc}%2Foauth%2Fkeycloak%2Foffline%2Fcallback'
|
||||
f'&scope=openid%20email%20profile%20offline_access'
|
||||
f'&state={state}'
|
||||
f'?{param_str}'
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
@@ -532,7 +490,7 @@ async def keycloak_callback(
|
||||
response=response,
|
||||
keycloak_access_token=keycloak_access_token,
|
||||
keycloak_refresh_token=keycloak_refresh_token,
|
||||
secure=True if scheme == 'https' else False,
|
||||
secure=True if redirect_url.startswith('https') else False,
|
||||
accepted_tos=has_accepted_tos,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime, Identity, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class BlockedEmailDomain(Base): # type: ignore
|
||||
"""Stores blocked email domain patterns.
|
||||
|
||||
Supports blocking:
|
||||
- Exact domains: 'example.com' blocks 'user@example.com'
|
||||
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
|
||||
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
|
||||
"""
|
||||
|
||||
__tablename__ = 'blocked_email_domains'
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
domain = Column(String, nullable=False, unique=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -1,43 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import text
|
||||
from storage.database import a_session_maker
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockedEmailDomainStore:
|
||||
async def is_domain_blocked(self, domain: str) -> bool:
|
||||
"""Check if a domain is blocked by querying the database directly.
|
||||
|
||||
This method uses SQL to efficiently check if the domain matches any blocked pattern:
|
||||
- TLD patterns (e.g., '.us'): checks if domain ends with the pattern
|
||||
- Full domain patterns (e.g., 'example.com'): checks for exact match or subdomain match
|
||||
|
||||
Args:
|
||||
domain: The extracted domain from the email (e.g., 'example.com' or 'subdomain.example.com')
|
||||
|
||||
Returns:
|
||||
True if the domain is blocked, False otherwise
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
# SQL query that handles both TLD patterns and full domain patterns
|
||||
# TLD patterns (starting with '.'): check if domain ends with it (case-insensitive)
|
||||
# Full domain patterns: check for exact match or subdomain match
|
||||
# All comparisons are case-insensitive using LOWER() to ensure consistent matching
|
||||
query = text("""
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM blocked_email_domains
|
||||
WHERE
|
||||
-- TLD pattern (e.g., '.us') - check if domain ends with it (case-insensitive)
|
||||
(LOWER(domain) LIKE '.%' AND LOWER(:domain) LIKE '%' || LOWER(domain)) OR
|
||||
-- Full domain pattern (e.g., 'example.com')
|
||||
-- Block exact match or subdomains (case-insensitive)
|
||||
(LOWER(domain) NOT LIKE '.%' AND (
|
||||
LOWER(:domain) = LOWER(domain) OR
|
||||
LOWER(:domain) LIKE '%.' || LOWER(domain)
|
||||
))
|
||||
)
|
||||
""")
|
||||
result = await session.execute(query, {'domain': domain})
|
||||
return bool(result.scalar())
|
||||
45
enterprise/storage/user_authorization.py
Normal file
45
enterprise/storage/user_authorization.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""User authorization model for managing email/provider based access control."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import Column, DateTime, Identity, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class UserAuthorizationType(str, Enum):
|
||||
"""Type of user authorization rule."""
|
||||
|
||||
WHITELIST = 'whitelist'
|
||||
BLACKLIST = 'blacklist'
|
||||
|
||||
|
||||
class UserAuthorization(Base): # type: ignore
|
||||
"""Stores user authorization rules based on email patterns and provider types.
|
||||
|
||||
Supports:
|
||||
- Email pattern matching using SQL LIKE (e.g., '%@openhands.dev')
|
||||
- Provider type filtering (e.g., 'github', 'gitlab')
|
||||
- Whitelist/Blacklist rules
|
||||
|
||||
When email_pattern is NULL, the rule matches all emails.
|
||||
When provider_type is NULL, the rule matches all providers.
|
||||
"""
|
||||
|
||||
__tablename__ = 'user_authorizations'
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
email_pattern = Column(String, nullable=True)
|
||||
provider_type = Column(String, nullable=True)
|
||||
type = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
203
enterprise/storage/user_authorization_store.py
Normal file
203
enterprise/storage/user_authorization_store.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Store class for managing user authorizations."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.database import a_session_maker
|
||||
from storage.user_authorization import UserAuthorization, UserAuthorizationType
|
||||
|
||||
|
||||
class UserAuthorizationStore:
|
||||
"""Store for managing user authorization rules."""
|
||||
|
||||
@staticmethod
|
||||
async def _get_matching_authorizations(
|
||||
email: str,
|
||||
provider_type: str | None,
|
||||
session: AsyncSession,
|
||||
) -> list[UserAuthorization]:
|
||||
"""Get all authorization rules that match the given email and provider.
|
||||
|
||||
Uses SQL LIKE for pattern matching:
|
||||
- email_pattern is NULL matches all emails
|
||||
- provider_type is NULL matches all providers
|
||||
- email LIKE email_pattern for pattern matching
|
||||
|
||||
Args:
|
||||
email: The user's email address
|
||||
provider_type: The identity provider type (e.g., 'github', 'gitlab')
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
List of matching UserAuthorization objects
|
||||
"""
|
||||
# Build query using SQLAlchemy ORM
|
||||
# We need: (email_pattern IS NULL OR LOWER(email) LIKE LOWER(email_pattern))
|
||||
# AND (provider_type IS NULL OR provider_type = :provider_type)
|
||||
email_condition = or_(
|
||||
UserAuthorization.email_pattern.is_(None),
|
||||
func.lower(email).like(func.lower(UserAuthorization.email_pattern)),
|
||||
)
|
||||
provider_condition = or_(
|
||||
UserAuthorization.provider_type.is_(None),
|
||||
UserAuthorization.provider_type == provider_type,
|
||||
)
|
||||
|
||||
query = select(UserAuthorization).where(email_condition, provider_condition)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def get_matching_authorizations(
|
||||
email: str,
|
||||
provider_type: str | None,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> list[UserAuthorization]:
|
||||
"""Get all authorization rules that match the given email and provider.
|
||||
|
||||
Args:
|
||||
email: The user's email address
|
||||
provider_type: The identity provider type (e.g., 'github', 'gitlab')
|
||||
session: Optional database session
|
||||
|
||||
Returns:
|
||||
List of matching UserAuthorization objects
|
||||
"""
|
||||
if session is not None:
|
||||
return await UserAuthorizationStore._get_matching_authorizations(
|
||||
email, provider_type, session
|
||||
)
|
||||
async with a_session_maker() as new_session:
|
||||
return await UserAuthorizationStore._get_matching_authorizations(
|
||||
email, provider_type, new_session
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_authorization_type(
|
||||
email: str,
|
||||
provider_type: str | None,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> UserAuthorizationType | None:
|
||||
"""Get the authorization type for the given email and provider.
|
||||
|
||||
Checks matching authorization rules and returns the effective authorization
|
||||
type. Whitelist rules take precedence over blacklist rules.
|
||||
|
||||
Args:
|
||||
email: The user's email address
|
||||
provider_type: The identity provider type (e.g., 'github', 'gitlab')
|
||||
session: Optional database session
|
||||
|
||||
Returns:
|
||||
UserAuthorizationType.WHITELIST if a whitelist rule matches,
|
||||
UserAuthorizationType.BLACKLIST if a blacklist rule matches (and no whitelist),
|
||||
None if no rules match
|
||||
"""
|
||||
authorizations = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email, provider_type, session
|
||||
)
|
||||
|
||||
has_whitelist = any(
|
||||
auth.type == UserAuthorizationType.WHITELIST.value
|
||||
for auth in authorizations
|
||||
)
|
||||
if has_whitelist:
|
||||
return UserAuthorizationType.WHITELIST
|
||||
|
||||
has_blacklist = any(
|
||||
auth.type == UserAuthorizationType.BLACKLIST.value
|
||||
for auth in authorizations
|
||||
)
|
||||
if has_blacklist:
|
||||
return UserAuthorizationType.BLACKLIST
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _create_authorization(
|
||||
email_pattern: str | None,
|
||||
provider_type: str | None,
|
||||
auth_type: UserAuthorizationType,
|
||||
session: AsyncSession,
|
||||
) -> UserAuthorization:
|
||||
"""Create a new user authorization rule."""
|
||||
authorization = UserAuthorization(
|
||||
email_pattern=email_pattern,
|
||||
provider_type=provider_type,
|
||||
type=auth_type.value,
|
||||
)
|
||||
session.add(authorization)
|
||||
await session.flush()
|
||||
await session.refresh(authorization)
|
||||
return authorization
|
||||
|
||||
@staticmethod
|
||||
async def create_authorization(
|
||||
email_pattern: str | None,
|
||||
provider_type: str | None,
|
||||
auth_type: UserAuthorizationType,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> UserAuthorization:
|
||||
"""Create a new user authorization rule.
|
||||
|
||||
Args:
|
||||
email_pattern: SQL LIKE pattern for email matching (e.g., '%@openhands.dev')
|
||||
provider_type: Provider type to match (e.g., 'github'), or None for all
|
||||
auth_type: WHITELIST or BLACKLIST
|
||||
session: Optional database session
|
||||
|
||||
Returns:
|
||||
The created UserAuthorization object
|
||||
"""
|
||||
if session is not None:
|
||||
return await UserAuthorizationStore._create_authorization(
|
||||
email_pattern, provider_type, auth_type, session
|
||||
)
|
||||
async with a_session_maker() as new_session:
|
||||
auth = await UserAuthorizationStore._create_authorization(
|
||||
email_pattern, provider_type, auth_type, new_session
|
||||
)
|
||||
await new_session.commit()
|
||||
return auth
|
||||
|
||||
@staticmethod
|
||||
async def _delete_authorization(
|
||||
authorization_id: int,
|
||||
session: AsyncSession,
|
||||
) -> bool:
|
||||
"""Delete an authorization rule by ID."""
|
||||
result = await session.execute(
|
||||
select(UserAuthorization).where(UserAuthorization.id == authorization_id)
|
||||
)
|
||||
authorization = result.scalars().first()
|
||||
if authorization:
|
||||
await session.delete(authorization)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def delete_authorization(
|
||||
authorization_id: int,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> bool:
|
||||
"""Delete an authorization rule by ID.
|
||||
|
||||
Args:
|
||||
authorization_id: The ID of the authorization to delete
|
||||
session: Optional database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
if session is not None:
|
||||
return await UserAuthorizationStore._delete_authorization(
|
||||
authorization_id, session
|
||||
)
|
||||
async with a_session_maker() as new_session:
|
||||
deleted = await UserAuthorizationStore._delete_authorization(
|
||||
authorization_id, new_session
|
||||
)
|
||||
if deleted:
|
||||
await new_session.commit()
|
||||
return deleted
|
||||
635
enterprise/tests/unit/storage/test_user_authorization_store.py
Normal file
635
enterprise/tests/unit/storage/test_user_authorization_store.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""Unit tests for UserAuthorizationStore using SQLite in-memory database."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.user_authorization import UserAuthorization, UserAuthorizationType
|
||||
from storage.user_authorization_store import UserAuthorizationStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker bound to the async engine."""
|
||||
session_maker = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
return session_maker
|
||||
|
||||
|
||||
class TestGetMatchingAuthorizations:
|
||||
"""Tests for get_matching_authorizations method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_authorizations_returns_empty_list(self, async_session_maker):
|
||||
"""Test returns empty list when no authorizations exist."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='test@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_email_pattern_match(self, async_session_maker):
|
||||
"""Test matching with exact email pattern."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
# Create a whitelist rule for exact email
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='test@example.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='test@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].email_pattern == 'test@example.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_domain_suffix_pattern_match(self, async_session_maker):
|
||||
"""Test matching with domain suffix pattern (e.g., %@example.com)."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
# Create a whitelist rule for domain
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
# Should match
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
# Should also match different user
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='another.user@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
# Should not match different domain
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@other.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_email_pattern_matches_all_emails(self, async_session_maker):
|
||||
"""Test that NULL email_pattern matches all emails."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
# Create a rule with NULL email pattern
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern=None,
|
||||
provider_type='github',
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
|
||||
# Should match any email with github provider
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='any@email.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
# Should not match different provider
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='any@email.com',
|
||||
provider_type='gitlab',
|
||||
)
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_provider_type_matches_all_providers(self, async_session_maker):
|
||||
"""Test that NULL provider_type matches all providers."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
# Create a rule with NULL provider type
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@blocked.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
|
||||
# Should match any provider
|
||||
for provider in ['github', 'gitlab', 'bitbucket', None]:
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@blocked.com',
|
||||
provider_type=provider,
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_type_filter(self, async_session_maker):
|
||||
"""Test filtering by provider type."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
# Create rules for different providers
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type='github',
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type='gitlab',
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
|
||||
# Check github
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].type == UserAuthorizationType.WHITELIST.value
|
||||
|
||||
# Check gitlab
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@example.com',
|
||||
provider_type='gitlab',
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].type == UserAuthorizationType.BLACKLIST.value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_case_insensitive_email_matching(self, async_session_maker):
|
||||
"""Test that email matching is case insensitive."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@Example.COM',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
# Should match regardless of case
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='USER@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_matching_rules(self, async_session_maker):
|
||||
"""Test that multiple matching rules are returned."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
# Create multiple rules that match
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern=None, # Matches all emails
|
||||
provider_type='github',
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_provided_session(self, async_session_maker):
|
||||
"""Test using a provided session instead of creating one."""
|
||||
async with async_session_maker() as session:
|
||||
# Create authorization within session
|
||||
auth = UserAuthorization(
|
||||
email_pattern='%@test.com',
|
||||
provider_type=None,
|
||||
type=UserAuthorizationType.WHITELIST.value,
|
||||
)
|
||||
session.add(auth)
|
||||
await session.flush()
|
||||
|
||||
# Query within same session
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@test.com',
|
||||
provider_type='github',
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestGetAuthorizationType:
|
||||
"""Tests for get_authorization_type method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_whitelist_when_whitelist_match_exists(
|
||||
self, async_session_maker
|
||||
):
|
||||
"""Test returns WHITELIST when a whitelist rule matches."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@allowed.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
result = await UserAuthorizationStore.get_authorization_type(
|
||||
email='user@allowed.com',
|
||||
provider_type='github',
|
||||
)
|
||||
|
||||
assert result == UserAuthorizationType.WHITELIST
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_blacklist_when_blacklist_match_exists(
|
||||
self, async_session_maker
|
||||
):
|
||||
"""Test returns BLACKLIST when a blacklist rule matches."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@blocked.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
|
||||
result = await UserAuthorizationStore.get_authorization_type(
|
||||
email='user@blocked.com',
|
||||
provider_type='github',
|
||||
)
|
||||
|
||||
assert result == UserAuthorizationType.BLACKLIST
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_rules_exist(self, async_session_maker):
|
||||
"""Test returns None when no authorization rules exist."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
result = await UserAuthorizationStore.get_authorization_type(
|
||||
email='user@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_only_non_matching_rules_exist(
|
||||
self, async_session_maker
|
||||
):
|
||||
"""Test returns None when rules exist but don't match."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@other.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
|
||||
result = await UserAuthorizationStore.get_authorization_type(
|
||||
email='user@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitelist_takes_precedence_over_blacklist(self, async_session_maker):
|
||||
"""Test whitelist takes precedence when both match."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
# Create both whitelist and blacklist rules that match
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type='github',
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
result = await UserAuthorizationStore.get_authorization_type(
|
||||
email='user@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
|
||||
assert result == UserAuthorizationType.WHITELIST
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_blacklist_for_domain_block(self, async_session_maker):
|
||||
"""Test blacklist match for domain-based blocking."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@disposable-email.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
|
||||
result = await UserAuthorizationStore.get_authorization_type(
|
||||
email='spammer@disposable-email.com',
|
||||
provider_type='github',
|
||||
)
|
||||
|
||||
assert result == UserAuthorizationType.BLACKLIST
|
||||
|
||||
|
||||
class TestCreateAuthorization:
|
||||
"""Tests for create_authorization method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_whitelist_authorization(self, async_session_maker):
|
||||
"""Test creating a whitelist authorization."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
auth = await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type='github',
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
assert auth.id is not None
|
||||
assert auth.email_pattern == '%@example.com'
|
||||
assert auth.provider_type == 'github'
|
||||
assert auth.type == UserAuthorizationType.WHITELIST.value
|
||||
assert auth.created_at is not None
|
||||
assert auth.updated_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_blacklist_authorization(self, async_session_maker):
|
||||
"""Test creating a blacklist authorization."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
auth = await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@blocked.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
|
||||
assert auth.id is not None
|
||||
assert auth.email_pattern == '%@blocked.com'
|
||||
assert auth.provider_type is None
|
||||
assert auth.type == UserAuthorizationType.BLACKLIST.value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_with_null_email_pattern(self, async_session_maker):
|
||||
"""Test creating authorization with NULL email pattern."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
auth = await UserAuthorizationStore.create_authorization(
|
||||
email_pattern=None,
|
||||
provider_type='github',
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
assert auth.email_pattern is None
|
||||
assert auth.provider_type == 'github'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_with_provided_session(self, async_session_maker):
|
||||
"""Test creating authorization with a provided session."""
|
||||
async with async_session_maker() as session:
|
||||
auth = await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@test.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert auth.id is not None
|
||||
|
||||
# Verify it exists in session
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@test.com',
|
||||
provider_type='github',
|
||||
session=session,
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestDeleteAuthorization:
|
||||
"""Tests for delete_authorization method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_existing_authorization(self, async_session_maker):
|
||||
"""Test deleting an existing authorization."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
# Create an authorization
|
||||
auth = await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
# Delete it
|
||||
deleted = await UserAuthorizationStore.delete_authorization(auth.id)
|
||||
|
||||
assert deleted is True
|
||||
|
||||
# Verify it's gone
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_for_nonexistent_authorization(
|
||||
self, async_session_maker
|
||||
):
|
||||
"""Test returns False when authorization doesn't exist."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
deleted = await UserAuthorizationStore.delete_authorization(99999)
|
||||
|
||||
assert deleted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_with_provided_session(self, async_session_maker):
|
||||
"""Test deleting authorization with a provided session."""
|
||||
async with async_session_maker() as session:
|
||||
# Create an authorization
|
||||
auth = await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@test.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
session=session,
|
||||
)
|
||||
auth_id = auth.id
|
||||
|
||||
# Flush to persist to database before delete
|
||||
await session.flush()
|
||||
|
||||
# Delete within same session
|
||||
deleted = await UserAuthorizationStore.delete_authorization(
|
||||
auth_id, session=session
|
||||
)
|
||||
|
||||
assert deleted is True
|
||||
|
||||
# Flush delete to database
|
||||
await session.flush()
|
||||
|
||||
# Verify it's gone
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@test.com',
|
||||
provider_type='github',
|
||||
session=session,
|
||||
)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestPatternMatchingEdgeCases:
|
||||
"""Tests for edge cases in pattern matching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_prefix_pattern(self, async_session_maker):
|
||||
"""Test pattern with wildcard prefix (e.g., admin%)."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='admin%',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
# Should match
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='admin@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
# Should also match
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='administrator@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
# Should not match
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@admin.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_character_wildcard(self, async_session_maker):
|
||||
"""Test pattern with single character wildcard (underscore in SQL LIKE)."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='user_@example.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
# Should match user1@example.com
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user1@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
# Should not match user12@example.com
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user12@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_email_with_plus_sign(self, async_session_maker):
|
||||
"""Test matching emails with plus signs (common for email aliases)."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.WHITELIST,
|
||||
)
|
||||
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user+alias@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subdomain_email(self, async_session_maker):
|
||||
"""Test that subdomain emails don't match parent domain patterns."""
|
||||
with patch(
|
||||
'storage.user_authorization_store.a_session_maker', async_session_maker
|
||||
):
|
||||
await UserAuthorizationStore.create_authorization(
|
||||
email_pattern='%@example.com',
|
||||
provider_type=None,
|
||||
auth_type=UserAuthorizationType.BLACKLIST,
|
||||
)
|
||||
|
||||
# Should match exact domain
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
# Should NOT match subdomain
|
||||
result = await UserAuthorizationStore.get_matching_authorizations(
|
||||
email='user@sub.example.com',
|
||||
provider_type='github',
|
||||
)
|
||||
assert len(result) == 0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,429 +0,0 @@
|
||||
"""Unit tests for DomainBlocker class."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from server.auth.domain_blocker import DomainBlocker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_store():
|
||||
"""Create a mock BlockedEmailDomainStore for testing."""
|
||||
store = MagicMock()
|
||||
store.is_domain_blocked = AsyncMock()
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def domain_blocker(mock_store):
|
||||
"""Create a DomainBlocker instance for testing with a mocked store."""
|
||||
return DomainBlocker(store=mock_store)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected_domain',
|
||||
[
|
||||
('user@example.com', 'example.com'),
|
||||
('test@colsch.us', 'colsch.us'),
|
||||
('user.name@other-domain.com', 'other-domain.com'),
|
||||
('USER@EXAMPLE.COM', 'example.com'), # Case insensitive
|
||||
('user@EXAMPLE.COM', 'example.com'),
|
||||
(' user@example.com ', 'example.com'), # Whitespace handling
|
||||
],
|
||||
)
|
||||
def test_extract_domain_valid_emails(domain_blocker, email, expected_domain):
|
||||
"""Test that _extract_domain correctly extracts and normalizes domains from valid emails."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected_domain
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected',
|
||||
[
|
||||
(None, None),
|
||||
('', None),
|
||||
('invalid-email', None),
|
||||
('user@', None), # Empty domain after @
|
||||
('no-at-sign', None),
|
||||
],
|
||||
)
|
||||
def test_extract_domain_invalid_emails(domain_blocker, email, expected):
|
||||
"""Test that _extract_domain returns None for invalid email formats."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_with_none_email(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when email is None."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked(None)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_with_empty_email(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when email is empty."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_with_invalid_email(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when email format is invalid."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('invalid-email')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_not_blocked(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when domain is not blocked."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_blocked(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns True when domain is blocked."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_case_insensitive(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked performs case-insensitive domain extraction."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@COLSCH.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_with_whitespace(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked handles emails with whitespace correctly."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked correctly checks multiple domains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked = AsyncMock(
|
||||
side_effect=lambda domain: domain
|
||||
in [
|
||||
'other-domain.com',
|
||||
'blocked.org',
|
||||
]
|
||||
)
|
||||
|
||||
# Act
|
||||
result1 = await domain_blocker.is_domain_blocked('user@other-domain.com')
|
||||
result2 = await domain_blocker.is_domain_blocked('user@blocked.org')
|
||||
result3 = await domain_blocker.is_domain_blocked('user@allowed.com')
|
||||
|
||||
# Assert
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is False
|
||||
assert mock_store.is_domain_blocked.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern blocks domains ending with that TLD."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@company.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('company.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern blocks subdomains with that TLD."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@subdomain.company.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('subdomain.company.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern does not block domains with different TLD."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@company.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('company.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_case_insensitive(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern matching is case-insensitive."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@COMPANY.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('company.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_with_multi_level_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern works with multi-level TLDs like .co.uk."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.side_effect = lambda domain: domain.endswith('.co.uk')
|
||||
|
||||
# Act
|
||||
result_match = await domain_blocker.is_domain_blocked('user@example.co.uk')
|
||||
result_subdomain = await domain_blocker.is_domain_blocked('user@api.example.co.uk')
|
||||
result_no_match = await domain_blocker.is_domain_blocked('user@example.uk')
|
||||
|
||||
# Assert
|
||||
assert result_match is True
|
||||
assert result_subdomain is True
|
||||
assert result_no_match is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_blocks_exact_match(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern blocks exact domain match."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_blocks_subdomain(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern blocks subdomains of that domain."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@subdomain.example.com')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('subdomain.example.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern blocks multi-level subdomains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@api.v2.example.com')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('api.v2.example.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern does not block domains that contain but don't match the pattern."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@notexample.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('notexample.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern does not block same domain with different TLD."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@example.org')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.org')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that blocking a subdomain also blocks its nested subdomains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.side_effect = (
|
||||
lambda domain: 'api.example.com' in domain
|
||||
)
|
||||
|
||||
# Act
|
||||
result_exact = await domain_blocker.is_domain_blocked('user@api.example.com')
|
||||
result_nested = await domain_blocker.is_domain_blocked('user@v1.api.example.com')
|
||||
result_parent = await domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result_exact is True
|
||||
assert result_nested is True
|
||||
assert result_parent is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store):
|
||||
"""Test that domain patterns work with hyphenated domains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result_exact = await domain_blocker.is_domain_blocked('user@my-company.com')
|
||||
result_subdomain = await domain_blocker.is_domain_blocked('user@api.my-company.com')
|
||||
|
||||
# Assert
|
||||
assert result_exact is True
|
||||
assert result_subdomain is True
|
||||
assert mock_store.is_domain_blocked.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store):
|
||||
"""Test that domain patterns work with numeric domains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result_exact = await domain_blocker.is_domain_blocked('user@test123.com')
|
||||
result_subdomain = await domain_blocker.is_domain_blocked('user@api.test123.com')
|
||||
|
||||
# Assert
|
||||
assert result_exact is True
|
||||
assert result_subdomain is True
|
||||
assert mock_store.is_domain_blocked.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_store):
|
||||
"""Test that blocking works with very long subdomain chains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked(
|
||||
'user@level4.level3.level2.level1.example.com'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with(
|
||||
'level4.level3.level2.level1.example.com'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_handles_store_exception(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when store raises an exception."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.side_effect = Exception('Database connection error')
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.com')
|
||||
@@ -18,6 +18,7 @@ from server.auth.saas_user_auth import (
|
||||
saas_user_auth_from_cookie,
|
||||
saas_user_auth_from_signed_token,
|
||||
)
|
||||
from storage.user_authorization import UserAuthorizationType
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
|
||||
@@ -493,14 +494,20 @@ async def test_saas_user_auth_from_signed_token(mock_config):
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
# Mock UserAuthorizationStore to avoid database access
|
||||
with patch(
|
||||
'server.auth.saas_user_auth.UserAuthorizationStore'
|
||||
) as mock_user_auth_store:
|
||||
mock_user_auth_store.get_authorization_type = AsyncMock(return_value=None)
|
||||
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
assert result.access_token.get_secret_value() == access_token
|
||||
assert result.refresh_token.get_secret_value() == 'test_refresh_token'
|
||||
assert result.email == 'test@example.com'
|
||||
assert result.email_verified is True
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
assert result.access_token.get_secret_value() == access_token
|
||||
assert result.refresh_token.get_secret_value() == 'test_refresh_token'
|
||||
assert result.email == 'test@example.com'
|
||||
assert result.email_verified is True
|
||||
|
||||
|
||||
def test_get_api_key_from_header_with_authorization_header():
|
||||
@@ -701,15 +708,21 @@ async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
'server.auth.saas_user_auth.UserAuthorizationStore'
|
||||
) as mock_user_auth_store:
|
||||
mock_user_auth_store.get_authorization_type = AsyncMock(
|
||||
return_value=UserAuthorizationType.BLACKLIST
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
assert 'email domain is not allowed' in str(exc_info.value)
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
mock_user_auth_store.get_authorization_type.assert_called_once_with(
|
||||
'user@colsch.us', None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -730,8 +743,10 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
with patch(
|
||||
'server.auth.saas_user_auth.UserAuthorizationStore'
|
||||
) as mock_user_auth_store:
|
||||
mock_user_auth_store.get_authorization_type = AsyncMock(return_value=None)
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
@@ -740,8 +755,8 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
assert result.email == 'user@example.com'
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||
'user@example.com'
|
||||
mock_user_auth_store.get_authorization_type.assert_called_once_with(
|
||||
'user@example.com', None
|
||||
)
|
||||
|
||||
|
||||
@@ -763,8 +778,10 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
with patch(
|
||||
'server.auth.saas_user_auth.UserAuthorizationStore'
|
||||
) as mock_user_auth_store:
|
||||
mock_user_auth_store.get_authorization_type = AsyncMock(return_value=None)
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
@@ -772,4 +789,6 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
mock_user_auth_store.get_authorization_type.assert_called_once_with(
|
||||
'user@colsch.us', None
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user