Refactor user authorization: Replace domain blocklist with flexible whitelist/blacklist pattern matching (#13207)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-03-06 09:10:59 -07:00
committed by GitHub
parent 2d7362bf26
commit 6186685ebc
16 changed files with 1490 additions and 1106 deletions

View File

@@ -0,0 +1,136 @@
"""Create user_authorizations table and migrate blocked_email_domains
Revision ID: 099
Revises: 098
Create Date: 2025-03-05 00:00:00.000000
"""
import os
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '099'
down_revision: Union[str, None] = '098'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _seed_from_environment() -> None:
"""Seed user_authorizations table from environment variables.
Reads EMAIL_PATTERN_BLACKLIST and EMAIL_PATTERN_WHITELIST environment variables.
Each should be a comma-separated list of SQL LIKE patterns (e.g., '%@example.com').
If the environment variables are not set or empty, this function does nothing.
This allows us to set up feature deployments with particular patterns already
blacklisted or whitelisted. (For example, you could blacklist everything with
`%`, and then whitelist certain email accounts.)
"""
blacklist_patterns = os.environ.get('EMAIL_PATTERN_BLACKLIST', '').strip()
whitelist_patterns = os.environ.get('EMAIL_PATTERN_WHITELIST', '').strip()
connection = op.get_bind()
if blacklist_patterns:
for pattern in blacklist_patterns.split(','):
pattern = pattern.strip()
if pattern:
connection.execute(
sa.text("""
INSERT INTO user_authorizations
(email_pattern, provider_type, type)
VALUES
(:pattern, NULL, 'blacklist')
"""),
{'pattern': pattern},
)
if whitelist_patterns:
for pattern in whitelist_patterns.split(','):
pattern = pattern.strip()
if pattern:
connection.execute(
sa.text("""
INSERT INTO user_authorizations
(email_pattern, provider_type, type)
VALUES
(:pattern, NULL, 'whitelist')
"""),
{'pattern': pattern},
)
def upgrade() -> None:
"""Create user_authorizations table, migrate data, and drop blocked_email_domains."""
# Create user_authorizations table
op.create_table(
'user_authorizations',
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
sa.Column('email_pattern', sa.String(), nullable=True),
sa.Column('provider_type', sa.String(), nullable=True),
sa.Column('type', sa.String(), nullable=False),
sa.Column(
'created_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column(
'updated_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.PrimaryKeyConstraint('id'),
)
# Create index on email_pattern for efficient LIKE queries
op.create_index(
'ix_user_authorizations_email_pattern',
'user_authorizations',
['email_pattern'],
)
# Create index on type for efficient filtering
op.create_index(
'ix_user_authorizations_type',
'user_authorizations',
['type'],
)
# Migrate existing blocked_email_domains to user_authorizations as blacklist entries
# The domain patterns are converted to SQL LIKE patterns:
# - 'example.com' becomes '%@example.com' (matches user@example.com)
# - '.us' becomes '%@%.us' (matches user@anything.us)
# We also add '%.' prefix for subdomain matching
op.execute("""
INSERT INTO user_authorizations (email_pattern, provider_type, type, created_at, updated_at)
SELECT
CASE
WHEN domain LIKE '.%' THEN '%' || domain
ELSE '%@%' || domain
END as email_pattern,
NULL as provider_type,
'blacklist' as type,
created_at,
updated_at
FROM blocked_email_domains
""")
# Seed additional patterns from environment variables (if set)
_seed_from_environment()
def downgrade() -> None:
"""Recreate blocked_email_domains table and migrate data back."""
# Drop user_authorizations table
op.drop_index('ix_user_authorizations_type', table_name='user_authorizations')
op.drop_index(
'ix_user_authorizations_email_pattern', table_name='user_authorizations'
)
op.drop_table('user_authorizations')

View File

@@ -1,53 +0,0 @@
import os
from openhands.core.logger import openhands_logger as logger
class UserVerifier:
def __init__(self) -> None:
logger.debug('Initializing UserVerifier')
self.file_users: list[str] | None = None
# Initialize from environment variables
self._init_file_users()
def _init_file_users(self) -> None:
"""Load users from text file if configured."""
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
if not waitlist:
logger.debug('GITHUB_USER_LIST_FILE not configured')
return
if not os.path.exists(waitlist):
logger.error(f'User list file not found: {waitlist}')
raise FileNotFoundError(f'User list file not found: {waitlist}')
try:
with open(waitlist, 'r') as f:
self.file_users = [line.strip().lower() for line in f if line.strip()]
logger.info(
f'Successfully loaded {len(self.file_users)} users from {waitlist}'
)
except Exception:
logger.exception(f'Error reading user list file {waitlist}')
def is_active(self) -> bool:
if os.getenv('DISABLE_WAITLIST', '').lower() == 'true':
logger.info('Waitlist disabled via DISABLE_WAITLIST env var')
return False
return bool(self.file_users)
def is_user_allowed(self, username: str) -> bool:
"""Check if user is allowed based on file and/or sheet configuration."""
logger.debug(f'Checking if GitHub user {username} is allowed')
if self.file_users:
if username.lower() in self.file_users:
logger.debug(f'User {username} found in text file allowlist')
return True
logger.debug(f'User {username} not found in text file allowlist')
logger.debug(f'User {username} not found in any allowlist')
return False
user_verifier = UserVerifier()

View File

@@ -1,66 +0,0 @@
from storage.blocked_email_domain_store import BlockedEmailDomainStore
from openhands.core.logger import openhands_logger as logger
class DomainBlocker:
def __init__(self, store: BlockedEmailDomainStore) -> None:
logger.debug('Initializing DomainBlocker')
self.store = store
def _extract_domain(self, email: str) -> str | None:
"""Extract and normalize email domain from email address"""
if not email:
return None
try:
# Extract domain part after @
if '@' not in email:
return None
domain = email.split('@')[1].strip().lower()
return domain if domain else None
except Exception:
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
return None
async def is_domain_blocked(self, email: str) -> bool:
"""Check if email domain is blocked by querying the database directly via SQL.
Supports blocking:
- Exact domains: 'example.com' blocks 'user@example.com'
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
The blocking logic is handled efficiently in SQL, avoiding the need to load
all blocked domains into memory.
"""
if not email:
logger.debug('No email provided for domain check')
return False
domain = self._extract_domain(email)
if not domain:
logger.debug(f'Could not extract domain from email: {email}')
return False
try:
# Query database directly via SQL to check if domain is blocked
is_blocked = await self.store.is_domain_blocked(domain)
if is_blocked:
logger.warning(f'Email domain {domain} is blocked for email: {email}')
else:
logger.debug(f'Email domain {domain} is not blocked')
return is_blocked
except Exception as e:
logger.error(
f'Error checking if domain is blocked for email {email}: {e}',
exc_info=True,
)
# Fail-safe: if database query fails, don't block (allow auth to proceed)
return False
# Initialize store and domain blocker
_store = BlockedEmailDomainStore()
domain_blocker = DomainBlocker(store=_store)

View File

@@ -13,7 +13,6 @@ from server.auth.auth_error import (
ExpiredError,
NoCredentialsError,
)
from server.auth.domain_blocker import domain_blocker
from server.auth.token_manager import TokenManager
from server.config import get_config
from server.logger import logger
@@ -24,6 +23,8 @@ from storage.auth_tokens import AuthTokens
from storage.database import a_session_maker
from storage.saas_secrets_store import SaasSecretsStore
from storage.saas_settings_store import SaasSettingsStore
from storage.user_authorization import UserAuthorizationType
from storage.user_authorization_store import UserAuthorizationStore
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from openhands.integrations.provider import (
@@ -326,14 +327,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
email = access_token_payload['email']
email_verified = access_token_payload['email_verified']
# Check if email domain is blocked
if email and await domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for existing user with email: {email}'
)
raise AuthError(
'Access denied: Your email domain is not allowed to access this service'
)
# Check if email is blacklisted (whitelist takes precedence)
if email:
auth_type = await UserAuthorizationStore.get_authorization_type(email, None)
if auth_type == UserAuthorizationType.BLACKLIST:
logger.warning(
f'Blocked authentication attempt for existing user with email: {email}'
)
raise AuthError(
'Access denied: Your email domain is not allowed to access this service'
)
logger.debug('saas_user_auth_from_signed_token:return')

View File

View File

@@ -0,0 +1,98 @@
import logging
from dataclasses import dataclass
from typing import AsyncGenerator
from fastapi import Request
from pydantic import Field
from server.auth.email_validation import extract_base_email
from server.auth.token_manager import KeycloakUserInfo, TokenManager
from server.auth.user.user_authorizer import (
UserAuthorizationResponse,
UserAuthorizer,
UserAuthorizerInjector,
)
from storage.user_authorization import UserAuthorizationType
from storage.user_authorization_store import UserAuthorizationStore
from openhands.app_server.services.injector import InjectorState
logger = logging.getLogger(__name__)
token_manager = TokenManager()
@dataclass
class DefaultUserAuthorizer(UserAuthorizer):
"""Class determining whether a user may be authorized.
Uses the user_authorizations database table to check whitelist/blacklist rules.
"""
prevent_duplicates: bool
async def authorize_user(
self, user_info: KeycloakUserInfo
) -> UserAuthorizationResponse:
user_id = user_info.sub
email = user_info.email
provider_type = user_info.identity_provider
try:
if not email:
logger.warning(f'No email provided for user_id: {user_id}')
return UserAuthorizationResponse(
success=False, error_detail='missing_email'
)
if self.prevent_duplicates:
has_duplicate = await token_manager.check_duplicate_base_email(
email, user_id
)
if has_duplicate:
logger.warning(
f'Blocked signup attempt for email {email} - duplicate base email found',
extra={'user_id': user_id, 'email': email},
)
return UserAuthorizationResponse(
success=False, error_detail='duplicate_email'
)
# Check authorization rules (whitelist takes precedence over blacklist)
base_email = extract_base_email(email)
if base_email is None:
return UserAuthorizationResponse(
success=False, error_detail='invalid_email'
)
auth_type = await UserAuthorizationStore.get_authorization_type(
base_email, provider_type
)
if auth_type == UserAuthorizationType.WHITELIST:
logger.debug(
f'User {email} matched whitelist rule',
extra={'user_id': user_id, 'email': email},
)
return UserAuthorizationResponse(success=True)
if auth_type == UserAuthorizationType.BLACKLIST:
logger.warning(
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
)
return UserAuthorizationResponse(success=False, error_detail='blocked')
return UserAuthorizationResponse(success=True)
except Exception:
logger.exception('error authorizing user', extra={'user_id': user_id})
return UserAuthorizationResponse(success=False)
class DefaultUserAuthorizerInjector(UserAuthorizerInjector):
prevent_duplicates: bool = Field(
default=True,
description='Whether duplicate emails (containing +) are filtered',
)
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[UserAuthorizer, None]:
yield DefaultUserAuthorizer(
prevent_duplicates=self.prevent_duplicates,
)

View File

@@ -0,0 +1,48 @@
import logging
from abc import ABC, abstractmethod
from fastapi import Depends
from pydantic import BaseModel
from server.auth.token_manager import KeycloakUserInfo
from openhands.agent_server.env_parser import from_env
from openhands.app_server.services.injector import Injector
from openhands.sdk.utils.models import DiscriminatedUnionMixin
logger = logging.getLogger(__name__)
class UserAuthorizationResponse(BaseModel):
success: bool
error_detail: str | None = None
class UserAuthorizer(ABC):
"""Class determining whether a user may be authorized."""
@abstractmethod
async def authorize_user(
self, user_info: KeycloakUserInfo
) -> UserAuthorizationResponse:
"""Determine whether the info given is permitted."""
class UserAuthorizerInjector(DiscriminatedUnionMixin, Injector[UserAuthorizer], ABC):
pass
def depends_user_authorizer():
from server.auth.user.default_user_authorizer import (
DefaultUserAuthorizerInjector,
)
try:
injector: UserAuthorizerInjector = from_env(
UserAuthorizerInjector, 'OH_USER_AUTHORIZER'
)
except Exception as ex:
print(ex)
logger.info('Using default UserAuthorizer')
injector = DefaultUserAuthorizerInjector()
return Depends(injector.depends)

View File

@@ -4,14 +4,13 @@ import uuid
import warnings
from datetime import datetime, timezone
from typing import Annotated, Literal, Optional, cast
from urllib.parse import quote
from urllib.parse import quote, urlencode
from uuid import UUID as parse_uuid
import posthog
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import SecretStr
from server.auth.auth_utils import user_verifier
from server.auth.constants import (
KEYCLOAK_CLIENT_ID,
KEYCLOAK_REALM_NAME,
@@ -19,11 +18,14 @@ from server.auth.constants import (
RECAPTCHA_SITE_KEY,
ROLE_CHECK_ENABLED,
)
from server.auth.domain_blocker import domain_blocker
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
from server.auth.recaptcha_service import recaptcha_service
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
from server.auth.user.user_authorizer import (
UserAuthorizer,
depends_user_authorizer,
)
from server.config import sign_token
from server.constants import IS_FEATURE_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
@@ -40,6 +42,7 @@ from storage.database import a_session_maker
from storage.user import User
from storage.user_store import UserStore
from openhands.app_server.config import get_global_config
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import ProviderType, TokenResponse
@@ -157,11 +160,16 @@ async def keycloak_callback(
state: Optional[str] = None,
error: Optional[str] = None,
error_description: Optional[str] = None,
user_authorizer: UserAuthorizer = depends_user_authorizer(),
):
# Extract redirect URL, reCAPTCHA token, and invitation token from state
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
if not redirect_url:
redirect_url = str(request.base_url)
if redirect_url is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Missing state in request params',
)
if not code:
# check if this is a forward from the account linking page
@@ -170,36 +178,40 @@ async def keycloak_callback(
and error_description == 'authentication_expired'
):
return RedirectResponse(redirect_url, status_code=302)
return JSONResponse(
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
content={'error': 'Missing code in request params'},
detail='Missing code in request params',
)
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
redirect_uri = f'{scheme}://{request.url.netloc}{request.url.path}'
logger.debug(f'code: {code}, redirect_uri: {redirect_uri}')
web_url = get_global_config().web_url
if not web_url:
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
web_url = f'{scheme}://{request.url.netloc}'
redirect_uri = web_url + request.url.path
(
keycloak_access_token,
keycloak_refresh_token,
) = await token_manager.get_keycloak_tokens(code, redirect_uri)
if not keycloak_access_token or not keycloak_refresh_token:
return JSONResponse(
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
content={'error': 'Problem retrieving Keycloak tokens'},
detail='Problem retrieving Keycloak tokens',
)
user_info = await token_manager.get_user_info(keycloak_access_token)
logger.debug(f'user_info: {user_info}')
if ROLE_CHECK_ENABLED and user_info.roles is None:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Missing required role'},
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail='Missing required role'
)
if user_info.preferred_username is None:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={'error': 'Missing user ID or username in response'},
authorization = await user_authorizer.authorize_user(user_info)
if not authorization.success:
# Return unauthorized
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=authorization.error_detail,
)
email = user_info.email
@@ -214,12 +226,10 @@ async def keycloak_callback(
await UserStore.backfill_user_email(user_id, user_info_dict)
if not user:
logger.error(f'Failed to authenticate user {user_info.preferred_username}')
return JSONResponse(
logger.error(f'Failed to authenticate user {user_info.email}')
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': f'Failed to authenticate user {user_info.preferred_username}'
},
detail=f'Failed to authenticate user {user_info.email}',
)
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
@@ -234,7 +244,7 @@ async def keycloak_callback(
'email': email,
},
)
error_url = f'{request.base_url}login?recaptcha_blocked=true'
error_url = f'{web_url}/login?recaptcha_blocked=true'
return RedirectResponse(error_url, status_code=302)
user_ip = request.client.host if request.client else 'unknown'
@@ -265,65 +275,13 @@ async def keycloak_callback(
},
)
# Redirect to home with error parameter
error_url = f'{request.base_url}login?recaptcha_blocked=true'
error_url = f'{web_url}/login?recaptcha_blocked=true'
return RedirectResponse(error_url, status_code=302)
except Exception as e:
logger.exception(f'reCAPTCHA verification error at callback: {e}')
# Fail open - continue with login if reCAPTCHA service unavailable
# Check if email domain is blocked
if email and await domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
)
# Disable the Keycloak account
await token_manager.disable_keycloak_user(user_id, email)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': 'Access denied: Your email domain is not allowed to access this service'
},
)
# Check for duplicate email with + modifier
if email:
try:
has_duplicate = await token_manager.check_duplicate_base_email(
email, user_id
)
if has_duplicate:
logger.warning(
f'Blocked signup attempt for email {email} - duplicate base email found',
extra={'user_id': user_id, 'email': email},
)
# Delete the Keycloak user that was automatically created during OAuth
# This prevents orphaned accounts in Keycloak
# The delete_keycloak_user method already handles all errors internally
deletion_success = await token_manager.delete_keycloak_user(user_id)
if deletion_success:
logger.info(
f'Deleted Keycloak user {user_id} after detecting duplicate email {email}'
)
else:
logger.warning(
f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. '
f'User may need to be manually cleaned up.'
)
# Redirect to home page with query parameter indicating the issue
home_url = f'{request.base_url}/login?duplicated_email=true'
return RedirectResponse(home_url, status_code=302)
except Exception as e:
# Log error but allow signup to proceed (fail open)
logger.error(
f'Error checking duplicate email for {email}: {e}',
extra={'user_id': user_id, 'email': email},
)
# Check email verification status
email_verified = user_info.email_verified or False
if not email_verified:
@@ -358,6 +316,7 @@ async def keycloak_callback(
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
if rate_limited:
verification_redirect_url = f'{verification_redirect_url}&rate_limited=true'
# Preserve invitation token so it can be included in OAuth state after verification
if invitation_token:
verification_redirect_url = (
@@ -379,13 +338,6 @@ async def keycloak_callback(
ProviderType(idp), user_id, keycloak_access_token
)
username = user_info.preferred_username
if user_verifier.is_active() and not user_verifier.is_user_allowed(username):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Not authorized via waitlist'},
)
valid_offline_token = (
await token_manager.validate_offline_token(user_id=user_info.sub)
if idp_type != 'saml'
@@ -431,13 +383,19 @@ async def keycloak_callback(
)
if not valid_offline_token:
param_str = urlencode(
{
'client_id': KEYCLOAK_CLIENT_ID,
'response_type': 'code',
'kc_idp_hint': idp,
'redirect_uri': f'{web_url}/oauth/keycloak/offline/callback',
'scope': 'openid email profile offline_access',
'state': state,
}
)
redirect_url = (
f'{KEYCLOAK_SERVER_URL_EXT}/realms/{KEYCLOAK_REALM_NAME}/protocol/openid-connect/auth'
f'?client_id={KEYCLOAK_CLIENT_ID}&response_type=code'
f'&kc_idp_hint={idp}'
f'&redirect_uri={scheme}%3A%2F%2F{request.url.netloc}%2Foauth%2Fkeycloak%2Foffline%2Fcallback'
f'&scope=openid%20email%20profile%20offline_access'
f'&state={state}'
f'?{param_str}'
)
has_accepted_tos = user.accepted_tos is not None
@@ -532,7 +490,7 @@ async def keycloak_callback(
response=response,
keycloak_access_token=keycloak_access_token,
keycloak_refresh_token=keycloak_refresh_token,
secure=True if scheme == 'https' else False,
secure=True if redirect_url.startswith('https') else False,
accepted_tos=has_accepted_tos,
)

View File

@@ -1,30 +0,0 @@
from datetime import UTC, datetime
from sqlalchemy import Column, DateTime, Identity, Integer, String
from storage.base import Base
class BlockedEmailDomain(Base): # type: ignore
"""Stores blocked email domain patterns.
Supports blocking:
- Exact domains: 'example.com' blocks 'user@example.com'
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
"""
__tablename__ = 'blocked_email_domains'
id = Column(Integer, Identity(), primary_key=True)
domain = Column(String, nullable=False, unique=True)
created_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
nullable=False,
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
nullable=False,
)

View File

@@ -1,43 +0,0 @@
from dataclasses import dataclass
from sqlalchemy import text
from storage.database import a_session_maker
@dataclass
class BlockedEmailDomainStore:
async def is_domain_blocked(self, domain: str) -> bool:
"""Check if a domain is blocked by querying the database directly.
This method uses SQL to efficiently check if the domain matches any blocked pattern:
- TLD patterns (e.g., '.us'): checks if domain ends with the pattern
- Full domain patterns (e.g., 'example.com'): checks for exact match or subdomain match
Args:
domain: The extracted domain from the email (e.g., 'example.com' or 'subdomain.example.com')
Returns:
True if the domain is blocked, False otherwise
"""
async with a_session_maker() as session:
# SQL query that handles both TLD patterns and full domain patterns
# TLD patterns (starting with '.'): check if domain ends with it (case-insensitive)
# Full domain patterns: check for exact match or subdomain match
# All comparisons are case-insensitive using LOWER() to ensure consistent matching
query = text("""
SELECT EXISTS(
SELECT 1
FROM blocked_email_domains
WHERE
-- TLD pattern (e.g., '.us') - check if domain ends with it (case-insensitive)
(LOWER(domain) LIKE '.%' AND LOWER(:domain) LIKE '%' || LOWER(domain)) OR
-- Full domain pattern (e.g., 'example.com')
-- Block exact match or subdomains (case-insensitive)
(LOWER(domain) NOT LIKE '.%' AND (
LOWER(:domain) = LOWER(domain) OR
LOWER(:domain) LIKE '%.' || LOWER(domain)
))
)
""")
result = await session.execute(query, {'domain': domain})
return bool(result.scalar())

View File

@@ -0,0 +1,45 @@
"""User authorization model for managing email/provider based access control."""
from datetime import UTC, datetime
from enum import Enum
from sqlalchemy import Column, DateTime, Identity, Integer, String
from storage.base import Base
class UserAuthorizationType(str, Enum):
"""Type of user authorization rule."""
WHITELIST = 'whitelist'
BLACKLIST = 'blacklist'
class UserAuthorization(Base): # type: ignore
"""Stores user authorization rules based on email patterns and provider types.
Supports:
- Email pattern matching using SQL LIKE (e.g., '%@openhands.dev')
- Provider type filtering (e.g., 'github', 'gitlab')
- Whitelist/Blacklist rules
When email_pattern is NULL, the rule matches all emails.
When provider_type is NULL, the rule matches all providers.
"""
__tablename__ = 'user_authorizations'
id = Column(Integer, Identity(), primary_key=True)
email_pattern = Column(String, nullable=True)
provider_type = Column(String, nullable=True)
type = Column(String, nullable=False)
created_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
nullable=False,
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
nullable=False,
)

View File

@@ -0,0 +1,203 @@
"""Store class for managing user authorizations."""
from typing import Optional
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from storage.database import a_session_maker
from storage.user_authorization import UserAuthorization, UserAuthorizationType
class UserAuthorizationStore:
"""Store for managing user authorization rules."""
@staticmethod
async def _get_matching_authorizations(
email: str,
provider_type: str | None,
session: AsyncSession,
) -> list[UserAuthorization]:
"""Get all authorization rules that match the given email and provider.
Uses SQL LIKE for pattern matching:
- email_pattern is NULL matches all emails
- provider_type is NULL matches all providers
- email LIKE email_pattern for pattern matching
Args:
email: The user's email address
provider_type: The identity provider type (e.g., 'github', 'gitlab')
session: Database session
Returns:
List of matching UserAuthorization objects
"""
# Build query using SQLAlchemy ORM
# We need: (email_pattern IS NULL OR LOWER(email) LIKE LOWER(email_pattern))
# AND (provider_type IS NULL OR provider_type = :provider_type)
email_condition = or_(
UserAuthorization.email_pattern.is_(None),
func.lower(email).like(func.lower(UserAuthorization.email_pattern)),
)
provider_condition = or_(
UserAuthorization.provider_type.is_(None),
UserAuthorization.provider_type == provider_type,
)
query = select(UserAuthorization).where(email_condition, provider_condition)
result = await session.execute(query)
return list(result.scalars().all())
@staticmethod
async def get_matching_authorizations(
email: str,
provider_type: str | None,
session: Optional[AsyncSession] = None,
) -> list[UserAuthorization]:
"""Get all authorization rules that match the given email and provider.
Args:
email: The user's email address
provider_type: The identity provider type (e.g., 'github', 'gitlab')
session: Optional database session
Returns:
List of matching UserAuthorization objects
"""
if session is not None:
return await UserAuthorizationStore._get_matching_authorizations(
email, provider_type, session
)
async with a_session_maker() as new_session:
return await UserAuthorizationStore._get_matching_authorizations(
email, provider_type, new_session
)
@staticmethod
async def get_authorization_type(
email: str,
provider_type: str | None,
session: Optional[AsyncSession] = None,
) -> UserAuthorizationType | None:
"""Get the authorization type for the given email and provider.
Checks matching authorization rules and returns the effective authorization
type. Whitelist rules take precedence over blacklist rules.
Args:
email: The user's email address
provider_type: The identity provider type (e.g., 'github', 'gitlab')
session: Optional database session
Returns:
UserAuthorizationType.WHITELIST if a whitelist rule matches,
UserAuthorizationType.BLACKLIST if a blacklist rule matches (and no whitelist),
None if no rules match
"""
authorizations = await UserAuthorizationStore.get_matching_authorizations(
email, provider_type, session
)
has_whitelist = any(
auth.type == UserAuthorizationType.WHITELIST.value
for auth in authorizations
)
if has_whitelist:
return UserAuthorizationType.WHITELIST
has_blacklist = any(
auth.type == UserAuthorizationType.BLACKLIST.value
for auth in authorizations
)
if has_blacklist:
return UserAuthorizationType.BLACKLIST
return None
@staticmethod
async def _create_authorization(
email_pattern: str | None,
provider_type: str | None,
auth_type: UserAuthorizationType,
session: AsyncSession,
) -> UserAuthorization:
"""Create a new user authorization rule."""
authorization = UserAuthorization(
email_pattern=email_pattern,
provider_type=provider_type,
type=auth_type.value,
)
session.add(authorization)
await session.flush()
await session.refresh(authorization)
return authorization
@staticmethod
async def create_authorization(
email_pattern: str | None,
provider_type: str | None,
auth_type: UserAuthorizationType,
session: Optional[AsyncSession] = None,
) -> UserAuthorization:
"""Create a new user authorization rule.
Args:
email_pattern: SQL LIKE pattern for email matching (e.g., '%@openhands.dev')
provider_type: Provider type to match (e.g., 'github'), or None for all
auth_type: WHITELIST or BLACKLIST
session: Optional database session
Returns:
The created UserAuthorization object
"""
if session is not None:
return await UserAuthorizationStore._create_authorization(
email_pattern, provider_type, auth_type, session
)
async with a_session_maker() as new_session:
auth = await UserAuthorizationStore._create_authorization(
email_pattern, provider_type, auth_type, new_session
)
await new_session.commit()
return auth
@staticmethod
async def _delete_authorization(
authorization_id: int,
session: AsyncSession,
) -> bool:
"""Delete an authorization rule by ID."""
result = await session.execute(
select(UserAuthorization).where(UserAuthorization.id == authorization_id)
)
authorization = result.scalars().first()
if authorization:
await session.delete(authorization)
return True
return False
@staticmethod
async def delete_authorization(
authorization_id: int,
session: Optional[AsyncSession] = None,
) -> bool:
"""Delete an authorization rule by ID.
Args:
authorization_id: The ID of the authorization to delete
session: Optional database session
Returns:
True if deleted, False if not found
"""
if session is not None:
return await UserAuthorizationStore._delete_authorization(
authorization_id, session
)
async with a_session_maker() as new_session:
deleted = await UserAuthorizationStore._delete_authorization(
authorization_id, new_session
)
if deleted:
await new_session.commit()
return deleted

View File

@@ -0,0 +1,635 @@
"""Unit tests for UserAuthorizationStore using SQLite in-memory database."""
from unittest.mock import patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.user_authorization import UserAuthorization, UserAuthorizationType
from storage.user_authorization_store import UserAuthorizationStore
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
)
return engine
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker bound to the async engine."""
session_maker = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return session_maker
class TestGetMatchingAuthorizations:
"""Tests for get_matching_authorizations method."""
@pytest.mark.asyncio
async def test_no_authorizations_returns_empty_list(self, async_session_maker):
"""Test returns empty list when no authorizations exist."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
result = await UserAuthorizationStore.get_matching_authorizations(
email='test@example.com',
provider_type='github',
)
assert result == []
@pytest.mark.asyncio
async def test_exact_email_pattern_match(self, async_session_maker):
"""Test matching with exact email pattern."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
# Create a whitelist rule for exact email
await UserAuthorizationStore.create_authorization(
email_pattern='test@example.com',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
)
result = await UserAuthorizationStore.get_matching_authorizations(
email='test@example.com',
provider_type='github',
)
assert len(result) == 1
assert result[0].email_pattern == 'test@example.com'
@pytest.mark.asyncio
async def test_domain_suffix_pattern_match(self, async_session_maker):
"""Test matching with domain suffix pattern (e.g., %@example.com)."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
# Create a whitelist rule for domain
await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
)
# Should match
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@example.com',
provider_type='github',
)
assert len(result) == 1
# Should also match different user
result = await UserAuthorizationStore.get_matching_authorizations(
email='another.user@example.com',
provider_type='github',
)
assert len(result) == 1
# Should not match different domain
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@other.com',
provider_type='github',
)
assert len(result) == 0
@pytest.mark.asyncio
async def test_null_email_pattern_matches_all_emails(self, async_session_maker):
"""Test that NULL email_pattern matches all emails."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
# Create a rule with NULL email pattern
await UserAuthorizationStore.create_authorization(
email_pattern=None,
provider_type='github',
auth_type=UserAuthorizationType.BLACKLIST,
)
# Should match any email with github provider
result = await UserAuthorizationStore.get_matching_authorizations(
email='any@email.com',
provider_type='github',
)
assert len(result) == 1
# Should not match different provider
result = await UserAuthorizationStore.get_matching_authorizations(
email='any@email.com',
provider_type='gitlab',
)
assert len(result) == 0
@pytest.mark.asyncio
async def test_null_provider_type_matches_all_providers(self, async_session_maker):
"""Test that NULL provider_type matches all providers."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
# Create a rule with NULL provider type
await UserAuthorizationStore.create_authorization(
email_pattern='%@blocked.com',
provider_type=None,
auth_type=UserAuthorizationType.BLACKLIST,
)
# Should match any provider
for provider in ['github', 'gitlab', 'bitbucket', None]:
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@blocked.com',
provider_type=provider,
)
assert len(result) == 1
@pytest.mark.asyncio
async def test_provider_type_filter(self, async_session_maker):
"""Test filtering by provider type."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
# Create rules for different providers
await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type='github',
auth_type=UserAuthorizationType.WHITELIST,
)
await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type='gitlab',
auth_type=UserAuthorizationType.BLACKLIST,
)
# Check github
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@example.com',
provider_type='github',
)
assert len(result) == 1
assert result[0].type == UserAuthorizationType.WHITELIST.value
# Check gitlab
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@example.com',
provider_type='gitlab',
)
assert len(result) == 1
assert result[0].type == UserAuthorizationType.BLACKLIST.value
@pytest.mark.asyncio
async def test_case_insensitive_email_matching(self, async_session_maker):
"""Test that email matching is case insensitive."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
await UserAuthorizationStore.create_authorization(
email_pattern='%@Example.COM',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
)
# Should match regardless of case
result = await UserAuthorizationStore.get_matching_authorizations(
email='USER@example.com',
provider_type='github',
)
assert len(result) == 1
@pytest.mark.asyncio
async def test_multiple_matching_rules(self, async_session_maker):
"""Test that multiple matching rules are returned."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
# Create multiple rules that match
await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
)
await UserAuthorizationStore.create_authorization(
email_pattern=None, # Matches all emails
provider_type='github',
auth_type=UserAuthorizationType.BLACKLIST,
)
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@example.com',
provider_type='github',
)
assert len(result) == 2
@pytest.mark.asyncio
async def test_with_provided_session(self, async_session_maker):
"""Test using a provided session instead of creating one."""
async with async_session_maker() as session:
# Create authorization within session
auth = UserAuthorization(
email_pattern='%@test.com',
provider_type=None,
type=UserAuthorizationType.WHITELIST.value,
)
session.add(auth)
await session.flush()
# Query within same session
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@test.com',
provider_type='github',
session=session,
)
assert len(result) == 1
class TestGetAuthorizationType:
"""Tests for get_authorization_type method."""
@pytest.mark.asyncio
async def test_returns_whitelist_when_whitelist_match_exists(
self, async_session_maker
):
"""Test returns WHITELIST when a whitelist rule matches."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
await UserAuthorizationStore.create_authorization(
email_pattern='%@allowed.com',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
)
result = await UserAuthorizationStore.get_authorization_type(
email='user@allowed.com',
provider_type='github',
)
assert result == UserAuthorizationType.WHITELIST
@pytest.mark.asyncio
async def test_returns_blacklist_when_blacklist_match_exists(
self, async_session_maker
):
"""Test returns BLACKLIST when a blacklist rule matches."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
await UserAuthorizationStore.create_authorization(
email_pattern='%@blocked.com',
provider_type=None,
auth_type=UserAuthorizationType.BLACKLIST,
)
result = await UserAuthorizationStore.get_authorization_type(
email='user@blocked.com',
provider_type='github',
)
assert result == UserAuthorizationType.BLACKLIST
@pytest.mark.asyncio
async def test_returns_none_when_no_rules_exist(self, async_session_maker):
"""Test returns None when no authorization rules exist."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
result = await UserAuthorizationStore.get_authorization_type(
email='user@example.com',
provider_type='github',
)
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_only_non_matching_rules_exist(
self, async_session_maker
):
"""Test returns None when rules exist but don't match."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
await UserAuthorizationStore.create_authorization(
email_pattern='%@other.com',
provider_type=None,
auth_type=UserAuthorizationType.BLACKLIST,
)
result = await UserAuthorizationStore.get_authorization_type(
email='user@example.com',
provider_type='github',
)
assert result is None
@pytest.mark.asyncio
async def test_whitelist_takes_precedence_over_blacklist(self, async_session_maker):
"""Test whitelist takes precedence when both match."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
# Create both whitelist and blacklist rules that match
await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type=None,
auth_type=UserAuthorizationType.BLACKLIST,
)
await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type='github',
auth_type=UserAuthorizationType.WHITELIST,
)
result = await UserAuthorizationStore.get_authorization_type(
email='user@example.com',
provider_type='github',
)
assert result == UserAuthorizationType.WHITELIST
@pytest.mark.asyncio
async def test_returns_blacklist_for_domain_block(self, async_session_maker):
"""Test blacklist match for domain-based blocking."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
await UserAuthorizationStore.create_authorization(
email_pattern='%@disposable-email.com',
provider_type=None,
auth_type=UserAuthorizationType.BLACKLIST,
)
result = await UserAuthorizationStore.get_authorization_type(
email='spammer@disposable-email.com',
provider_type='github',
)
assert result == UserAuthorizationType.BLACKLIST
class TestCreateAuthorization:
"""Tests for create_authorization method."""
@pytest.mark.asyncio
async def test_creates_whitelist_authorization(self, async_session_maker):
"""Test creating a whitelist authorization."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
auth = await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type='github',
auth_type=UserAuthorizationType.WHITELIST,
)
assert auth.id is not None
assert auth.email_pattern == '%@example.com'
assert auth.provider_type == 'github'
assert auth.type == UserAuthorizationType.WHITELIST.value
assert auth.created_at is not None
assert auth.updated_at is not None
@pytest.mark.asyncio
async def test_creates_blacklist_authorization(self, async_session_maker):
"""Test creating a blacklist authorization."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
auth = await UserAuthorizationStore.create_authorization(
email_pattern='%@blocked.com',
provider_type=None,
auth_type=UserAuthorizationType.BLACKLIST,
)
assert auth.id is not None
assert auth.email_pattern == '%@blocked.com'
assert auth.provider_type is None
assert auth.type == UserAuthorizationType.BLACKLIST.value
@pytest.mark.asyncio
async def test_creates_with_null_email_pattern(self, async_session_maker):
"""Test creating authorization with NULL email pattern."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
auth = await UserAuthorizationStore.create_authorization(
email_pattern=None,
provider_type='github',
auth_type=UserAuthorizationType.WHITELIST,
)
assert auth.email_pattern is None
assert auth.provider_type == 'github'
@pytest.mark.asyncio
async def test_creates_with_provided_session(self, async_session_maker):
"""Test creating authorization with a provided session."""
async with async_session_maker() as session:
auth = await UserAuthorizationStore.create_authorization(
email_pattern='%@test.com',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
session=session,
)
assert auth.id is not None
# Verify it exists in session
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@test.com',
provider_type='github',
session=session,
)
assert len(result) == 1
class TestDeleteAuthorization:
"""Tests for delete_authorization method."""
@pytest.mark.asyncio
async def test_deletes_existing_authorization(self, async_session_maker):
"""Test deleting an existing authorization."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
# Create an authorization
auth = await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
)
# Delete it
deleted = await UserAuthorizationStore.delete_authorization(auth.id)
assert deleted is True
# Verify it's gone
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@example.com',
provider_type='github',
)
assert len(result) == 0
@pytest.mark.asyncio
async def test_returns_false_for_nonexistent_authorization(
self, async_session_maker
):
"""Test returns False when authorization doesn't exist."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
deleted = await UserAuthorizationStore.delete_authorization(99999)
assert deleted is False
@pytest.mark.asyncio
async def test_deletes_with_provided_session(self, async_session_maker):
"""Test deleting authorization with a provided session."""
async with async_session_maker() as session:
# Create an authorization
auth = await UserAuthorizationStore.create_authorization(
email_pattern='%@test.com',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
session=session,
)
auth_id = auth.id
# Flush to persist to database before delete
await session.flush()
# Delete within same session
deleted = await UserAuthorizationStore.delete_authorization(
auth_id, session=session
)
assert deleted is True
# Flush delete to database
await session.flush()
# Verify it's gone
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@test.com',
provider_type='github',
session=session,
)
assert len(result) == 0
class TestPatternMatchingEdgeCases:
"""Tests for edge cases in pattern matching."""
@pytest.mark.asyncio
async def test_wildcard_prefix_pattern(self, async_session_maker):
"""Test pattern with wildcard prefix (e.g., admin%)."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
await UserAuthorizationStore.create_authorization(
email_pattern='admin%',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
)
# Should match
result = await UserAuthorizationStore.get_matching_authorizations(
email='admin@example.com',
provider_type='github',
)
assert len(result) == 1
# Should also match
result = await UserAuthorizationStore.get_matching_authorizations(
email='administrator@example.com',
provider_type='github',
)
assert len(result) == 1
# Should not match
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@admin.com',
provider_type='github',
)
assert len(result) == 0
@pytest.mark.asyncio
async def test_single_character_wildcard(self, async_session_maker):
"""Test pattern with single character wildcard (underscore in SQL LIKE)."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
await UserAuthorizationStore.create_authorization(
email_pattern='user_@example.com',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
)
# Should match user1@example.com
result = await UserAuthorizationStore.get_matching_authorizations(
email='user1@example.com',
provider_type='github',
)
assert len(result) == 1
# Should not match user12@example.com
result = await UserAuthorizationStore.get_matching_authorizations(
email='user12@example.com',
provider_type='github',
)
assert len(result) == 0
@pytest.mark.asyncio
async def test_email_with_plus_sign(self, async_session_maker):
"""Test matching emails with plus signs (common for email aliases)."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type=None,
auth_type=UserAuthorizationType.WHITELIST,
)
result = await UserAuthorizationStore.get_matching_authorizations(
email='user+alias@example.com',
provider_type='github',
)
assert len(result) == 1
@pytest.mark.asyncio
async def test_subdomain_email(self, async_session_maker):
"""Test that subdomain emails don't match parent domain patterns."""
with patch(
'storage.user_authorization_store.a_session_maker', async_session_maker
):
await UserAuthorizationStore.create_authorization(
email_pattern='%@example.com',
provider_type=None,
auth_type=UserAuthorizationType.BLACKLIST,
)
# Should match exact domain
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@example.com',
provider_type='github',
)
assert len(result) == 1
# Should NOT match subdomain
result = await UserAuthorizationStore.get_matching_authorizations(
email='user@sub.example.com',
provider_type='github',
)
assert len(result) == 0

File diff suppressed because it is too large Load Diff

View File

@@ -1,429 +0,0 @@
"""Unit tests for DomainBlocker class."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from server.auth.domain_blocker import DomainBlocker
@pytest.fixture
def mock_store():
"""Create a mock BlockedEmailDomainStore for testing."""
store = MagicMock()
store.is_domain_blocked = AsyncMock()
return store
@pytest.fixture
def domain_blocker(mock_store):
"""Create a DomainBlocker instance for testing with a mocked store."""
return DomainBlocker(store=mock_store)
@pytest.mark.parametrize(
'email,expected_domain',
[
('user@example.com', 'example.com'),
('test@colsch.us', 'colsch.us'),
('user.name@other-domain.com', 'other-domain.com'),
('USER@EXAMPLE.COM', 'example.com'), # Case insensitive
('user@EXAMPLE.COM', 'example.com'),
(' user@example.com ', 'example.com'), # Whitespace handling
],
)
def test_extract_domain_valid_emails(domain_blocker, email, expected_domain):
"""Test that _extract_domain correctly extracts and normalizes domains from valid emails."""
# Act
result = domain_blocker._extract_domain(email)
# Assert
assert result == expected_domain
@pytest.mark.parametrize(
'email,expected',
[
(None, None),
('', None),
('invalid-email', None),
('user@', None), # Empty domain after @
('no-at-sign', None),
],
)
def test_extract_domain_invalid_emails(domain_blocker, email, expected):
"""Test that _extract_domain returns None for invalid email formats."""
# Act
result = domain_blocker._extract_domain(email)
# Assert
assert result == expected
@pytest.mark.asyncio
async def test_is_domain_blocked_with_none_email(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when email is None."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked(None)
# Assert
assert result is False
mock_store.is_domain_blocked.assert_not_called()
@pytest.mark.asyncio
async def test_is_domain_blocked_with_empty_email(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when email is empty."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_not_called()
@pytest.mark.asyncio
async def test_is_domain_blocked_with_invalid_email(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when email format is invalid."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('invalid-email')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_not_called()
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_not_blocked(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when domain is not blocked."""
# Arrange
mock_store.is_domain_blocked.return_value = False
# Act
result = await domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('example.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_blocked(domain_blocker, mock_store):
"""Test that is_domain_blocked returns True when domain is blocked."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@colsch.us')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_case_insensitive(domain_blocker, mock_store):
"""Test that is_domain_blocked performs case-insensitive domain extraction."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@COLSCH.US')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_with_whitespace(domain_blocker, mock_store):
"""Test that is_domain_blocked handles emails with whitespace correctly."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked(' user@colsch.us ')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_store):
"""Test that is_domain_blocked correctly checks multiple domains."""
# Arrange
mock_store.is_domain_blocked = AsyncMock(
side_effect=lambda domain: domain
in [
'other-domain.com',
'blocked.org',
]
)
# Act
result1 = await domain_blocker.is_domain_blocked('user@other-domain.com')
result2 = await domain_blocker.is_domain_blocked('user@blocked.org')
result3 = await domain_blocker.is_domain_blocked('user@allowed.com')
# Assert
assert result1 is True
assert result2 is True
assert result3 is False
assert mock_store.is_domain_blocked.call_count == 3
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
domain_blocker, mock_store
):
"""Test that TLD pattern blocks domains ending with that TLD."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@company.us')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('company.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
domain_blocker, mock_store
):
"""Test that TLD pattern blocks subdomains with that TLD."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@subdomain.company.us')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('subdomain.company.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
domain_blocker, mock_store
):
"""Test that TLD pattern does not block domains with different TLD."""
# Arrange
mock_store.is_domain_blocked.return_value = False
# Act
result = await domain_blocker.is_domain_blocked('user@company.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('company.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_case_insensitive(
domain_blocker, mock_store
):
"""Test that TLD pattern matching is case-insensitive."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@COMPANY.US')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('company.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_with_multi_level_tld(
domain_blocker, mock_store
):
"""Test that TLD pattern works with multi-level TLDs like .co.uk."""
# Arrange
mock_store.is_domain_blocked.side_effect = lambda domain: domain.endswith('.co.uk')
# Act
result_match = await domain_blocker.is_domain_blocked('user@example.co.uk')
result_subdomain = await domain_blocker.is_domain_blocked('user@api.example.co.uk')
result_no_match = await domain_blocker.is_domain_blocked('user@example.uk')
# Assert
assert result_match is True
assert result_subdomain is True
assert result_no_match is False
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_blocks_exact_match(
domain_blocker, mock_store
):
"""Test that domain pattern blocks exact domain match."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('example.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_blocks_subdomain(
domain_blocker, mock_store
):
"""Test that domain pattern blocks subdomains of that domain."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@subdomain.example.com')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('subdomain.example.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
domain_blocker, mock_store
):
"""Test that domain pattern blocks multi-level subdomains."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@api.v2.example.com')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('api.v2.example.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
domain_blocker, mock_store
):
"""Test that domain pattern does not block domains that contain but don't match the pattern."""
# Arrange
mock_store.is_domain_blocked.return_value = False
# Act
result = await domain_blocker.is_domain_blocked('user@notexample.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('notexample.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
domain_blocker, mock_store
):
"""Test that domain pattern does not block same domain with different TLD."""
# Arrange
mock_store.is_domain_blocked.return_value = False
# Act
result = await domain_blocker.is_domain_blocked('user@example.org')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('example.org')
@pytest.mark.asyncio
async def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
domain_blocker, mock_store
):
"""Test that blocking a subdomain also blocks its nested subdomains."""
# Arrange
mock_store.is_domain_blocked.side_effect = (
lambda domain: 'api.example.com' in domain
)
# Act
result_exact = await domain_blocker.is_domain_blocked('user@api.example.com')
result_nested = await domain_blocker.is_domain_blocked('user@v1.api.example.com')
result_parent = await domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result_exact is True
assert result_nested is True
assert result_parent is False
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store):
"""Test that domain patterns work with hyphenated domains."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result_exact = await domain_blocker.is_domain_blocked('user@my-company.com')
result_subdomain = await domain_blocker.is_domain_blocked('user@api.my-company.com')
# Assert
assert result_exact is True
assert result_subdomain is True
assert mock_store.is_domain_blocked.call_count == 2
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store):
"""Test that domain patterns work with numeric domains."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result_exact = await domain_blocker.is_domain_blocked('user@test123.com')
result_subdomain = await domain_blocker.is_domain_blocked('user@api.test123.com')
# Assert
assert result_exact is True
assert result_subdomain is True
assert mock_store.is_domain_blocked.call_count == 2
@pytest.mark.asyncio
async def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_store):
"""Test that blocking works with very long subdomain chains."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked(
'user@level4.level3.level2.level1.example.com'
)
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with(
'level4.level3.level2.level1.example.com'
)
@pytest.mark.asyncio
async def test_is_domain_blocked_handles_store_exception(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when store raises an exception."""
# Arrange
mock_store.is_domain_blocked.side_effect = Exception('Database connection error')
# Act
result = await domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('example.com')

View File

@@ -18,6 +18,7 @@ from server.auth.saas_user_auth import (
saas_user_auth_from_cookie,
saas_user_auth_from_signed_token,
)
from storage.user_authorization import UserAuthorizationType
from openhands.integrations.provider import ProviderToken, ProviderType
@@ -493,14 +494,20 @@ async def test_saas_user_auth_from_signed_token(mock_config):
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
result = await saas_user_auth_from_signed_token(signed_token)
# Mock UserAuthorizationStore to avoid database access
with patch(
'server.auth.saas_user_auth.UserAuthorizationStore'
) as mock_user_auth_store:
mock_user_auth_store.get_authorization_type = AsyncMock(return_value=None)
assert isinstance(result, SaasUserAuth)
assert result.user_id == 'test_user_id'
assert result.access_token.get_secret_value() == access_token
assert result.refresh_token.get_secret_value() == 'test_refresh_token'
assert result.email == 'test@example.com'
assert result.email_verified is True
result = await saas_user_auth_from_signed_token(signed_token)
assert isinstance(result, SaasUserAuth)
assert result.user_id == 'test_user_id'
assert result.access_token.get_secret_value() == access_token
assert result.refresh_token.get_secret_value() == 'test_refresh_token'
assert result.email == 'test@example.com'
assert result.email_verified is True
def test_get_api_key_from_header_with_authorization_header():
@@ -701,15 +708,21 @@ async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=True)
with patch(
'server.auth.saas_user_auth.UserAuthorizationStore'
) as mock_user_auth_store:
mock_user_auth_store.get_authorization_type = AsyncMock(
return_value=UserAuthorizationType.BLACKLIST
)
# Act & Assert
with pytest.raises(AuthError) as exc_info:
await saas_user_auth_from_signed_token(signed_token)
assert 'email domain is not allowed' in str(exc_info.value)
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
mock_user_auth_store.get_authorization_type.assert_called_once_with(
'user@colsch.us', None
)
@pytest.mark.asyncio
@@ -730,8 +743,10 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
with patch(
'server.auth.saas_user_auth.UserAuthorizationStore'
) as mock_user_auth_store:
mock_user_auth_store.get_authorization_type = AsyncMock(return_value=None)
# Act
result = await saas_user_auth_from_signed_token(signed_token)
@@ -740,8 +755,8 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
assert isinstance(result, SaasUserAuth)
assert result.user_id == 'test_user_id'
assert result.email == 'user@example.com'
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
'user@example.com'
mock_user_auth_store.get_authorization_type.assert_called_once_with(
'user@example.com', None
)
@@ -763,8 +778,10 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
with patch(
'server.auth.saas_user_auth.UserAuthorizationStore'
) as mock_user_auth_store:
mock_user_auth_store.get_authorization_type = AsyncMock(return_value=None)
# Act
result = await saas_user_auth_from_signed_token(signed_token)
@@ -772,4 +789,6 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co
# Assert
assert isinstance(result, SaasUserAuth)
assert result.user_id == 'test_user_id'
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
mock_user_auth_store.get_authorization_type.assert_called_once_with(
'user@colsch.us', None
)