diff --git a/enterprise/migrations/versions/094_create_org_invitation_table.py b/enterprise/migrations/versions/094_create_org_invitation_table.py new file mode 100644 index 0000000000..3dc6a2f89a --- /dev/null +++ b/enterprise/migrations/versions/094_create_org_invitation_table.py @@ -0,0 +1,110 @@ +"""create org_invitation table + +Revision ID: 094 +Revises: 093 +Create Date: 2026-02-18 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '094' +down_revision: Union[str, None] = '093' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create org_invitation table + op.create_table( + 'org_invitation', + sa.Column('id', sa.Integer, sa.Identity(), primary_key=True), + sa.Column('token', sa.String(64), nullable=False), + sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('email', sa.String(255), nullable=False), + sa.Column('role_id', sa.Integer, nullable=False), + sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column( + 'status', + sa.String(20), + nullable=False, + server_default=sa.text("'pending'"), + ), + sa.Column( + 'created_at', + sa.DateTime, + nullable=False, + server_default=sa.text('CURRENT_TIMESTAMP'), + ), + sa.Column('expires_at', sa.DateTime, nullable=False), + sa.Column('accepted_at', sa.DateTime, nullable=True), + sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True), + # Foreign key constraints + sa.ForeignKeyConstraint( + ['org_id'], + ['org.id'], + name='org_invitation_org_fkey', + ondelete='CASCADE', + ), + sa.ForeignKeyConstraint( + ['role_id'], + ['role.id'], + name='org_invitation_role_fkey', + ), + sa.ForeignKeyConstraint( + ['inviter_id'], + ['user.id'], + name='org_invitation_inviter_fkey', + ), + sa.ForeignKeyConstraint( + ['accepted_by_user_id'], + ['user.id'], + name='org_invitation_accepter_fkey', + ), + ) + + # Create indexes + op.create_index( + 'ix_org_invitation_token', + 'org_invitation', + ['token'], + unique=True, + ) + op.create_index( + 'ix_org_invitation_org_id', + 'org_invitation', + ['org_id'], + ) + op.create_index( + 'ix_org_invitation_email', + 'org_invitation', + ['email'], + ) + op.create_index( + 'ix_org_invitation_status', + 'org_invitation', + ['status'], + ) + # Composite index for checking pending invitations + op.create_index( + 'ix_org_invitation_org_email_status', + 'org_invitation', + ['org_id', 'email', 'status'], + ) + + +def downgrade() -> None: + # Drop indexes + op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation') + op.drop_index('ix_org_invitation_status', table_name='org_invitation') + op.drop_index('ix_org_invitation_email', table_name='org_invitation') + op.drop_index('ix_org_invitation_org_id', table_name='org_invitation') + op.drop_index('ix_org_invitation_token', table_name='org_invitation') + + # Drop table + op.drop_table('org_invitation') diff --git a/enterprise/saas_server.py b/enterprise/saas_server.py index e8c1935cb5..2248892993 100644 --- a/enterprise/saas_server.py +++ b/enterprise/saas_server.py @@ -38,6 +38,12 @@ from server.routes.integration.linear import linear_integration_router # noqa: from server.routes.integration.slack import slack_router # noqa: E402 from server.routes.mcp_patch import patch_mcp_server # noqa: E402 from server.routes.oauth_device import oauth_device_router # noqa: E402 +from server.routes.org_invitations import ( # noqa: E402 + accept_router as invitation_accept_router, +) +from server.routes.org_invitations import ( # noqa: E402 + invitation_router, +) from server.routes.orgs import org_router # noqa: E402 from server.routes.readiness import readiness_router # noqa: E402 from server.routes.user import saas_user_router # noqa: E402 @@ -99,6 +105,8 @@ if GITLAB_APP_CLIENT_ID: base_app.include_router(api_keys_router) # Add routes for API key management base_app.include_router(org_router) # Add routes for organization management +base_app.include_router(invitation_router) # Add routes for org invitation management +base_app.include_router(invitation_accept_router) # Add route for accepting invitations add_github_proxy_routes(base_app) add_debugging_routes( base_app diff --git a/enterprise/server/middleware.py b/enterprise/server/middleware.py index 651cd45512..c3d12c7897 100644 --- a/enterprise/server/middleware.py +++ b/enterprise/server/middleware.py @@ -160,6 +160,7 @@ class SetAuthCookieMiddleware: '/api/billing/customer-setup-success', '/api/billing/stripe-webhook', '/api/email/resend', + '/api/organizations/members/invite/accept', '/oauth/device/authorize', '/oauth/device/token', '/api/v1/web-client/config', diff --git a/enterprise/server/routes/auth.py b/enterprise/server/routes/auth.py index 15479bbdaa..75eb8b440f 100644 --- a/enterprise/server/routes/auth.py +++ b/enterprise/server/routes/auth.py @@ -5,6 +5,7 @@ import warnings from datetime import datetime, timezone from typing import Annotated, Literal, Optional from urllib.parse import quote +from uuid import UUID as parse_uuid import posthog from fastapi import APIRouter, Header, HTTPException, Request, Response, status @@ -26,6 +27,13 @@ from server.auth.token_manager import TokenManager from server.config import sign_token from server.constants import IS_FEATURE_ENV from server.routes.event_webhook import _get_session_api_key, _get_user_id +from server.services.org_invitation_service import ( + EmailMismatchError, + InvitationExpiredError, + InvitationInvalidError, + OrgInvitationService, + UserAlreadyMemberError, +) from storage.database import session_maker from storage.user import User from storage.user_store import UserStore @@ -104,22 +112,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']: ) +def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]: + """Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state. + + Returns: + Tuple of (redirect_url, recaptcha_token, invitation_token). + Tokens may be None. + """ + if not state: + return '', None, None + + try: + # Try to decode as JSON (new format with reCAPTCHA and/or invitation) + state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode()) + return ( + state_data.get('redirect_url', ''), + state_data.get('recaptcha_token'), + state_data.get('invitation_token'), + ) + except Exception: + # Old format - state is just the redirect URL + return state, None, None + + +# Keep alias for backward compatibility def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]: """Extract redirect URL and reCAPTCHA token from OAuth state. + Deprecated: Use _extract_oauth_state instead. + Returns: Tuple of (redirect_url, recaptcha_token). Token may be None. """ - if not state: - return '', None - - try: - # Try to decode as JSON (new format with reCAPTCHA) - state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode()) - return state_data.get('redirect_url', ''), state_data.get('recaptcha_token') - except Exception: - # Old format - state is just the redirect URL - return state, None + redirect_url, recaptcha_token, _ = _extract_oauth_state(state) + return redirect_url, recaptcha_token @oauth_router.get('/keycloak/callback') @@ -130,8 +156,8 @@ async def keycloak_callback( error: Optional[str] = None, error_description: Optional[str] = None, ): - # Extract redirect URL and reCAPTCHA token from state - redirect_url, recaptcha_token = _extract_recaptcha_state(state) + # Extract redirect URL, reCAPTCHA token, and invitation token from state + redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state) if not redirect_url: redirect_url = str(request.base_url) @@ -302,8 +328,13 @@ async def keycloak_callback( from server.routes.email import verify_email await verify_email(request=request, user_id=user_id, is_auth_flow=True) - redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}' - response = RedirectResponse(redirect_url, status_code=302) + verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}' + # Preserve invitation token so it can be included in OAuth state after verification + if invitation_token: + verification_redirect_url = ( + f'{verification_redirect_url}&invitation_token={invitation_token}' + ) + response = RedirectResponse(verification_redirect_url, status_code=302) return response # default to github IDP for now. @@ -381,14 +412,90 @@ async def keycloak_callback( ) has_accepted_tos = user.accepted_tos is not None + + # Process invitation token if present (after email verification but before TOS) + if invitation_token: + try: + logger.info( + 'Processing invitation token during auth callback', + extra={ + 'user_id': user_id, + 'invitation_token_prefix': invitation_token[:10] + '...', + }, + ) + + await OrgInvitationService.accept_invitation( + invitation_token, parse_uuid(user_id) + ) + logger.info( + 'Invitation accepted during auth callback', + extra={'user_id': user_id}, + ) + + except InvitationExpiredError: + logger.warning( + 'Invitation expired during auth callback', + extra={'user_id': user_id}, + ) + # Add query param to redirect URL + if '?' in redirect_url: + redirect_url = f'{redirect_url}&invitation_expired=true' + else: + redirect_url = f'{redirect_url}?invitation_expired=true' + + except InvitationInvalidError as e: + logger.warning( + 'Invalid invitation during auth callback', + extra={'user_id': user_id, 'error': str(e)}, + ) + if '?' in redirect_url: + redirect_url = f'{redirect_url}&invitation_invalid=true' + else: + redirect_url = f'{redirect_url}?invitation_invalid=true' + + except UserAlreadyMemberError: + logger.info( + 'User already member during invitation acceptance', + extra={'user_id': user_id}, + ) + if '?' in redirect_url: + redirect_url = f'{redirect_url}&already_member=true' + else: + redirect_url = f'{redirect_url}?already_member=true' + + except EmailMismatchError as e: + logger.warning( + 'Email mismatch during auth callback invitation acceptance', + extra={'user_id': user_id, 'error': str(e)}, + ) + if '?' in redirect_url: + redirect_url = f'{redirect_url}&email_mismatch=true' + else: + redirect_url = f'{redirect_url}?email_mismatch=true' + + except Exception as e: + logger.exception( + 'Unexpected error processing invitation during auth callback', + extra={'user_id': user_id, 'error': str(e)}, + ) + # Don't fail the login if invitation processing fails + if '?' in redirect_url: + redirect_url = f'{redirect_url}&invitation_error=true' + else: + redirect_url = f'{redirect_url}?invitation_error=true' + # If the user hasn't accepted the TOS, redirect to the TOS page if not has_accepted_tos: encoded_redirect_url = quote(redirect_url, safe='') tos_redirect_url = ( f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}' ) + if invitation_token: + tos_redirect_url = f'{tos_redirect_url}&invitation_success=true' response = RedirectResponse(tos_redirect_url, status_code=302) else: + if invitation_token: + redirect_url = f'{redirect_url}&invitation_success=true' response = RedirectResponse(redirect_url, status_code=302) set_response_cookie( diff --git a/enterprise/server/routes/org_invitation_models.py b/enterprise/server/routes/org_invitation_models.py new file mode 100644 index 0000000000..3852959a68 --- /dev/null +++ b/enterprise/server/routes/org_invitation_models.py @@ -0,0 +1,122 @@ +""" +Pydantic models and custom exceptions for organization invitations. +""" + +from pydantic import BaseModel, EmailStr +from storage.org_invitation import OrgInvitation +from storage.role_store import RoleStore + + +class InvitationError(Exception): + """Base exception for invitation errors.""" + + pass + + +class InvitationAlreadyExistsError(InvitationError): + """Raised when a pending invitation already exists for the email.""" + + def __init__( + self, message: str = 'A pending invitation already exists for this email' + ): + super().__init__(message) + + +class UserAlreadyMemberError(InvitationError): + """Raised when the user is already a member of the organization.""" + + def __init__(self, message: str = 'User is already a member of this organization'): + super().__init__(message) + + +class InvitationExpiredError(InvitationError): + """Raised when the invitation has expired.""" + + def __init__(self, message: str = 'Invitation has expired'): + super().__init__(message) + + +class InvitationInvalidError(InvitationError): + """Raised when the invitation is invalid or revoked.""" + + def __init__(self, message: str = 'Invitation is no longer valid'): + super().__init__(message) + + +class InsufficientPermissionError(InvitationError): + """Raised when the user lacks permission to perform the action.""" + + def __init__(self, message: str = 'Insufficient permission'): + super().__init__(message) + + +class EmailMismatchError(InvitationError): + """Raised when the accepting user's email doesn't match the invitation email.""" + + def __init__(self, message: str = 'Your email does not match the invitation'): + super().__init__(message) + + +class InvitationCreate(BaseModel): + """Request model for creating invitation(s).""" + + emails: list[EmailStr] + role: str = 'member' # Default to member role + + +class InvitationResponse(BaseModel): + """Response model for invitation details.""" + + id: int + email: str + role: str + status: str + created_at: str + expires_at: str + inviter_email: str | None = None + + @classmethod + def from_invitation( + cls, + invitation: OrgInvitation, + inviter_email: str | None = None, + ) -> 'InvitationResponse': + """Create an InvitationResponse from an OrgInvitation entity. + + Args: + invitation: The invitation entity to convert + inviter_email: Optional email of the inviter + + Returns: + InvitationResponse: The response model instance + """ + role_name = '' + if invitation.role: + role_name = invitation.role.name + elif invitation.role_id: + role = RoleStore.get_role_by_id(invitation.role_id) + role_name = role.name if role else '' + + return cls( + id=invitation.id, + email=invitation.email, + role=role_name, + status=invitation.status, + created_at=invitation.created_at.isoformat(), + expires_at=invitation.expires_at.isoformat(), + inviter_email=inviter_email, + ) + + +class InvitationFailure(BaseModel): + """Response model for a failed invitation.""" + + email: str + error: str + + +class BatchInvitationResponse(BaseModel): + """Response model for batch invitation creation.""" + + successful: list[InvitationResponse] + failed: list[InvitationFailure] diff --git a/enterprise/server/routes/org_invitations.py b/enterprise/server/routes/org_invitations.py new file mode 100644 index 0000000000..3349d600ac --- /dev/null +++ b/enterprise/server/routes/org_invitations.py @@ -0,0 +1,226 @@ +"""API routes for organization invitations.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse +from server.routes.org_invitation_models import ( + BatchInvitationResponse, + EmailMismatchError, + InsufficientPermissionError, + InvitationCreate, + InvitationExpiredError, + InvitationFailure, + InvitationInvalidError, + InvitationResponse, + UserAlreadyMemberError, +) +from server.services.org_invitation_service import OrgInvitationService +from server.utils.rate_limit_utils import check_rate_limit_by_user_id + +from openhands.core.logger import openhands_logger as logger +from openhands.server.user_auth import get_user_id +from openhands.server.user_auth.user_auth import get_user_auth + +# Router for invitation operations on an organization (requires org_id) +invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members') + +# Router for accepting invitations (no org_id required) +accept_router = APIRouter(prefix='/api/organizations/members/invite') + + +@invitation_router.post( + '/invite', + response_model=BatchInvitationResponse, + status_code=status.HTTP_201_CREATED, +) +async def create_invitation( + org_id: UUID, + invitation_data: InvitationCreate, + request: Request, + user_id: str = Depends(get_user_id), +): + """Create organization invitations for multiple email addresses. + + Sends emails to invitees with secure links to join the organization. + Supports batch invitations - some may succeed while others fail. + + Permission rules: + - Only owners and admins can create invitations + - Admins can only invite with 'member' or 'admin' role (not 'owner') + - Owners can invite with any role + + Args: + org_id: Organization UUID + invitation_data: Invitation details (emails array, role) + request: FastAPI request + user_id: Authenticated user ID (from dependency) + + Returns: + BatchInvitationResponse: Lists of successful and failed invitations + + Raises: + HTTPException 400: Invalid role or organization not found + HTTPException 403: User lacks permission to invite + HTTPException 429: Rate limit exceeded + """ + # Rate limit: 10 invitations per minute per user (6 seconds between requests) + await check_rate_limit_by_user_id( + request=request, + key_prefix='org_invitation_create', + user_id=user_id, + user_rate_limit_seconds=6, + ) + + try: + successful, failed = await OrgInvitationService.create_invitations_batch( + org_id=org_id, + emails=[str(email) for email in invitation_data.emails], + role_name=invitation_data.role, + inviter_id=UUID(user_id), + ) + + logger.info( + 'Batch organization invitations created', + extra={ + 'org_id': str(org_id), + 'total_emails': len(invitation_data.emails), + 'successful': len(successful), + 'failed': len(failed), + 'inviter_id': user_id, + }, + ) + + return BatchInvitationResponse( + successful=[InvitationResponse.from_invitation(inv) for inv in successful], + failed=[ + InvitationFailure(email=email, error=error) for email, error in failed + ], + ) + + except InsufficientPermissionError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e), + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + except Exception as e: + logger.exception( + 'Unexpected error creating batch invitations', + extra={'org_id': str(org_id), 'error': str(e)}, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='An unexpected error occurred', + ) + + +@accept_router.get('/accept') +async def accept_invitation( + token: str, + request: Request, +): + """Accept an organization invitation via token. + + This endpoint is accessed via the link in the invitation email. + + Flow: + 1. If user is authenticated: Accept invitation directly and redirect to home + 2. If user is not authenticated: Redirect to login page with invitation token + - Frontend stores token and includes it in OAuth state during login + - After authentication, keycloak_callback processes the invitation + + Args: + token: The invitation token from the email link + request: FastAPI request + + Returns: + RedirectResponse: Redirect to home page on success, or login page if not authenticated, + or home page with error query params on failure + """ + base_url = str(request.base_url).rstrip('/') + + # Try to get user_id from auth (may not be authenticated) + user_id = None + try: + user_auth = await get_user_auth(request) + if user_auth: + user_id = await user_auth.get_user_id() + except Exception: + pass + + if not user_id: + # User not authenticated - redirect to login page with invitation token + # Frontend will store the token and include it in OAuth state during login + logger.info( + 'Invitation accept: redirecting unauthenticated user to login', + extra={'token_prefix': token[:10] + '...'}, + ) + login_url = f'{base_url}/login?invitation_token={token}' + return RedirectResponse(login_url, status_code=302) + + # User is authenticated - process the invitation directly + try: + await OrgInvitationService.accept_invitation(token, UUID(user_id)) + + logger.info( + 'Invitation accepted successfully', + extra={ + 'token_prefix': token[:10] + '...', + 'user_id': user_id, + }, + ) + + # Redirect to home page on success + return RedirectResponse(f'{base_url}/', status_code=302) + + except InvitationExpiredError: + logger.warning( + 'Invitation accept failed: expired', + extra={'token_prefix': token[:10] + '...', 'user_id': user_id}, + ) + return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302) + + except InvitationInvalidError as e: + logger.warning( + 'Invitation accept failed: invalid', + extra={ + 'token_prefix': token[:10] + '...', + 'user_id': user_id, + 'error': str(e), + }, + ) + return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302) + + except UserAlreadyMemberError: + logger.info( + 'Invitation accept: user already member', + extra={'token_prefix': token[:10] + '...', 'user_id': user_id}, + ) + return RedirectResponse(f'{base_url}/?already_member=true', status_code=302) + + except EmailMismatchError as e: + logger.warning( + 'Invitation accept failed: email mismatch', + extra={ + 'token_prefix': token[:10] + '...', + 'user_id': user_id, + 'error': str(e), + }, + ) + return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302) + + except Exception as e: + logger.exception( + 'Unexpected error accepting invitation', + extra={ + 'token_prefix': token[:10] + '...', + 'user_id': user_id, + 'error': str(e), + }, + ) + return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302) diff --git a/enterprise/server/services/email_service.py b/enterprise/server/services/email_service.py new file mode 100644 index 0000000000..163671b617 --- /dev/null +++ b/enterprise/server/services/email_service.py @@ -0,0 +1,131 @@ +"""Email service for sending transactional emails via Resend.""" + +import os + +try: + import resend + + RESEND_AVAILABLE = True +except ImportError: + RESEND_AVAILABLE = False + +from openhands.core.logger import openhands_logger as logger + +DEFAULT_FROM_EMAIL = 'OpenHands ' +DEFAULT_WEB_HOST = 'https://app.all-hands.dev' + + +class EmailService: + """Service for sending transactional emails.""" + + @staticmethod + def _get_resend_client() -> bool: + """Initialize and return the Resend client. + + Returns: + bool: True if client is ready, False otherwise + """ + if not RESEND_AVAILABLE: + logger.warning('Resend library not installed, skipping email') + return False + + resend_api_key = os.environ.get('RESEND_API_KEY') + if not resend_api_key: + logger.warning('RESEND_API_KEY not configured, skipping email') + return False + + resend.api_key = resend_api_key + return True + + @staticmethod + def send_invitation_email( + to_email: str, + org_name: str, + inviter_name: str, + role_name: str, + invitation_token: str, + invitation_id: int, + ) -> None: + """Send an organization invitation email. + + Args: + to_email: Recipient's email address + org_name: Name of the organization + inviter_name: Display name of the person who sent the invite + role_name: Role being offered (e.g., 'member', 'admin') + invitation_token: The secure invitation token + invitation_id: The invitation ID for logging + """ + if not EmailService._get_resend_client(): + return + + # Build invitation URL + web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST) + invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}' + + from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL) + + params = { + 'from': from_email, + 'to': [to_email], + 'subject': f"You're invited to join {org_name} on OpenHands", + 'html': f""" +
+

Hi,

+ +

{inviter_name} has invited you to join {org_name} on OpenHands as a {role_name}.

+ +

Click the button below to accept the invitation:

+ +

+ + Accept Invitation + +

+ +

+ Or copy and paste this link into your browser:
+ {invitation_url} +

+ +

+ This invitation will expire in 7 days. +

+ +

+ If you weren't expecting this invitation, you can safely ignore this email. +

+ +
+ +

+ Best,
+ The OpenHands Team +

+
+ """, + } + + try: + response = resend.Emails.send(params) + logger.info( + 'Invitation email sent', + extra={ + 'invitation_id': invitation_id, + 'email': to_email, + 'response_id': response.get('id') if response else None, + }, + ) + except Exception as e: + logger.error( + 'Failed to send invitation email', + extra={ + 'invitation_id': invitation_id, + 'email': to_email, + 'error': str(e), + }, + ) + raise diff --git a/enterprise/server/services/org_invitation_service.py b/enterprise/server/services/org_invitation_service.py new file mode 100644 index 0000000000..e31e43c61f --- /dev/null +++ b/enterprise/server/services/org_invitation_service.py @@ -0,0 +1,397 @@ +"""Service for managing organization invitations.""" + +import asyncio +from uuid import UUID + +from server.auth.token_manager import TokenManager +from server.constants import ROLE_ADMIN, ROLE_OWNER +from server.routes.org_invitation_models import ( + EmailMismatchError, + InsufficientPermissionError, + InvitationExpiredError, + InvitationInvalidError, + UserAlreadyMemberError, +) +from server.services.email_service import EmailService +from storage.org_invitation import OrgInvitation +from storage.org_invitation_store import OrgInvitationStore +from storage.org_member_store import OrgMemberStore +from storage.org_service import OrgService +from storage.org_store import OrgStore +from storage.role_store import RoleStore +from storage.user_store import UserStore + +from openhands.core.logger import openhands_logger as logger + + +class OrgInvitationService: + """Service for organization invitation operations.""" + + @staticmethod + async def create_invitation( + org_id: UUID, + email: str, + role_name: str, + inviter_id: UUID, + ) -> OrgInvitation: + """Create a new organization invitation. + + This method: + 1. Validates the organization exists + 2. Validates this is not a personal workspace + 3. Checks inviter has owner/admin role + 4. Validates role assignment permissions + 5. Checks if user is already a member + 6. Creates the invitation + 7. Sends the invitation email + + Args: + org_id: Organization UUID + email: Invitee's email address + role_name: Role to assign on acceptance (owner, admin, member) + inviter_id: User ID of the person creating the invitation + + Returns: + OrgInvitation: The created invitation + + Raises: + ValueError: If organization or role not found + InsufficientPermissionError: If inviter lacks permission + UserAlreadyMemberError: If email is already a member + InvitationAlreadyExistsError: If pending invitation exists + """ + email = email.lower().strip() + + logger.info( + 'Creating organization invitation', + extra={ + 'org_id': str(org_id), + 'email': email, + 'role_name': role_name, + 'inviter_id': str(inviter_id), + }, + ) + + # Step 1: Validate organization exists + org = OrgStore.get_org_by_id(org_id) + if not org: + raise ValueError(f'Organization {org_id} not found') + + # Step 2: Check this is not a personal workspace + # A personal workspace has org_id matching the user's id + if str(org_id) == str(inviter_id): + raise InsufficientPermissionError( + 'Cannot invite users to a personal workspace' + ) + + # Step 3: Check inviter is a member and has permission + inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id) + if not inviter_member: + raise InsufficientPermissionError( + 'You are not a member of this organization' + ) + + inviter_role = RoleStore.get_role_by_id(inviter_member.role_id) + if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]: + raise InsufficientPermissionError('Only owners and admins can invite users') + + # Step 4: Validate role assignment permissions + role_name_lower = role_name.lower() + if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER: + raise InsufficientPermissionError('Only owners can invite with owner role') + + # Get the target role + target_role = RoleStore.get_role_by_name(role_name_lower) + if not target_role: + raise ValueError(f'Invalid role: {role_name}') + + # Step 5: Check if user is already a member (by email) + existing_user = await UserStore.get_user_by_email_async(email) + if existing_user: + existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id) + if existing_member: + raise UserAlreadyMemberError( + 'User is already a member of this organization' + ) + + # Step 6: Create the invitation + invitation = await OrgInvitationStore.create_invitation( + org_id=org_id, + email=email, + role_id=target_role.id, + inviter_id=inviter_id, + ) + + # Step 7: Send invitation email + try: + # Get inviter info for the email + inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id)) + inviter_name = 'A team member' + if inviter_user and inviter_user.email: + inviter_name = inviter_user.email.split('@')[0] + + EmailService.send_invitation_email( + to_email=email, + org_name=org.name, + inviter_name=inviter_name, + role_name=target_role.name, + invitation_token=invitation.token, + invitation_id=invitation.id, + ) + except Exception as e: + logger.error( + 'Failed to send invitation email', + extra={ + 'invitation_id': invitation.id, + 'email': email, + 'error': str(e), + }, + ) + # Don't fail the invitation creation if email fails + # The user can still access via direct link + + return invitation + + @staticmethod + async def create_invitations_batch( + org_id: UUID, + emails: list[str], + role_name: str, + inviter_id: UUID, + ) -> tuple[list[OrgInvitation], list[tuple[str, str]]]: + """Create multiple organization invitations concurrently. + + Validates permissions once upfront, then creates invitations in parallel. + + Args: + org_id: Organization UUID + emails: List of invitee email addresses + role_name: Role to assign on acceptance (owner, admin, member) + inviter_id: User ID of the person creating the invitations + + Returns: + Tuple of (successful_invitations, failed_emails_with_errors) + + Raises: + ValueError: If organization or role not found + InsufficientPermissionError: If inviter lacks permission + """ + logger.info( + 'Creating batch organization invitations', + extra={ + 'org_id': str(org_id), + 'email_count': len(emails), + 'role_name': role_name, + 'inviter_id': str(inviter_id), + }, + ) + + # Step 1: Validate permissions upfront (shared for all emails) + org = OrgStore.get_org_by_id(org_id) + if not org: + raise ValueError(f'Organization {org_id} not found') + + if str(org_id) == str(inviter_id): + raise InsufficientPermissionError( + 'Cannot invite users to a personal workspace' + ) + + inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id) + if not inviter_member: + raise InsufficientPermissionError( + 'You are not a member of this organization' + ) + + inviter_role = RoleStore.get_role_by_id(inviter_member.role_id) + if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]: + raise InsufficientPermissionError('Only owners and admins can invite users') + + role_name_lower = role_name.lower() + if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER: + raise InsufficientPermissionError('Only owners can invite with owner role') + + target_role = RoleStore.get_role_by_name(role_name_lower) + if not target_role: + raise ValueError(f'Invalid role: {role_name}') + + # Step 2: Create invitations concurrently + async def create_single( + email: str, + ) -> tuple[str, OrgInvitation | None, str | None]: + """Create single invitation, return (email, invitation, error).""" + try: + invitation = await OrgInvitationService.create_invitation( + org_id=org_id, + email=email, + role_name=role_name, + inviter_id=inviter_id, + ) + return (email, invitation, None) + except (UserAlreadyMemberError, ValueError) as e: + return (email, None, str(e)) + + results = await asyncio.gather(*[create_single(email) for email in emails]) + + # Step 3: Separate successes and failures + successful: list[OrgInvitation] = [] + failed: list[tuple[str, str]] = [] + for email, invitation, error in results: + if invitation: + successful.append(invitation) + elif error: + failed.append((email, error)) + + logger.info( + 'Batch invitation creation completed', + extra={ + 'org_id': str(org_id), + 'successful': len(successful), + 'failed': len(failed), + }, + ) + + return successful, failed + + @staticmethod + async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation: + """Accept an organization invitation. + + This method: + 1. Validates the token and invitation status + 2. Checks expiration + 3. Verifies user is not already a member + 4. Creates LiteLLM integration + 5. Adds user to the organization + 6. Marks invitation as accepted + + Args: + token: The invitation token + user_id: The user accepting the invitation + + Returns: + OrgInvitation: The accepted invitation + + Raises: + InvitationInvalidError: If token is invalid or invitation not pending + InvitationExpiredError: If invitation has expired + UserAlreadyMemberError: If user is already a member + """ + logger.info( + 'Accepting organization invitation', + extra={ + 'token_prefix': token[:10] + '...' if len(token) > 10 else token, + 'user_id': str(user_id), + }, + ) + + # Step 1: Get and validate invitation + invitation = await OrgInvitationStore.get_invitation_by_token(token) + + if not invitation: + raise InvitationInvalidError('Invalid invitation token') + + if invitation.status != OrgInvitation.STATUS_PENDING: + if invitation.status == OrgInvitation.STATUS_ACCEPTED: + raise InvitationInvalidError('Invitation has already been accepted') + elif invitation.status == OrgInvitation.STATUS_REVOKED: + raise InvitationInvalidError('Invitation has been revoked') + else: + raise InvitationInvalidError('Invitation is no longer valid') + + # Step 2: Check expiration + if OrgInvitationStore.is_token_expired(invitation): + await OrgInvitationStore.update_invitation_status( + invitation.id, OrgInvitation.STATUS_EXPIRED + ) + raise InvitationExpiredError('Invitation has expired') + + # Step 2.5: Verify user email matches invitation email + user = await UserStore.get_user_by_id_async(str(user_id)) + if not user: + raise InvitationInvalidError('User not found') + + user_email = user.email + # Fallback: fetch email from Keycloak if not in database (for existing users) + if not user_email: + token_manager = TokenManager() + user_info = await token_manager.get_user_info_from_user_id(str(user_id)) + user_email = user_info.get('email') if user_info else None + + if not user_email: + raise EmailMismatchError('Your account does not have an email address') + + user_email = user_email.lower().strip() + invitation_email = invitation.email.lower().strip() + + if user_email != invitation_email: + logger.warning( + 'Email mismatch during invitation acceptance', + extra={ + 'user_id': str(user_id), + 'user_email': user_email, + 'invitation_email': invitation_email, + 'invitation_id': invitation.id, + }, + ) + raise EmailMismatchError() + + # Step 3: Check if user is already a member + existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id) + if existing_member: + raise UserAlreadyMemberError( + 'You are already a member of this organization' + ) + + # Step 4: Create LiteLLM integration for the user in the new org + try: + settings = await OrgService.create_litellm_integration( + invitation.org_id, str(user_id) + ) + except Exception as e: + logger.error( + 'Failed to create LiteLLM integration for invitation acceptance', + extra={ + 'invitation_id': invitation.id, + 'user_id': str(user_id), + 'org_id': str(invitation.org_id), + 'error': str(e), + }, + ) + raise InvitationInvalidError( + 'Failed to set up organization access. Please try again.' + ) + + # Step 5: Add user to organization + from storage.org_member_store import OrgMemberStore as OMS + + org_member_kwargs = OMS.get_kwargs_from_settings(settings) + # Don't override with org defaults - use invitation-specified role + org_member_kwargs.pop('llm_model', None) + org_member_kwargs.pop('llm_base_url', None) + + OrgMemberStore.add_user_to_org( + org_id=invitation.org_id, + user_id=user_id, + role_id=invitation.role_id, + llm_api_key=settings.llm_api_key, + status='active', + ) + + # Step 6: Mark invitation as accepted + updated_invitation = await OrgInvitationStore.update_invitation_status( + invitation.id, + OrgInvitation.STATUS_ACCEPTED, + accepted_by_user_id=user_id, + ) + + logger.info( + 'Organization invitation accepted', + extra={ + 'invitation_id': invitation.id, + 'user_id': str(user_id), + 'org_id': str(invitation.org_id), + 'role_id': invitation.role_id, + }, + ) + + return updated_invitation diff --git a/enterprise/storage/__init__.py b/enterprise/storage/__init__.py index c00ad5346a..13b8564421 100644 --- a/enterprise/storage/__init__.py +++ b/enterprise/storage/__init__.py @@ -20,6 +20,7 @@ from storage.linear_workspace import LinearWorkspace from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus from storage.openhands_pr import OpenhandsPR from storage.org import Org +from storage.org_invitation import OrgInvitation from storage.org_member import OrgMember from storage.proactive_convos import ProactiveConversation from storage.role import Role @@ -65,6 +66,7 @@ __all__ = [ 'MaintenanceTaskStatus', 'OpenhandsPR', 'Org', + 'OrgInvitation', 'OrgMember', 'ProactiveConversation', 'Role', diff --git a/enterprise/storage/org.py b/enterprise/storage/org.py index 6e9884d655..bd20cb45d3 100644 --- a/enterprise/storage/org.py +++ b/enterprise/storage/org.py @@ -52,6 +52,7 @@ class Org(Base): # type: ignore # Relationships org_members = relationship('OrgMember', back_populates='org') current_users = relationship('User', back_populates='current_org') + invitations = relationship('OrgInvitation', back_populates='org') billing_sessions = relationship('BillingSession', back_populates='org') stored_conversation_metadata_saas = relationship( 'StoredConversationMetadataSaas', back_populates='org' diff --git a/enterprise/storage/org_invitation.py b/enterprise/storage/org_invitation.py new file mode 100644 index 0000000000..167271f8d3 --- /dev/null +++ b/enterprise/storage/org_invitation.py @@ -0,0 +1,59 @@ +""" +SQLAlchemy model for Organization Invitation. +""" + +from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text +from sqlalchemy.orm import relationship +from storage.base import Base + + +class OrgInvitation(Base): # type: ignore + """Organization invitation model. + + Represents an invitation for a user to join an organization. + Invitations are created by organization owners/admins and contain + a secure token that can be used to accept the invitation. + """ + + __tablename__ = 'org_invitation' + + id = Column(Integer, primary_key=True, autoincrement=True) + token = Column(String(64), nullable=False, unique=True, index=True) + org_id = Column( + UUID(as_uuid=True), + ForeignKey('org.id', ondelete='CASCADE'), + nullable=False, + index=True, + ) + email = Column(String(255), nullable=False, index=True) + role_id = Column(Integer, ForeignKey('role.id'), nullable=False) + inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False) + status = Column( + String(20), + nullable=False, + server_default=text("'pending'"), + ) + created_at = Column( + DateTime, + nullable=False, + server_default=text('CURRENT_TIMESTAMP'), + ) + expires_at = Column(DateTime, nullable=False) + accepted_at = Column(DateTime, nullable=True) + accepted_by_user_id = Column( + UUID(as_uuid=True), + ForeignKey('user.id'), + nullable=True, + ) + + # Relationships + org = relationship('Org', back_populates='invitations') + role = relationship('Role') + inviter = relationship('User', foreign_keys=[inviter_id]) + accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id]) + + # Status constants + STATUS_PENDING = 'pending' + STATUS_ACCEPTED = 'accepted' + STATUS_REVOKED = 'revoked' + STATUS_EXPIRED = 'expired' diff --git a/enterprise/storage/org_invitation_store.py b/enterprise/storage/org_invitation_store.py new file mode 100644 index 0000000000..86a84fb411 --- /dev/null +++ b/enterprise/storage/org_invitation_store.py @@ -0,0 +1,227 @@ +""" +Store class for managing organization invitations. +""" + +import secrets +import string +from datetime import datetime, timedelta +from typing import Optional +from uuid import UUID + +from sqlalchemy import and_, select +from sqlalchemy.orm import joinedload +from storage.database import a_session_maker +from storage.org_invitation import OrgInvitation + +from openhands.core.logger import openhands_logger as logger + +# Invitation token configuration +INVITATION_TOKEN_PREFIX = 'inv-' +INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix +DEFAULT_EXPIRATION_DAYS = 7 + + +class OrgInvitationStore: + """Store for managing organization invitations.""" + + @staticmethod + def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str: + """Generate a secure invitation token. + + Uses cryptographically secure random generation for tokens. + Pattern from api_key_store.py. + + Args: + length: Length of the random part of the token + + Returns: + str: Token with prefix (e.g., 'inv-aBcDeF123...') + """ + alphabet = string.ascii_letters + string.digits + random_part = ''.join(secrets.choice(alphabet) for _ in range(length)) + return f'{INVITATION_TOKEN_PREFIX}{random_part}' + + @staticmethod + async def create_invitation( + org_id: UUID, + email: str, + role_id: int, + inviter_id: UUID, + expiration_days: int = DEFAULT_EXPIRATION_DAYS, + ) -> OrgInvitation: + """Create a new organization invitation. + + Args: + org_id: Organization UUID + email: Invitee's email address + role_id: Role ID to assign on acceptance + inviter_id: User ID of the person creating the invitation + expiration_days: Days until the invitation expires + + Returns: + OrgInvitation: The created invitation record + """ + async with a_session_maker() as session: + token = OrgInvitationStore.generate_token() + # Use timezone-naive datetime for database compatibility + expires_at = datetime.utcnow() + timedelta(days=expiration_days) + + invitation = OrgInvitation( + token=token, + org_id=org_id, + email=email.lower().strip(), + role_id=role_id, + inviter_id=inviter_id, + status=OrgInvitation.STATUS_PENDING, + expires_at=expires_at, + ) + session.add(invitation) + await session.commit() + + # Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError + result = await session.execute( + select(OrgInvitation) + .options(joinedload(OrgInvitation.role)) + .filter(OrgInvitation.id == invitation.id) + ) + invitation = result.scalars().first() + + logger.info( + 'Created organization invitation', + extra={ + 'invitation_id': invitation.id, + 'org_id': str(org_id), + 'email': email, + 'inviter_id': str(inviter_id), + 'expires_at': expires_at.isoformat(), + }, + ) + + return invitation + + @staticmethod + async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]: + """Get an invitation by its token. + + Args: + token: The invitation token + + Returns: + OrgInvitation or None if not found + """ + async with a_session_maker() as session: + result = await session.execute( + select(OrgInvitation) + .options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role)) + .filter(OrgInvitation.token == token) + ) + return result.scalars().first() + + @staticmethod + async def get_pending_invitation( + org_id: UUID, email: str + ) -> Optional[OrgInvitation]: + """Get a pending invitation for an email in an organization. + + Args: + org_id: Organization UUID + email: Email address to check + + Returns: + OrgInvitation or None if no pending invitation exists + """ + async with a_session_maker() as session: + result = await session.execute( + select(OrgInvitation).filter( + and_( + OrgInvitation.org_id == org_id, + OrgInvitation.email == email.lower().strip(), + OrgInvitation.status == OrgInvitation.STATUS_PENDING, + ) + ) + ) + return result.scalars().first() + + @staticmethod + async def update_invitation_status( + invitation_id: int, + status: str, + accepted_by_user_id: Optional[UUID] = None, + ) -> Optional[OrgInvitation]: + """Update an invitation's status. + + Args: + invitation_id: The invitation ID + status: New status (pending, accepted, revoked, expired) + accepted_by_user_id: User ID who accepted (only for 'accepted' status) + + Returns: + Updated OrgInvitation or None if not found + """ + async with a_session_maker() as session: + result = await session.execute( + select(OrgInvitation).filter(OrgInvitation.id == invitation_id) + ) + invitation = result.scalars().first() + + if not invitation: + return None + + old_status = invitation.status + invitation.status = status + + if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id: + # Use timezone-naive datetime for database compatibility + invitation.accepted_at = datetime.utcnow() + invitation.accepted_by_user_id = accepted_by_user_id + + await session.commit() + await session.refresh(invitation) + + logger.info( + 'Updated invitation status', + extra={ + 'invitation_id': invitation_id, + 'old_status': old_status, + 'new_status': status, + 'accepted_by_user_id': ( + str(accepted_by_user_id) if accepted_by_user_id else None + ), + }, + ) + + return invitation + + @staticmethod + def is_token_expired(invitation: OrgInvitation) -> bool: + """Check if an invitation token has expired. + + Args: + invitation: The invitation to check + + Returns: + bool: True if expired, False otherwise + """ + # Use timezone-naive datetime for comparison (database stores without timezone) + now = datetime.utcnow() + return invitation.expires_at < now + + @staticmethod + async def mark_expired_if_needed(invitation: OrgInvitation) -> bool: + """Check if invitation is expired and update status if needed. + + Args: + invitation: The invitation to check + + Returns: + bool: True if invitation was marked as expired, False otherwise + """ + if ( + invitation.status == OrgInvitation.STATUS_PENDING + and OrgInvitationStore.is_token_expired(invitation) + ): + await OrgInvitationStore.update_invitation_status( + invitation.id, OrgInvitation.STATUS_EXPIRED + ) + return True + return False diff --git a/enterprise/storage/user_store.py b/enterprise/storage/user_store.py index ff49c17da8..379f02e45e 100644 --- a/enterprise/storage/user_store.py +++ b/enterprise/storage/user_store.py @@ -770,6 +770,30 @@ class UserStore: finally: await UserStore._release_user_creation_lock(user_id) + @staticmethod + async def get_user_by_email_async(email: str) -> Optional[User]: + """Get user by email address (async version). + + This method looks up a user by their email address. Note that email + addresses may not be unique across all users in rare cases. + + Args: + email: The email address to search for + + Returns: + User: The user with the matching email, or None if not found + """ + if not email: + return None + + async with a_session_maker() as session: + result = await session.execute( + select(User) + .options(joinedload(User.org_members)) + .filter(User.email == email.lower().strip()) + ) + return result.scalars().first() + @staticmethod def list_users() -> list[User]: """List all users.""" diff --git a/enterprise/tests/unit/test_auth_invitation_callback.py b/enterprise/tests/unit/test_auth_invitation_callback.py new file mode 100644 index 0000000000..e6906e550f --- /dev/null +++ b/enterprise/tests/unit/test_auth_invitation_callback.py @@ -0,0 +1,181 @@ +"""Tests for auth callback invitation acceptance - EmailMismatchError handling.""" + +import pytest + + +class TestAuthCallbackInvitationEmailMismatch: + """Test cases for EmailMismatchError handling during auth callback.""" + + @pytest.fixture + def mock_redirect_url(self): + """Base redirect URL.""" + return 'https://app.example.com/' + + @pytest.fixture + def mock_user_id(self): + """Mock user ID.""" + return '87654321-4321-8765-4321-876543218765' + + def test_email_mismatch_appends_to_url_without_query_params( + self, mock_redirect_url, mock_user_id + ): + """Test that email_mismatch=true is appended correctly when URL has no query params.""" + from server.routes.org_invitation_models import EmailMismatchError + + # Simulate the logic from auth.py + redirect_url = mock_redirect_url + try: + raise EmailMismatchError('Your email does not match the invitation') + except EmailMismatchError: + if '?' in redirect_url: + redirect_url = f'{redirect_url}&email_mismatch=true' + else: + redirect_url = f'{redirect_url}?email_mismatch=true' + + assert redirect_url == 'https://app.example.com/?email_mismatch=true' + + def test_email_mismatch_appends_to_url_with_query_params(self, mock_user_id): + """Test that email_mismatch=true is appended correctly when URL has existing query params.""" + from server.routes.org_invitation_models import EmailMismatchError + + redirect_url = 'https://app.example.com/?other_param=value' + try: + raise EmailMismatchError() + except EmailMismatchError: + if '?' in redirect_url: + redirect_url = f'{redirect_url}&email_mismatch=true' + else: + redirect_url = f'{redirect_url}?email_mismatch=true' + + assert ( + redirect_url + == 'https://app.example.com/?other_param=value&email_mismatch=true' + ) + + def test_email_mismatch_error_has_default_message(self): + """Test that EmailMismatchError has the default message.""" + from server.routes.org_invitation_models import EmailMismatchError + + error = EmailMismatchError() + assert str(error) == 'Your email does not match the invitation' + + def test_email_mismatch_error_accepts_custom_message(self): + """Test that EmailMismatchError accepts a custom message.""" + from server.routes.org_invitation_models import EmailMismatchError + + custom_message = 'Custom error message' + error = EmailMismatchError(custom_message) + assert str(error) == custom_message + + def test_email_mismatch_error_is_invitation_error(self): + """Test that EmailMismatchError inherits from InvitationError.""" + from server.routes.org_invitation_models import ( + EmailMismatchError, + InvitationError, + ) + + error = EmailMismatchError() + assert isinstance(error, InvitationError) + + +class TestInvitationTokenInOAuthState: + """Test cases for invitation token handling in OAuth state.""" + + def test_invitation_token_included_in_oauth_state(self): + """Test that invitation token is included in OAuth state data.""" + import base64 + import json + + # Simulate building OAuth state with invitation token + state_data = { + 'redirect_url': 'https://app.example.com/', + 'invitation_token': 'inv-test-token-12345', + } + + encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode() + decoded_data = json.loads(base64.b64decode(encoded_state)) + + assert decoded_data['invitation_token'] == 'inv-test-token-12345' + assert decoded_data['redirect_url'] == 'https://app.example.com/' + + def test_invitation_token_extracted_from_oauth_state(self): + """Test that invitation token can be extracted from OAuth state.""" + import base64 + import json + + state_data = { + 'redirect_url': 'https://app.example.com/', + 'invitation_token': 'inv-test-token-12345', + } + + encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode() + + # Simulate decoding in callback + decoded_state = json.loads(base64.b64decode(encoded_state)) + invitation_token = decoded_state.get('invitation_token') + + assert invitation_token == 'inv-test-token-12345' + + def test_oauth_state_without_invitation_token(self): + """Test that OAuth state works without invitation token.""" + import base64 + import json + + state_data = { + 'redirect_url': 'https://app.example.com/', + } + + encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode() + decoded_data = json.loads(base64.b64decode(encoded_state)) + + assert 'invitation_token' not in decoded_data + assert decoded_data['redirect_url'] == 'https://app.example.com/' + + +class TestAuthCallbackInvitationErrors: + """Test cases for various invitation error scenarios in auth callback.""" + + def test_invitation_expired_appends_flag(self): + """Test that invitation_expired=true is appended for expired invitations.""" + from server.routes.org_invitation_models import InvitationExpiredError + + redirect_url = 'https://app.example.com/' + try: + raise InvitationExpiredError() + except InvitationExpiredError: + if '?' in redirect_url: + redirect_url = f'{redirect_url}&invitation_expired=true' + else: + redirect_url = f'{redirect_url}?invitation_expired=true' + + assert redirect_url == 'https://app.example.com/?invitation_expired=true' + + def test_invitation_invalid_appends_flag(self): + """Test that invitation_invalid=true is appended for invalid invitations.""" + from server.routes.org_invitation_models import InvitationInvalidError + + redirect_url = 'https://app.example.com/' + try: + raise InvitationInvalidError() + except InvitationInvalidError: + if '?' in redirect_url: + redirect_url = f'{redirect_url}&invitation_invalid=true' + else: + redirect_url = f'{redirect_url}?invitation_invalid=true' + + assert redirect_url == 'https://app.example.com/?invitation_invalid=true' + + def test_already_member_appends_flag(self): + """Test that already_member=true is appended when user is already a member.""" + from server.routes.org_invitation_models import UserAlreadyMemberError + + redirect_url = 'https://app.example.com/' + try: + raise UserAlreadyMemberError() + except UserAlreadyMemberError: + if '?' in redirect_url: + redirect_url = f'{redirect_url}&already_member=true' + else: + redirect_url = f'{redirect_url}?already_member=true' + + assert redirect_url == 'https://app.example.com/?already_member=true' diff --git a/enterprise/tests/unit/test_email_service.py b/enterprise/tests/unit/test_email_service.py new file mode 100644 index 0000000000..0ea63e0496 --- /dev/null +++ b/enterprise/tests/unit/test_email_service.py @@ -0,0 +1,192 @@ +"""Tests for email service.""" + +import os +from unittest.mock import MagicMock, patch + +from server.services.email_service import ( + DEFAULT_WEB_HOST, + EmailService, +) + + +class TestEmailServiceInvitationUrl: + """Test cases for invitation URL generation.""" + + def test_invitation_url_uses_correct_endpoint(self): + """Test that invitation URL points to the correct API endpoint.""" + mock_response = MagicMock() + mock_response.get.return_value = 'test-email-id' + + with ( + patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}), + patch('server.services.email_service.RESEND_AVAILABLE', True), + patch('server.services.email_service.resend') as mock_resend, + ): + mock_resend.Emails.send.return_value = mock_response + + EmailService.send_invitation_email( + to_email='test@example.com', + org_name='Test Org', + inviter_name='Inviter', + role_name='member', + invitation_token='inv-test-token-12345', + invitation_id=1, + ) + + # Get the call arguments + call_args = mock_resend.Emails.send.call_args + email_params = call_args[0][0] + + # Verify the URL in the email HTML contains the correct endpoint + assert ( + '/api/organizations/members/invite/accept?token=' + in email_params['html'] + ) + assert 'inv-test-token-12345' in email_params['html'] + + def test_invitation_url_uses_web_host_env_var(self): + """Test that invitation URL uses WEB_HOST environment variable.""" + custom_host = 'https://custom.example.com' + mock_response = MagicMock() + mock_response.get.return_value = 'test-email-id' + + with ( + patch.dict( + os.environ, + {'RESEND_API_KEY': 'test-key', 'WEB_HOST': custom_host}, + ), + patch('server.services.email_service.RESEND_AVAILABLE', True), + patch('server.services.email_service.resend') as mock_resend, + ): + mock_resend.Emails.send.return_value = mock_response + + EmailService.send_invitation_email( + to_email='test@example.com', + org_name='Test Org', + inviter_name='Inviter', + role_name='member', + invitation_token='inv-test-token-12345', + invitation_id=1, + ) + + call_args = mock_resend.Emails.send.call_args + email_params = call_args[0][0] + + expected_url = f'{custom_host}/api/organizations/members/invite/accept?token=inv-test-token-12345' + assert expected_url in email_params['html'] + + def test_invitation_url_uses_default_host_when_env_not_set(self): + """Test that invitation URL falls back to DEFAULT_WEB_HOST when env not set.""" + mock_response = MagicMock() + mock_response.get.return_value = 'test-email-id' + + env_without_web_host = {'RESEND_API_KEY': 'test-key'} + # Remove WEB_HOST if it exists + env_without_web_host.pop('WEB_HOST', None) + + with ( + patch.dict(os.environ, env_without_web_host, clear=True), + patch('server.services.email_service.RESEND_AVAILABLE', True), + patch('server.services.email_service.resend') as mock_resend, + ): + # Clear WEB_HOST from the environment + os.environ.pop('WEB_HOST', None) + mock_resend.Emails.send.return_value = mock_response + + EmailService.send_invitation_email( + to_email='test@example.com', + org_name='Test Org', + inviter_name='Inviter', + role_name='member', + invitation_token='inv-test-token-12345', + invitation_id=1, + ) + + call_args = mock_resend.Emails.send.call_args + email_params = call_args[0][0] + + expected_url = f'{DEFAULT_WEB_HOST}/api/organizations/members/invite/accept?token=inv-test-token-12345' + assert expected_url in email_params['html'] + + +class TestEmailServiceGetResendClient: + """Test cases for Resend client initialization.""" + + def test_get_resend_client_returns_false_when_resend_not_available(self): + """Test that _get_resend_client returns False when resend is not installed.""" + with patch('server.services.email_service.RESEND_AVAILABLE', False): + result = EmailService._get_resend_client() + assert result is False + + def test_get_resend_client_returns_false_when_api_key_not_configured(self): + """Test that _get_resend_client returns False when API key is missing.""" + with ( + patch('server.services.email_service.RESEND_AVAILABLE', True), + patch.dict(os.environ, {}, clear=True), + ): + os.environ.pop('RESEND_API_KEY', None) + result = EmailService._get_resend_client() + assert result is False + + def test_get_resend_client_returns_true_when_configured(self): + """Test that _get_resend_client returns True when properly configured.""" + with ( + patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}), + patch('server.services.email_service.RESEND_AVAILABLE', True), + patch('server.services.email_service.resend') as mock_resend, + ): + result = EmailService._get_resend_client() + assert result is True + assert mock_resend.api_key == 'test-key' + + +class TestEmailServiceSendInvitationEmail: + """Test cases for send_invitation_email method.""" + + def test_send_invitation_email_skips_when_client_not_ready(self): + """Test that email sending is skipped when client is not ready.""" + with patch.object( + EmailService, '_get_resend_client', return_value=False + ) as mock_get_client: + # Should not raise, just return early + EmailService.send_invitation_email( + to_email='test@example.com', + org_name='Test Org', + inviter_name='Inviter', + role_name='member', + invitation_token='inv-test-token', + invitation_id=1, + ) + + mock_get_client.assert_called_once() + + def test_send_invitation_email_includes_all_required_info(self): + """Test that invitation email includes org name, inviter name, and role.""" + mock_response = MagicMock() + mock_response.get.return_value = 'test-email-id' + + with ( + patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}), + patch('server.services.email_service.RESEND_AVAILABLE', True), + patch('server.services.email_service.resend') as mock_resend, + ): + mock_resend.Emails.send.return_value = mock_response + + EmailService.send_invitation_email( + to_email='test@example.com', + org_name='Acme Corp', + inviter_name='John Doe', + role_name='admin', + invitation_token='inv-test-token-12345', + invitation_id=42, + ) + + call_args = mock_resend.Emails.send.call_args + email_params = call_args[0][0] + + # Verify email content + assert email_params['to'] == ['test@example.com'] + assert 'Acme Corp' in email_params['subject'] + assert 'John Doe' in email_params['html'] + assert 'Acme Corp' in email_params['html'] + assert 'admin' in email_params['html'] diff --git a/enterprise/tests/unit/test_org_invitation_service.py b/enterprise/tests/unit/test_org_invitation_service.py new file mode 100644 index 0000000000..06c0d258ed --- /dev/null +++ b/enterprise/tests/unit/test_org_invitation_service.py @@ -0,0 +1,464 @@ +"""Tests for organization invitation service - email validation.""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import UUID + +import pytest +from server.routes.org_invitation_models import ( + EmailMismatchError, +) +from server.services.org_invitation_service import OrgInvitationService +from storage.org_invitation import OrgInvitation + + +class TestAcceptInvitationEmailValidation: + """Test cases for email validation during invitation acceptance.""" + + @pytest.fixture + def mock_invitation(self): + """Create a mock invitation with pending status.""" + invitation = MagicMock(spec=OrgInvitation) + invitation.id = 1 + invitation.email = 'alice@example.com' + invitation.status = OrgInvitation.STATUS_PENDING + invitation.org_id = UUID('12345678-1234-5678-1234-567812345678') + invitation.role_id = 1 + return invitation + + @pytest.fixture + def mock_user(self): + """Create a mock user with email.""" + user = MagicMock() + user.id = UUID('87654321-4321-8765-4321-876543218765') + user.email = 'alice@example.com' + return user + + @pytest.mark.asyncio + async def test_accept_invitation_email_matches(self, mock_invitation, mock_user): + """Test that invitation is accepted when user email matches invitation email.""" + # Arrange + user_id = mock_user.id + token = 'inv-test-token-12345' + + with patch.object( + OrgInvitationService, 'accept_invitation', new_callable=AsyncMock + ) as mock_accept: + mock_accept.return_value = mock_invitation + + # Act + await OrgInvitationService.accept_invitation(token, user_id) + + # Assert + mock_accept.assert_called_once_with(token, user_id) + + @pytest.mark.asyncio + async def test_accept_invitation_email_mismatch_raises_error( + self, mock_invitation, mock_user + ): + """Test that EmailMismatchError is raised when emails don't match.""" + # Arrange + user_id = mock_user.id + token = 'inv-test-token-12345' + mock_user.email = 'bob@example.com' # Different email + + with ( + patch( + 'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token', + new_callable=AsyncMock, + ) as mock_get_invitation, + patch( + 'server.services.org_invitation_service.OrgInvitationStore.is_token_expired' + ) as mock_is_expired, + patch( + 'server.services.org_invitation_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, + ) as mock_get_user, + ): + mock_get_invitation.return_value = mock_invitation + mock_is_expired.return_value = False + mock_get_user.return_value = mock_user + + # Act & Assert + with pytest.raises(EmailMismatchError): + await OrgInvitationService.accept_invitation(token, user_id) + + @pytest.mark.asyncio + async def test_accept_invitation_user_no_email_keycloak_fallback_matches( + self, mock_invitation + ): + """Test that Keycloak email is used when user has no email in database.""" + # Arrange + user_id = UUID('87654321-4321-8765-4321-876543218765') + token = 'inv-test-token-12345' + + mock_user = MagicMock() + mock_user.id = user_id + mock_user.email = None # No email in database + + mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak + + with ( + patch( + 'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token', + new_callable=AsyncMock, + ) as mock_get_invitation, + patch( + 'server.services.org_invitation_service.OrgInvitationStore.is_token_expired' + ) as mock_is_expired, + patch( + 'server.services.org_invitation_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, + ) as mock_get_user, + patch( + 'server.services.org_invitation_service.TokenManager' + ) as mock_token_manager_class, + patch( + 'server.services.org_invitation_service.OrgMemberStore.get_org_member' + ) as mock_get_member, + patch( + 'server.services.org_invitation_service.OrgService.create_litellm_integration', + new_callable=AsyncMock, + ) as mock_create_litellm, + patch( + 'server.services.org_invitation_service.OrgMemberStore.add_user_to_org' + ), + patch( + 'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status', + new_callable=AsyncMock, + ) as mock_update_status, + ): + mock_get_invitation.return_value = mock_invitation + mock_is_expired.return_value = False + mock_get_user.return_value = mock_user + + # Mock TokenManager instance + mock_token_manager = MagicMock() + mock_token_manager.get_user_info_from_user_id = AsyncMock( + return_value=mock_keycloak_user_info + ) + mock_token_manager_class.return_value = mock_token_manager + + mock_get_member.return_value = None # Not already a member + mock_create_litellm.return_value = MagicMock(llm_api_key='test-key') + mock_update_status.return_value = mock_invitation + + # Act - should not raise error because Keycloak email matches + await OrgInvitationService.accept_invitation(token, user_id) + + # Assert + mock_token_manager.get_user_info_from_user_id.assert_called_once_with( + str(user_id) + ) + + @pytest.mark.asyncio + async def test_accept_invitation_no_email_anywhere_raises_error( + self, mock_invitation + ): + """Test that EmailMismatchError is raised when user has no email in database or Keycloak.""" + # Arrange + user_id = UUID('87654321-4321-8765-4321-876543218765') + token = 'inv-test-token-12345' + + mock_user = MagicMock() + mock_user.id = user_id + mock_user.email = None # No email in database + + with ( + patch( + 'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token', + new_callable=AsyncMock, + ) as mock_get_invitation, + patch( + 'server.services.org_invitation_service.OrgInvitationStore.is_token_expired' + ) as mock_is_expired, + patch( + 'server.services.org_invitation_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, + ) as mock_get_user, + patch( + 'server.services.org_invitation_service.TokenManager' + ) as mock_token_manager_class, + ): + mock_get_invitation.return_value = mock_invitation + mock_is_expired.return_value = False + mock_get_user.return_value = mock_user + + # Mock TokenManager to return no email + mock_token_manager = MagicMock() + mock_token_manager.get_user_info_from_user_id = AsyncMock(return_value={}) + mock_token_manager_class.return_value = mock_token_manager + + # Act & Assert + with pytest.raises(EmailMismatchError) as exc_info: + await OrgInvitationService.accept_invitation(token, user_id) + + assert 'does not have an email address' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_accept_invitation_email_comparison_is_case_insensitive( + self, mock_invitation + ): + """Test that email comparison is case insensitive.""" + # Arrange + user_id = UUID('87654321-4321-8765-4321-876543218765') + token = 'inv-test-token-12345' + + mock_user = MagicMock() + mock_user.id = user_id + mock_user.email = 'ALICE@EXAMPLE.COM' # Uppercase email + + mock_invitation.email = 'alice@example.com' # Lowercase in invitation + + with ( + patch( + 'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token', + new_callable=AsyncMock, + ) as mock_get_invitation, + patch( + 'server.services.org_invitation_service.OrgInvitationStore.is_token_expired' + ) as mock_is_expired, + patch( + 'server.services.org_invitation_service.UserStore.get_user_by_id_async', + new_callable=AsyncMock, + ) as mock_get_user, + patch( + 'server.services.org_invitation_service.OrgMemberStore.get_org_member' + ) as mock_get_member, + patch( + 'server.services.org_invitation_service.OrgService.create_litellm_integration', + new_callable=AsyncMock, + ) as mock_create_litellm, + patch( + 'server.services.org_invitation_service.OrgMemberStore.add_user_to_org' + ), + patch( + 'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status', + new_callable=AsyncMock, + ) as mock_update_status, + ): + mock_get_invitation.return_value = mock_invitation + mock_is_expired.return_value = False + mock_get_user.return_value = mock_user + mock_get_member.return_value = None + mock_create_litellm.return_value = MagicMock(llm_api_key='test-key') + mock_update_status.return_value = mock_invitation + + # Act - should not raise error because emails match case-insensitively + await OrgInvitationService.accept_invitation(token, user_id) + + # Assert - invitation was accepted (update_invitation_status was called) + mock_update_status.assert_called_once() + + +class TestCreateInvitationsBatch: + """Test cases for batch invitation creation.""" + + @pytest.fixture + def org_id(self): + """Organization UUID for testing.""" + return UUID('12345678-1234-5678-1234-567812345678') + + @pytest.fixture + def inviter_id(self): + """Inviter UUID for testing.""" + return UUID('87654321-4321-8765-4321-876543218765') + + @pytest.fixture + def mock_org(self): + """Create a mock organization.""" + org = MagicMock() + org.id = UUID('12345678-1234-5678-1234-567812345678') + org.name = 'Test Org' + return org + + @pytest.fixture + def mock_inviter_member(self): + """Create a mock inviter member with owner role.""" + member = MagicMock() + member.user_id = UUID('87654321-4321-8765-4321-876543218765') + member.role_id = 1 + return member + + @pytest.fixture + def mock_owner_role(self): + """Create a mock owner role.""" + role = MagicMock() + role.id = 1 + role.name = 'owner' + return role + + @pytest.fixture + def mock_member_role(self): + """Create a mock member role.""" + role = MagicMock() + role.id = 3 + role.name = 'member' + return role + + @pytest.mark.asyncio + async def test_batch_creates_all_invitations_successfully( + self, + org_id, + inviter_id, + mock_org, + mock_inviter_member, + mock_owner_role, + mock_member_role, + ): + """Test that batch creation succeeds for all valid emails.""" + # Arrange + emails = ['alice@example.com', 'bob@example.com'] + mock_invitation_1 = MagicMock(spec=OrgInvitation) + mock_invitation_1.id = 1 + mock_invitation_2 = MagicMock(spec=OrgInvitation) + mock_invitation_2.id = 2 + + with ( + patch( + 'server.services.org_invitation_service.OrgStore.get_org_by_id', + return_value=mock_org, + ), + patch( + 'server.services.org_invitation_service.OrgMemberStore.get_org_member', + return_value=mock_inviter_member, + ), + patch( + 'server.services.org_invitation_service.RoleStore.get_role_by_id', + return_value=mock_owner_role, + ), + patch( + 'server.services.org_invitation_service.RoleStore.get_role_by_name', + return_value=mock_member_role, + ), + patch.object( + OrgInvitationService, + 'create_invitation', + new_callable=AsyncMock, + side_effect=[mock_invitation_1, mock_invitation_2], + ), + ): + # Act + successful, failed = await OrgInvitationService.create_invitations_batch( + org_id=org_id, + emails=emails, + role_name='member', + inviter_id=inviter_id, + ) + + # Assert + assert len(successful) == 2 + assert len(failed) == 0 + + @pytest.mark.asyncio + async def test_batch_handles_partial_success( + self, + org_id, + inviter_id, + mock_org, + mock_inviter_member, + mock_owner_role, + mock_member_role, + ): + """Test that batch returns partial results when some emails fail.""" + # Arrange + from server.routes.org_invitation_models import UserAlreadyMemberError + + emails = ['alice@example.com', 'existing@example.com'] + mock_invitation = MagicMock(spec=OrgInvitation) + mock_invitation.id = 1 + + with ( + patch( + 'server.services.org_invitation_service.OrgStore.get_org_by_id', + return_value=mock_org, + ), + patch( + 'server.services.org_invitation_service.OrgMemberStore.get_org_member', + return_value=mock_inviter_member, + ), + patch( + 'server.services.org_invitation_service.RoleStore.get_role_by_id', + return_value=mock_owner_role, + ), + patch( + 'server.services.org_invitation_service.RoleStore.get_role_by_name', + return_value=mock_member_role, + ), + patch.object( + OrgInvitationService, + 'create_invitation', + new_callable=AsyncMock, + side_effect=[mock_invitation, UserAlreadyMemberError()], + ), + ): + # Act + successful, failed = await OrgInvitationService.create_invitations_batch( + org_id=org_id, + emails=emails, + role_name='member', + inviter_id=inviter_id, + ) + + # Assert + assert len(successful) == 1 + assert len(failed) == 1 + assert failed[0][0] == 'existing@example.com' + + @pytest.mark.asyncio + async def test_batch_fails_entirely_on_permission_error(self, org_id, inviter_id): + """Test that permission error fails the entire batch upfront.""" + # Arrange + + emails = ['alice@example.com', 'bob@example.com'] + + with patch( + 'server.services.org_invitation_service.OrgStore.get_org_by_id', + return_value=None, # Organization not found + ): + # Act & Assert + with pytest.raises(ValueError) as exc_info: + await OrgInvitationService.create_invitations_batch( + org_id=org_id, + emails=emails, + role_name='member', + inviter_id=inviter_id, + ) + + assert 'not found' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_batch_fails_on_invalid_role( + self, org_id, inviter_id, mock_org, mock_inviter_member, mock_owner_role + ): + """Test that invalid role fails the entire batch.""" + # Arrange + emails = ['alice@example.com'] + + with ( + patch( + 'server.services.org_invitation_service.OrgStore.get_org_by_id', + return_value=mock_org, + ), + patch( + 'server.services.org_invitation_service.OrgMemberStore.get_org_member', + return_value=mock_inviter_member, + ), + patch( + 'server.services.org_invitation_service.RoleStore.get_role_by_id', + return_value=mock_owner_role, + ), + patch( + 'server.services.org_invitation_service.RoleStore.get_role_by_name', + return_value=None, # Invalid role + ), + ): + # Act & Assert + with pytest.raises(ValueError) as exc_info: + await OrgInvitationService.create_invitations_batch( + org_id=org_id, + emails=emails, + role_name='invalid_role', + inviter_id=inviter_id, + ) + + assert 'Invalid role' in str(exc_info.value) diff --git a/enterprise/tests/unit/test_org_invitation_store.py b/enterprise/tests/unit/test_org_invitation_store.py new file mode 100644 index 0000000000..304922d569 --- /dev/null +++ b/enterprise/tests/unit/test_org_invitation_store.py @@ -0,0 +1,308 @@ +"""Tests for organization invitation store.""" + +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from storage.org_invitation import OrgInvitation +from storage.org_invitation_store import ( + INVITATION_TOKEN_LENGTH, + INVITATION_TOKEN_PREFIX, + OrgInvitationStore, +) + + +class TestGenerateToken: + """Test cases for token generation.""" + + def test_generate_token_has_correct_prefix(self): + """Test that generated tokens have the correct prefix.""" + token = OrgInvitationStore.generate_token() + assert token.startswith(INVITATION_TOKEN_PREFIX) + + def test_generate_token_has_correct_length(self): + """Test that generated tokens have the correct total length.""" + token = OrgInvitationStore.generate_token() + expected_length = len(INVITATION_TOKEN_PREFIX) + INVITATION_TOKEN_LENGTH + assert len(token) == expected_length + + def test_generate_token_uses_alphanumeric_characters(self): + """Test that generated tokens use only alphanumeric characters.""" + token = OrgInvitationStore.generate_token() + # Remove prefix and check the rest is alphanumeric + random_part = token[len(INVITATION_TOKEN_PREFIX) :] + assert random_part.isalnum() + + def test_generate_token_is_unique(self): + """Test that generated tokens are unique (probabilistically).""" + tokens = [OrgInvitationStore.generate_token() for _ in range(100)] + assert len(set(tokens)) == 100 + + +class TestIsTokenExpired: + """Test cases for token expiration checking.""" + + def test_token_not_expired_when_future(self): + """Test that tokens with future expiration are not expired.""" + invitation = MagicMock(spec=OrgInvitation) + invitation.expires_at = datetime.utcnow() + timedelta(days=1) + + result = OrgInvitationStore.is_token_expired(invitation) + assert result is False + + def test_token_expired_when_past(self): + """Test that tokens with past expiration are expired.""" + invitation = MagicMock(spec=OrgInvitation) + invitation.expires_at = datetime.utcnow() - timedelta(seconds=1) + + result = OrgInvitationStore.is_token_expired(invitation) + assert result is True + + def test_token_expired_at_exact_boundary(self): + """Test that tokens at exact expiration time are expired.""" + # A token that expires "now" should be expired + now = datetime.utcnow() + invitation = MagicMock(spec=OrgInvitation) + invitation.expires_at = now - timedelta(microseconds=1) + + result = OrgInvitationStore.is_token_expired(invitation) + assert result is True + + +class TestCreateInvitation: + """Test cases for invitation creation.""" + + @pytest.mark.asyncio + async def test_create_invitation_normalizes_email(self): + """Test that email is normalized (lowercase, stripped) on creation.""" + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + mock_session.execute = AsyncMock() + + # Mock the result of the re-fetch query + mock_result = MagicMock() + mock_invitation = MagicMock() + mock_invitation.id = 1 + mock_invitation.email = 'test@example.com' + mock_result.scalars.return_value.first.return_value = mock_invitation + mock_session.execute.return_value = mock_result + + with patch( + 'storage.org_invitation_store.a_session_maker' + ) as mock_session_maker: + mock_session_manager = AsyncMock() + mock_session_manager.__aenter__.return_value = mock_session + mock_session_manager.__aexit__.return_value = None + mock_session_maker.return_value = mock_session_manager + + from uuid import UUID + + await OrgInvitationStore.create_invitation( + org_id=UUID('12345678-1234-5678-1234-567812345678'), + email=' TEST@EXAMPLE.COM ', + role_id=1, + inviter_id=UUID('87654321-4321-8765-4321-876543218765'), + ) + + # Verify that the OrgInvitation was created with normalized email + add_call = mock_session.add.call_args + created_invitation = add_call[0][0] + assert created_invitation.email == 'test@example.com' + + +class TestGetInvitationByToken: + """Test cases for getting invitation by token.""" + + @pytest.mark.asyncio + async def test_get_invitation_by_token_returns_invitation(self): + """Test that get_invitation_by_token returns the invitation when found.""" + mock_invitation = MagicMock(spec=OrgInvitation) + mock_invitation.token = 'inv-test-token-12345' + + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_invitation + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch( + 'storage.org_invitation_store.a_session_maker' + ) as mock_session_maker: + mock_session_manager = AsyncMock() + mock_session_manager.__aenter__.return_value = mock_session + mock_session_manager.__aexit__.return_value = None + mock_session_maker.return_value = mock_session_manager + + result = await OrgInvitationStore.get_invitation_by_token( + 'inv-test-token-12345' + ) + assert result == mock_invitation + + @pytest.mark.asyncio + async def test_get_invitation_by_token_returns_none_when_not_found(self): + """Test that get_invitation_by_token returns None when not found.""" + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch( + 'storage.org_invitation_store.a_session_maker' + ) as mock_session_maker: + mock_session_manager = AsyncMock() + mock_session_manager.__aenter__.return_value = mock_session + mock_session_manager.__aexit__.return_value = None + mock_session_maker.return_value = mock_session_manager + + result = await OrgInvitationStore.get_invitation_by_token( + 'inv-nonexistent-token' + ) + assert result is None + + +class TestGetPendingInvitation: + """Test cases for getting pending invitation.""" + + @pytest.mark.asyncio + async def test_get_pending_invitation_normalizes_email(self): + """Test that email is normalized when querying for pending invitations.""" + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch( + 'storage.org_invitation_store.a_session_maker' + ) as mock_session_maker: + mock_session_manager = AsyncMock() + mock_session_manager.__aenter__.return_value = mock_session + mock_session_manager.__aexit__.return_value = None + mock_session_maker.return_value = mock_session_manager + + from uuid import UUID + + await OrgInvitationStore.get_pending_invitation( + org_id=UUID('12345678-1234-5678-1234-567812345678'), + email=' TEST@EXAMPLE.COM ', + ) + + # Verify the query was called (email normalization happens in the filter) + assert mock_session.execute.called + + +class TestUpdateInvitationStatus: + """Test cases for updating invitation status.""" + + @pytest.mark.asyncio + async def test_update_status_sets_accepted_at_for_accepted(self): + """Test that accepted_at is set when status is accepted.""" + from uuid import UUID + + mock_invitation = MagicMock(spec=OrgInvitation) + mock_invitation.id = 1 + mock_invitation.status = OrgInvitation.STATUS_PENDING + + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_invitation + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.commit = AsyncMock() + mock_session.refresh = AsyncMock() + + with patch( + 'storage.org_invitation_store.a_session_maker' + ) as mock_session_maker: + mock_session_manager = AsyncMock() + mock_session_manager.__aenter__.return_value = mock_session + mock_session_manager.__aexit__.return_value = None + mock_session_maker.return_value = mock_session_manager + + user_id = UUID('87654321-4321-8765-4321-876543218765') + await OrgInvitationStore.update_invitation_status( + invitation_id=1, + status=OrgInvitation.STATUS_ACCEPTED, + accepted_by_user_id=user_id, + ) + + assert mock_invitation.accepted_at is not None + assert mock_invitation.accepted_by_user_id == user_id + + @pytest.mark.asyncio + async def test_update_status_returns_none_when_not_found(self): + """Test that update returns None when invitation not found.""" + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch( + 'storage.org_invitation_store.a_session_maker' + ) as mock_session_maker: + mock_session_manager = AsyncMock() + mock_session_manager.__aenter__.return_value = mock_session + mock_session_manager.__aexit__.return_value = None + mock_session_maker.return_value = mock_session_manager + + result = await OrgInvitationStore.update_invitation_status( + invitation_id=999, + status=OrgInvitation.STATUS_ACCEPTED, + ) + assert result is None + + +class TestMarkExpiredIfNeeded: + """Test cases for marking expired invitations.""" + + @pytest.mark.asyncio + async def test_marks_expired_when_pending_and_past_expiry(self): + """Test that pending expired invitations are marked as expired.""" + mock_invitation = MagicMock(spec=OrgInvitation) + mock_invitation.id = 1 + mock_invitation.status = OrgInvitation.STATUS_PENDING + mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1) + + with patch.object( + OrgInvitationStore, + 'update_invitation_status', + new_callable=AsyncMock, + ) as mock_update: + result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation) + + assert result is True + mock_update.assert_called_once_with(1, OrgInvitation.STATUS_EXPIRED) + + @pytest.mark.asyncio + async def test_does_not_mark_when_not_expired(self): + """Test that non-expired invitations are not marked.""" + mock_invitation = MagicMock(spec=OrgInvitation) + mock_invitation.id = 1 + mock_invitation.status = OrgInvitation.STATUS_PENDING + mock_invitation.expires_at = datetime.utcnow() + timedelta(days=1) + + with patch.object( + OrgInvitationStore, + 'update_invitation_status', + new_callable=AsyncMock, + ) as mock_update: + result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation) + + assert result is False + mock_update.assert_not_called() + + @pytest.mark.asyncio + async def test_does_not_mark_when_not_pending(self): + """Test that non-pending invitations are not marked even if expired.""" + mock_invitation = MagicMock(spec=OrgInvitation) + mock_invitation.id = 1 + mock_invitation.status = OrgInvitation.STATUS_ACCEPTED + mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1) + + with patch.object( + OrgInvitationStore, + 'update_invitation_status', + new_callable=AsyncMock, + ) as mock_update: + result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation) + + assert result is False + mock_update.assert_not_called() diff --git a/enterprise/tests/unit/test_org_invitations_router.py b/enterprise/tests/unit/test_org_invitations_router.py new file mode 100644 index 0000000000..cbfbf4c810 --- /dev/null +++ b/enterprise/tests/unit/test_org_invitations_router.py @@ -0,0 +1,388 @@ +"""Tests for organization invitations API router.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from server.routes.org_invitation_models import ( + EmailMismatchError, + InvitationExpiredError, + InvitationInvalidError, + UserAlreadyMemberError, +) +from server.routes.org_invitations import accept_router, invitation_router + + +@pytest.fixture +def app(): + """Create a FastAPI app with the invitation routers.""" + app = FastAPI() + app.include_router(invitation_router) + app.include_router(accept_router) + return app + + +@pytest.fixture +def client(app): + """Create a test client for the app.""" + return TestClient(app) + + +class TestRouterPrefixes: + """Test that router prefixes are configured correctly.""" + + def test_invitation_router_has_correct_prefix(self): + """Test that invitation_router has /api/organizations/{org_id}/members prefix.""" + assert invitation_router.prefix == '/api/organizations/{org_id}/members' + + def test_accept_router_has_correct_prefix(self): + """Test that accept_router has /api/organizations/members/invite prefix.""" + assert accept_router.prefix == '/api/organizations/members/invite' + + +class TestAcceptInvitationEndpoint: + """Test cases for the accept invitation endpoint.""" + + @pytest.fixture + def mock_user_auth(self): + """Create a mock user auth.""" + user_auth = MagicMock() + user_auth.get_user_id = AsyncMock( + return_value='87654321-4321-8765-4321-876543218765' + ) + return user_auth + + @pytest.mark.asyncio + async def test_accept_unauthenticated_redirects_to_login(self, client): + """Test that unauthenticated users are redirected to login with invitation token.""" + with patch( + 'server.routes.org_invitations.get_user_auth', + new_callable=AsyncMock, + return_value=None, + ): + response = client.get( + '/api/organizations/members/invite/accept?token=inv-test-token-123', + follow_redirects=False, + ) + + assert response.status_code == 302 + assert '/login?invitation_token=inv-test-token-123' in response.headers.get( + 'location', '' + ) + + @pytest.mark.asyncio + async def test_accept_authenticated_success_redirects_home( + self, client, mock_user_auth + ): + """Test that successful acceptance redirects to home page.""" + mock_invitation = MagicMock() + + with ( + patch( + 'server.routes.org_invitations.get_user_auth', + new_callable=AsyncMock, + return_value=mock_user_auth, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.accept_invitation', + new_callable=AsyncMock, + return_value=mock_invitation, + ), + ): + response = client.get( + '/api/organizations/members/invite/accept?token=inv-test-token-123', + follow_redirects=False, + ) + + assert response.status_code == 302 + location = response.headers.get('location', '') + assert location.endswith('/') + assert 'invitation_expired' not in location + assert 'invitation_invalid' not in location + assert 'email_mismatch' not in location + + @pytest.mark.asyncio + async def test_accept_expired_invitation_redirects_with_flag( + self, client, mock_user_auth + ): + """Test that expired invitation redirects with invitation_expired=true.""" + with ( + patch( + 'server.routes.org_invitations.get_user_auth', + new_callable=AsyncMock, + return_value=mock_user_auth, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.accept_invitation', + new_callable=AsyncMock, + side_effect=InvitationExpiredError(), + ), + ): + response = client.get( + '/api/organizations/members/invite/accept?token=inv-test-token-123', + follow_redirects=False, + ) + + assert response.status_code == 302 + assert 'invitation_expired=true' in response.headers.get('location', '') + + @pytest.mark.asyncio + async def test_accept_invalid_invitation_redirects_with_flag( + self, client, mock_user_auth + ): + """Test that invalid invitation redirects with invitation_invalid=true.""" + with ( + patch( + 'server.routes.org_invitations.get_user_auth', + new_callable=AsyncMock, + return_value=mock_user_auth, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.accept_invitation', + new_callable=AsyncMock, + side_effect=InvitationInvalidError(), + ), + ): + response = client.get( + '/api/organizations/members/invite/accept?token=inv-test-token-123', + follow_redirects=False, + ) + + assert response.status_code == 302 + assert 'invitation_invalid=true' in response.headers.get('location', '') + + @pytest.mark.asyncio + async def test_accept_already_member_redirects_with_flag( + self, client, mock_user_auth + ): + """Test that already member error redirects with already_member=true.""" + with ( + patch( + 'server.routes.org_invitations.get_user_auth', + new_callable=AsyncMock, + return_value=mock_user_auth, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.accept_invitation', + new_callable=AsyncMock, + side_effect=UserAlreadyMemberError(), + ), + ): + response = client.get( + '/api/organizations/members/invite/accept?token=inv-test-token-123', + follow_redirects=False, + ) + + assert response.status_code == 302 + assert 'already_member=true' in response.headers.get('location', '') + + @pytest.mark.asyncio + async def test_accept_email_mismatch_redirects_with_flag( + self, client, mock_user_auth + ): + """Test that email mismatch error redirects with email_mismatch=true.""" + with ( + patch( + 'server.routes.org_invitations.get_user_auth', + new_callable=AsyncMock, + return_value=mock_user_auth, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.accept_invitation', + new_callable=AsyncMock, + side_effect=EmailMismatchError(), + ), + ): + response = client.get( + '/api/organizations/members/invite/accept?token=inv-test-token-123', + follow_redirects=False, + ) + + assert response.status_code == 302 + assert 'email_mismatch=true' in response.headers.get('location', '') + + @pytest.mark.asyncio + async def test_accept_unexpected_error_redirects_with_flag( + self, client, mock_user_auth + ): + """Test that unexpected errors redirect with invitation_error=true.""" + with ( + patch( + 'server.routes.org_invitations.get_user_auth', + new_callable=AsyncMock, + return_value=mock_user_auth, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.accept_invitation', + new_callable=AsyncMock, + side_effect=Exception('Unexpected error'), + ), + ): + response = client.get( + '/api/organizations/members/invite/accept?token=inv-test-token-123', + follow_redirects=False, + ) + + assert response.status_code == 302 + assert 'invitation_error=true' in response.headers.get('location', '') + + +class TestCreateInvitationBatchEndpoint: + """Test cases for the batch invitation creation endpoint.""" + + @pytest.fixture + def batch_app(self): + """Create a FastAPI app with dependency overrides for batch tests.""" + from openhands.server.user_auth import get_user_id + + app = FastAPI() + app.include_router(invitation_router) + + # Override the get_user_id dependency + app.dependency_overrides[get_user_id] = ( + lambda: '87654321-4321-8765-4321-876543218765' + ) + + return app + + @pytest.fixture + def batch_client(self, batch_app): + """Create a test client with dependency overrides.""" + return TestClient(batch_app) + + @pytest.fixture + def mock_invitation(self): + """Create a mock invitation.""" + from datetime import datetime + + invitation = MagicMock() + invitation.id = 1 + invitation.email = 'alice@example.com' + invitation.role = MagicMock(name='member') + invitation.role.name = 'member' + invitation.role_id = 3 + invitation.status = 'pending' + invitation.created_at = datetime(2026, 2, 17, 10, 0, 0) + invitation.expires_at = datetime(2026, 2, 24, 10, 0, 0) + return invitation + + @pytest.mark.asyncio + async def test_batch_create_returns_successful_invitations( + self, batch_client, mock_invitation + ): + """Test that batch creation returns successful invitations.""" + mock_invitation_2 = MagicMock() + mock_invitation_2.id = 2 + mock_invitation_2.email = 'bob@example.com' + mock_invitation_2.role = MagicMock() + mock_invitation_2.role.name = 'member' + mock_invitation_2.role_id = 3 + mock_invitation_2.status = 'pending' + mock_invitation_2.created_at = mock_invitation.created_at + mock_invitation_2.expires_at = mock_invitation.expires_at + + with ( + patch( + 'server.routes.org_invitations.check_rate_limit_by_user_id', + new_callable=AsyncMock, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.create_invitations_batch', + new_callable=AsyncMock, + return_value=([mock_invitation, mock_invitation_2], []), + ), + ): + response = batch_client.post( + '/api/organizations/12345678-1234-5678-1234-567812345678/members/invite', + json={ + 'emails': ['alice@example.com', 'bob@example.com'], + 'role': 'member', + }, + ) + + assert response.status_code == 201 + data = response.json() + assert len(data['successful']) == 2 + assert len(data['failed']) == 0 + + @pytest.mark.asyncio + async def test_batch_create_returns_partial_success( + self, batch_client, mock_invitation + ): + """Test that batch creation returns both successful and failed invitations.""" + failed_emails = [('existing@example.com', 'User is already a member')] + + with ( + patch( + 'server.routes.org_invitations.check_rate_limit_by_user_id', + new_callable=AsyncMock, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.create_invitations_batch', + new_callable=AsyncMock, + return_value=([mock_invitation], failed_emails), + ), + ): + response = batch_client.post( + '/api/organizations/12345678-1234-5678-1234-567812345678/members/invite', + json={ + 'emails': ['alice@example.com', 'existing@example.com'], + 'role': 'member', + }, + ) + + assert response.status_code == 201 + data = response.json() + assert len(data['successful']) == 1 + assert len(data['failed']) == 1 + assert data['failed'][0]['email'] == 'existing@example.com' + assert 'already a member' in data['failed'][0]['error'] + + @pytest.mark.asyncio + async def test_batch_create_permission_denied_returns_403(self, batch_client): + """Test that permission denied returns 403 for entire batch.""" + from server.routes.org_invitation_models import InsufficientPermissionError + + with ( + patch( + 'server.routes.org_invitations.check_rate_limit_by_user_id', + new_callable=AsyncMock, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.create_invitations_batch', + new_callable=AsyncMock, + side_effect=InsufficientPermissionError( + 'Only owners and admins can invite' + ), + ), + ): + response = batch_client.post( + '/api/organizations/12345678-1234-5678-1234-567812345678/members/invite', + json={'emails': ['alice@example.com'], 'role': 'member'}, + ) + + assert response.status_code == 403 + assert 'owners and admins' in response.json()['detail'] + + @pytest.mark.asyncio + async def test_batch_create_invalid_role_returns_400(self, batch_client): + """Test that invalid role returns 400.""" + with ( + patch( + 'server.routes.org_invitations.check_rate_limit_by_user_id', + new_callable=AsyncMock, + ), + patch( + 'server.routes.org_invitations.OrgInvitationService.create_invitations_batch', + new_callable=AsyncMock, + side_effect=ValueError('Invalid role: superuser'), + ), + ): + response = batch_client.post( + '/api/organizations/12345678-1234-5678-1234-567812345678/members/invite', + json={'emails': ['alice@example.com'], 'role': 'superuser'}, + ) + + assert response.status_code == 400 + assert 'Invalid role' in response.json()['detail'] diff --git a/frontend/__tests__/components/features/auth/login-content.test.tsx b/frontend/__tests__/components/features/auth/login-content.test.tsx index ae4459980f..1681acb873 100644 --- a/frontend/__tests__/components/features/auth/login-content.test.tsx +++ b/frontend/__tests__/components/features/auth/login-content.test.tsx @@ -151,8 +151,9 @@ describe("LoginContent", () => { await user.click(githubButton); // Wait for async handleAuthRedirect to complete + // The URL includes state parameter added by handleAuthRedirect await waitFor(() => { - expect(window.location.href).toBe(mockUrl); + expect(window.location.href).toContain(mockUrl); }); }); @@ -201,4 +202,103 @@ describe("LoginContent", () => { expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument(); }); + + it("should display invitation pending message when hasInvitation is true", () => { + render( + + + , + ); + + expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument(); + }); + + it("should not display invitation pending message when hasInvitation is false", () => { + render( + + + , + ); + + expect( + screen.queryByText("AUTH$INVITATION_PENDING"), + ).not.toBeInTheDocument(); + }); + + it("should call buildOAuthStateData when clicking auth button", async () => { + const user = userEvent.setup(); + const mockBuildOAuthStateData = vi.fn((baseState) => ({ + ...baseState, + invitation_token: "inv-test-token-12345", + })); + + render( + + + , + ); + + const githubButton = screen.getByRole("button", { + name: "GITHUB$CONNECT_TO_GITHUB", + }); + await user.click(githubButton); + + await waitFor(() => { + expect(mockBuildOAuthStateData).toHaveBeenCalled(); + const callArg = mockBuildOAuthStateData.mock.calls[0][0]; + expect(callArg).toHaveProperty("redirect_url"); + }); + }); + + it("should encode state with invitation token when buildOAuthStateData provides token", async () => { + const user = userEvent.setup(); + const mockBuildOAuthStateData = vi.fn((baseState) => ({ + ...baseState, + invitation_token: "inv-test-token-12345", + })); + + render( + + + , + ); + + const githubButton = screen.getByRole("button", { + name: "GITHUB$CONNECT_TO_GITHUB", + }); + await user.click(githubButton); + + await waitFor(() => { + const redirectUrl = window.location.href; + // The URL should contain an encoded state parameter + expect(redirectUrl).toContain("state="); + // Decode and verify the state contains invitation_token + const url = new URL(redirectUrl); + const state = url.searchParams.get("state"); + if (state) { + const decodedState = JSON.parse(atob(state)); + expect(decodedState.invitation_token).toBe("inv-test-token-12345"); + } + }); + }); }); diff --git a/frontend/__tests__/hooks/use-invitation.test.ts b/frontend/__tests__/hooks/use-invitation.test.ts new file mode 100644 index 0000000000..9de4d47781 --- /dev/null +++ b/frontend/__tests__/hooks/use-invitation.test.ts @@ -0,0 +1,170 @@ +import { act, renderHook } from "@testing-library/react"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +const INVITATION_TOKEN_KEY = "openhands_invitation_token"; + +// Mock setSearchParams function +const mockSetSearchParams = vi.fn(); + +// Default mock searchParams +let mockSearchParamsData: Record = {}; + +// Mock react-router +vi.mock("react-router", () => ({ + useSearchParams: () => [ + { + get: (key: string) => mockSearchParamsData[key] || null, + has: (key: string) => key in mockSearchParamsData, + }, + mockSetSearchParams, + ], +})); + +// Import after mocking +import { useInvitation } from "#/hooks/use-invitation"; + +describe("useInvitation", () => { + beforeEach(() => { + // Clear localStorage before each test + localStorage.clear(); + // Reset mock data + mockSearchParamsData = {}; + mockSetSearchParams.mockClear(); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + describe("initialization", () => { + it("should initialize with null token when localStorage is empty", () => { + // Arrange - localStorage is empty (cleared in beforeEach) + + // Act + const { result } = renderHook(() => useInvitation()); + + // Assert + expect(result.current.invitationToken).toBeNull(); + expect(result.current.hasInvitation).toBe(false); + }); + + it("should initialize with token from localStorage if present", () => { + // Arrange + const storedToken = "inv-stored-token-12345"; + localStorage.setItem(INVITATION_TOKEN_KEY, storedToken); + + // Act + const { result } = renderHook(() => useInvitation()); + + // Assert + expect(result.current.invitationToken).toBe(storedToken); + expect(result.current.hasInvitation).toBe(true); + }); + }); + + describe("URL token capture", () => { + it("should capture invitation_token from URL and store in localStorage", () => { + // Arrange + const urlToken = "inv-url-token-67890"; + mockSearchParamsData = { invitation_token: urlToken }; + + // Act + renderHook(() => useInvitation()); + + // Assert + expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBe(urlToken); + expect(mockSetSearchParams).toHaveBeenCalled(); + }); + }); + + describe("completion cleanup", () => { + it("should clear localStorage when email_mismatch param is present", () => { + // Arrange + const storedToken = "inv-token-to-clear"; + localStorage.setItem(INVITATION_TOKEN_KEY, storedToken); + mockSearchParamsData = { email_mismatch: "true" }; + + // Act + const { result } = renderHook(() => useInvitation()); + + // Assert + expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull(); + expect(mockSetSearchParams).toHaveBeenCalled(); + }); + + it("should clear localStorage when invitation_success param is present", () => { + // Arrange + const storedToken = "inv-token-to-clear"; + localStorage.setItem(INVITATION_TOKEN_KEY, storedToken); + mockSearchParamsData = { invitation_success: "true" }; + + // Act + renderHook(() => useInvitation()); + + // Assert + expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull(); + }); + + it("should clear localStorage when invitation_expired param is present", () => { + // Arrange + localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token"); + mockSearchParamsData = { invitation_expired: "true" }; + + // Act + renderHook(() => useInvitation()); + + // Assert + expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull(); + }); + }); + + describe("buildOAuthStateData", () => { + it("should include invitation_token in OAuth state when token is present", () => { + // Arrange + const token = "inv-oauth-token-12345"; + localStorage.setItem(INVITATION_TOKEN_KEY, token); + + const { result } = renderHook(() => useInvitation()); + const baseState = { redirect_url: "/dashboard" }; + + // Act + const stateData = result.current.buildOAuthStateData(baseState); + + // Assert + expect(stateData.invitation_token).toBe(token); + expect(stateData.redirect_url).toBe("/dashboard"); + }); + + it("should not include invitation_token when no token is present", () => { + // Arrange - no token in localStorage + + const { result } = renderHook(() => useInvitation()); + const baseState = { redirect_url: "/dashboard" }; + + // Act + const stateData = result.current.buildOAuthStateData(baseState); + + // Assert + expect(stateData.invitation_token).toBeUndefined(); + expect(stateData.redirect_url).toBe("/dashboard"); + }); + }); + + describe("clearInvitation", () => { + it("should remove token from localStorage when called", () => { + // Arrange + localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token-to-clear"); + const { result } = renderHook(() => useInvitation()); + + // Act + act(() => { + result.current.clearInvitation(); + }); + + // Assert + expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull(); + expect(result.current.invitationToken).toBeNull(); + expect(result.current.hasInvitation).toBe(false); + }); + }); +}); diff --git a/frontend/__tests__/routes/login.test.tsx b/frontend/__tests__/routes/login.test.tsx index 3abb9557ee..2b63a8c98e 100644 --- a/frontend/__tests__/routes/login.test.tsx +++ b/frontend/__tests__/routes/login.test.tsx @@ -57,6 +57,22 @@ vi.mock("#/hooks/use-tracking", () => ({ }), })); +const { useInvitationMock, buildOAuthStateDataMock } = vi.hoisted(() => ({ + useInvitationMock: vi.fn(() => ({ + invitationToken: null as string | null, + hasInvitation: false, + buildOAuthStateData: (baseState: Record) => baseState, + clearInvitation: vi.fn(), + })), + buildOAuthStateDataMock: vi.fn( + (baseState: Record) => baseState, + ), +})); + +vi.mock("#/hooks/use-invitation", () => ({ + useInvitation: () => useInvitationMock(), +})); + const RouterStub = createRoutesStub([ { Component: LoginPage, @@ -234,7 +250,8 @@ describe("LoginPage", () => { }); await user.click(githubButton); - expect(window.location.href).toBe(mockUrl); + // URL includes state parameter added by handleAuthRedirect + expect(window.location.href).toContain(mockUrl); }); it("should redirect to GitLab auth URL when GitLab button is clicked", async () => { @@ -255,7 +272,8 @@ describe("LoginPage", () => { }); await user.click(gitlabButton); - expect(window.location.href).toBe("https://gitlab.com/oauth/authorize"); + // URL includes state parameter added by handleAuthRedirect + expect(window.location.href).toContain("https://gitlab.com/oauth/authorize"); }); it("should redirect to Bitbucket auth URL when Bitbucket button is clicked", async () => { @@ -282,7 +300,8 @@ describe("LoginPage", () => { }); await user.click(bitbucketButton); - expect(window.location.href).toBe( + // URL includes state parameter added by handleAuthRedirect + expect(window.location.href).toContain( "https://bitbucket.org/site/oauth2/authorize", ); }); @@ -479,4 +498,137 @@ describe("LoginPage", () => { }); }); }); + + describe("Invitation Flow", () => { + it("should display invitation pending message when hasInvitation is true", async () => { + useInvitationMock.mockReturnValue({ + invitationToken: "inv-test-token-12345", + hasInvitation: true, + buildOAuthStateData: buildOAuthStateDataMock, + clearInvitation: vi.fn(), + }); + + render(, { + wrapper: createWrapper(), + }); + + await waitFor(() => { + expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument(); + }); + }); + + it("should not display invitation pending message when hasInvitation is false", async () => { + useInvitationMock.mockReturnValue({ + invitationToken: null, + hasInvitation: false, + buildOAuthStateData: buildOAuthStateDataMock, + clearInvitation: vi.fn(), + }); + + render(, { + wrapper: createWrapper(), + }); + + await waitFor(() => { + expect(screen.getByTestId("login-content")).toBeInTheDocument(); + }); + + expect( + screen.queryByText("AUTH$INVITATION_PENDING"), + ).not.toBeInTheDocument(); + }); + + it("should pass buildOAuthStateData to LoginContent for OAuth state encoding", async () => { + const user = userEvent.setup(); + const mockBuildOAuthStateData = vi.fn((baseState: Record) => ({ + ...baseState, + invitation_token: "inv-test-token-12345", + })); + + useInvitationMock.mockReturnValue({ + invitationToken: "inv-test-token-12345", + hasInvitation: true, + buildOAuthStateData: mockBuildOAuthStateData, + clearInvitation: vi.fn(), + }); + + render(, { + wrapper: createWrapper(), + }); + + await waitFor(() => { + expect( + screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }), + ).toBeInTheDocument(); + }); + + const githubButton = screen.getByRole("button", { + name: "GITHUB$CONNECT_TO_GITHUB", + }); + await user.click(githubButton); + + // buildOAuthStateData should have been called during the OAuth redirect + expect(mockBuildOAuthStateData).toHaveBeenCalled(); + }); + + it("should include invitation token in OAuth state when invitation is present", async () => { + const user = userEvent.setup(); + const mockBuildOAuthStateData = vi.fn((baseState: Record) => ({ + ...baseState, + invitation_token: "inv-test-token-12345", + })); + + useInvitationMock.mockReturnValue({ + invitationToken: "inv-test-token-12345", + hasInvitation: true, + buildOAuthStateData: mockBuildOAuthStateData, + clearInvitation: vi.fn(), + }); + + render(, { + wrapper: createWrapper(), + }); + + await waitFor(() => { + expect( + screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }), + ).toBeInTheDocument(); + }); + + const githubButton = screen.getByRole("button", { + name: "GITHUB$CONNECT_TO_GITHUB", + }); + await user.click(githubButton); + + // Verify the redirect URL contains the state with invitation token + await waitFor(() => { + expect(window.location.href).toContain("state="); + }); + + // Decode and verify the state contains invitation_token + const url = new URL(window.location.href); + const state = url.searchParams.get("state"); + if (state) { + const decodedState = JSON.parse(atob(state)); + expect(decodedState.invitation_token).toBe("inv-test-token-12345"); + } + }); + + it("should handle login with invitation_token URL parameter", async () => { + useInvitationMock.mockReturnValue({ + invitationToken: "inv-url-token-67890", + hasInvitation: true, + buildOAuthStateData: buildOAuthStateDataMock, + clearInvitation: vi.fn(), + }); + + render(, { + wrapper: createWrapper(), + }); + + await waitFor(() => { + expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument(); + }); + }); + }); }); diff --git a/frontend/__tests__/routes/root-layout.test.tsx b/frontend/__tests__/routes/root-layout.test.tsx index 0fd9f64deb..107841c71d 100644 --- a/frontend/__tests__/routes/root-layout.test.tsx +++ b/frontend/__tests__/routes/root-layout.test.tsx @@ -42,6 +42,15 @@ vi.mock("#/utils/custom-toast-handlers", () => ({ displaySuccessToast: vi.fn(), })); +vi.mock("#/hooks/use-invitation", () => ({ + useInvitation: () => ({ + invitationToken: null, + hasInvitation: false, + buildOAuthStateData: (baseState: Record) => baseState, + clearInvitation: vi.fn(), + }), +})); + function LoginStub() { const [searchParams] = useSearchParams(); const emailVerificationRequired = @@ -353,4 +362,68 @@ describe("MainApp", () => { ); }); }); + + describe("Invitation URL Parameters", () => { + beforeEach(() => { + vi.spyOn(AuthService, "authenticate").mockRejectedValue({ + response: { status: 401 }, + isAxiosError: true, + }); + }); + + it("should redirect to login when email_mismatch=true is in query params", async () => { + renderMainApp(["/?email_mismatch=true"]); + + await waitFor( + () => { + expect(screen.getByTestId("login-page")).toBeInTheDocument(); + }, + { timeout: 2000 }, + ); + }); + + it("should redirect to login when invitation_success=true is in query params", async () => { + renderMainApp(["/?invitation_success=true"]); + + await waitFor( + () => { + expect(screen.getByTestId("login-page")).toBeInTheDocument(); + }, + { timeout: 2000 }, + ); + }); + + it("should redirect to login when invitation_expired=true is in query params", async () => { + renderMainApp(["/?invitation_expired=true"]); + + await waitFor( + () => { + expect(screen.getByTestId("login-page")).toBeInTheDocument(); + }, + { timeout: 2000 }, + ); + }); + + it("should redirect to login when invitation_invalid=true is in query params", async () => { + renderMainApp(["/?invitation_invalid=true"]); + + await waitFor( + () => { + expect(screen.getByTestId("login-page")).toBeInTheDocument(); + }, + { timeout: 2000 }, + ); + }); + + it("should redirect to login when already_member=true is in query params", async () => { + renderMainApp(["/?already_member=true"]); + + await waitFor( + () => { + expect(screen.getByTestId("login-page")).toBeInTheDocument(); + }, + { timeout: 2000 }, + ); + }); + }); }); diff --git a/frontend/src/components/features/auth/login-content.tsx b/frontend/src/components/features/auth/login-content.tsx index 1938929c91..8ab3d72f59 100644 --- a/frontend/src/components/features/auth/login-content.tsx +++ b/frontend/src/components/features/auth/login-content.tsx @@ -21,6 +21,10 @@ export interface LoginContentProps { emailVerified?: boolean; hasDuplicatedEmail?: boolean; recaptchaBlocked?: boolean; + hasInvitation?: boolean; + buildOAuthStateData?: ( + baseStateData: Record, + ) => Record; } export function LoginContent({ @@ -31,6 +35,8 @@ export function LoginContent({ emailVerified = false, hasDuplicatedEmail = false, recaptchaBlocked = false, + hasInvitation = false, + buildOAuthStateData, }: LoginContentProps) { const { t } = useTranslation(); const { trackLoginButtonClick } = useTracking(); @@ -59,31 +65,36 @@ export function LoginContent({ ) => { trackLoginButtonClick({ provider }); - if (!config?.recaptcha_site_key || !recaptchaReady) { - // No reCAPTCHA or token generation failed - redirect normally - window.location.href = redirectUrl; - return; + const url = new URL(redirectUrl); + const currentState = + url.searchParams.get("state") || window.location.origin; + + // Build base state data + let stateData: Record = { + redirect_url: currentState, + }; + + // Add invitation token if present + if (buildOAuthStateData) { + stateData = buildOAuthStateData(stateData); } - // If reCAPTCHA is configured, encode token in OAuth state - try { - const token = await executeRecaptcha("LOGIN"); - if (token) { - const url = new URL(redirectUrl); - const currentState = - url.searchParams.get("state") || window.location.origin; - - // Encode state with reCAPTCHA token for backend verification - const stateData = { - redirect_url: currentState, - recaptcha_token: token, - }; - url.searchParams.set("state", btoa(JSON.stringify(stateData))); - window.location.href = url.toString(); + // If reCAPTCHA is configured, add token to state + if (config?.recaptcha_site_key && recaptchaReady) { + try { + const token = await executeRecaptcha("LOGIN"); + if (token) { + stateData.recaptcha_token = token; + } + } catch (err) { + displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED)); + return; } - } catch (err) { - displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED)); } + + // Encode state and redirect + url.searchParams.set("state", btoa(JSON.stringify(stateData))); + window.location.href = url.toString(); }; const handleGitHubAuth = () => { @@ -123,6 +134,10 @@ export function LoginContent({ const buttonBaseClasses = "w-[301.5px] h-10 rounded p-2 flex items-center justify-center cursor-pointer transition-opacity hover:opacity-90 disabled:opacity-50 disabled:cursor-not-allowed"; const buttonLabelClasses = "text-sm font-medium leading-5 px-1"; + + const shouldShownHelperText = + emailVerified || hasDuplicatedEmail || recaptchaBlocked || hasInvitation; + return (
- {emailVerified && ( -

- {t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)} -

- )} - {hasDuplicatedEmail && ( -

- {t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)} -

- )} - {recaptchaBlocked && ( -

- {t(I18nKey.AUTH$RECAPTCHA_BLOCKED)} -

+ {shouldShownHelperText && ( +
+ {emailVerified && ( +

+ {t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)} +

+ )} + {hasDuplicatedEmail && ( +

+ {t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)} +

+ )} + {recaptchaBlocked && ( +

+ {t(I18nKey.AUTH$RECAPTCHA_BLOCKED)} +

+ )} + {hasInvitation && ( +

+ {t(I18nKey.AUTH$INVITATION_PENDING)} +

+ )} +
)}
diff --git a/frontend/src/hooks/use-invitation.ts b/frontend/src/hooks/use-invitation.ts new file mode 100644 index 0000000000..14719ae7bb --- /dev/null +++ b/frontend/src/hooks/use-invitation.ts @@ -0,0 +1,119 @@ +import React from "react"; +import { useSearchParams } from "react-router"; + +const INVITATION_TOKEN_KEY = "openhands_invitation_token"; + +interface UseInvitationReturn { + /** The invitation token, if present */ + invitationToken: string | null; + /** Whether there is an active invitation */ + hasInvitation: boolean; + /** Clear the stored invitation token */ + clearInvitation: () => void; + /** Build OAuth state data including invitation token if present */ + buildOAuthStateData: ( + baseStateData: Record, + ) => Record; +} + +/** + * Hook to manage organization invitation tokens during the login flow. + * + * This hook: + * 1. Reads invitation_token from URL query params on mount + * 2. Persists the token in localStorage (survives page refresh and works across tabs) + * 3. Provides the token for inclusion in OAuth state + * 4. Provides cleanup method after successful authentication + * + * The invitation token flow: + * 1. User clicks invitation link → /api/invitations/accept?token=xxx + * 2. Backend redirects to /login?invitation_token=xxx + * 3. This hook captures token and stores in localStorage + * 4. When user clicks login button, token is included in OAuth state + * 5. After auth callback processes invitation, frontend clears the token + * + * Note: localStorage is used instead of sessionStorage to support scenarios where + * the user opens the email verification link in a new tab/browser window. + */ +export function useInvitation(): UseInvitationReturn { + const [searchParams, setSearchParams] = useSearchParams(); + const [invitationToken, setInvitationToken] = React.useState( + () => { + // Initialize from localStorage (persists across tabs and page refreshes) + if (typeof window !== "undefined") { + return localStorage.getItem(INVITATION_TOKEN_KEY); + } + return null; + }, + ); + + // Capture invitation token from URL and persist to localStorage + // This only runs on the login page where the hook is used + React.useEffect(() => { + const tokenFromUrl = searchParams.get("invitation_token"); + + if (tokenFromUrl) { + // Store in localStorage for persistence across tabs and refreshes + localStorage.setItem(INVITATION_TOKEN_KEY, tokenFromUrl); + setInvitationToken(tokenFromUrl); + + // Remove token from URL to clean up (prevents token exposure in browser history) + const newSearchParams = new URLSearchParams(searchParams); + newSearchParams.delete("invitation_token"); + setSearchParams(newSearchParams, { replace: true }); + } + }, [searchParams, setSearchParams]); + + // Clear invitation token when invitation flow completes (success or failure) + // These query params are set by the backend after processing the invitation + React.useEffect(() => { + const invitationCompleted = + searchParams.has("invitation_success") || + searchParams.has("invitation_expired") || + searchParams.has("invitation_invalid") || + searchParams.has("invitation_error") || + searchParams.has("already_member") || + searchParams.has("email_mismatch"); + + if (invitationCompleted) { + localStorage.removeItem(INVITATION_TOKEN_KEY); + setInvitationToken(null); + + // Remove invitation params from URL to clean up + const newSearchParams = new URLSearchParams(searchParams); + newSearchParams.delete("invitation_success"); + newSearchParams.delete("invitation_expired"); + newSearchParams.delete("invitation_invalid"); + newSearchParams.delete("invitation_error"); + newSearchParams.delete("already_member"); + newSearchParams.delete("email_mismatch"); + setSearchParams(newSearchParams, { replace: true }); + } + }, [searchParams, setSearchParams]); + + const clearInvitation = React.useCallback(() => { + localStorage.removeItem(INVITATION_TOKEN_KEY); + setInvitationToken(null); + }, []); + + const buildOAuthStateData = React.useCallback( + (baseStateData: Record): Record => { + const stateData = { ...baseStateData }; + + // Include invitation token in state if present + if (invitationToken) { + stateData.invitation_token = invitationToken; + } + + return stateData; + }, + [invitationToken], + ); + + return { + invitationToken, + hasInvitation: invitationToken !== null, + clearInvitation, + buildOAuthStateData, + }; +} diff --git a/frontend/src/i18n/declaration.ts b/frontend/src/i18n/declaration.ts index b570bdbbaa..2d9599e810 100644 --- a/frontend/src/i18n/declaration.ts +++ b/frontend/src/i18n/declaration.ts @@ -763,6 +763,7 @@ export enum I18nKey { AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR", AUTH$RECAPTCHA_BLOCKED = "AUTH$RECAPTCHA_BLOCKED", AUTH$LETS_GET_STARTED = "AUTH$LETS_GET_STARTED", + AUTH$INVITATION_PENDING = "AUTH$INVITATION_PENDING", COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE", COMMON$AND = "COMMON$AND", COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY", diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index 787f287c6d..91530fe633 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -12207,6 +12207,22 @@ "de": "Lass uns anfangen", "uk": "Почнімо" }, + "AUTH$INVITATION_PENDING": { + "en": "Sign in to accept your organization invitation", + "ja": "組織への招待を受け入れるにはサインインしてください", + "zh-CN": "登录以接受您的组织邀请", + "zh-TW": "登入以接受您的組織邀請", + "ko-KR": "조직 초대를 수락하려면 로그인하세요", + "no": "Logg inn for å godta organisasjonsinvitasjonen din", + "it": "Accedi per accettare l'invito della tua organizzazione", + "pt": "Faça login para aceitar o convite da sua organização", + "es": "Inicia sesión para aceptar la invitación de tu organización", + "ar": "سجّل الدخول لقبول دعوة مؤسستك", + "fr": "Connectez-vous pour accepter l'invitation de votre organisation", + "tr": "Organizasyon davetinizi kabul etmek için giriş yapın", + "de": "Melden Sie sich an, um Ihre Organisationseinladung anzunehmen", + "uk": "Увійдіть, щоб прийняти запрошення до організації" + }, "COMMON$TERMS_OF_SERVICE": { "en": "Terms of Service", "ja": "利用規約", diff --git a/frontend/src/root.tsx b/frontend/src/root.tsx index 6f1f896b4d..d2cc4f64f1 100644 --- a/frontend/src/root.tsx +++ b/frontend/src/root.tsx @@ -10,6 +10,7 @@ import "./tailwind.css"; import "./index.css"; import React from "react"; import { Toaster } from "react-hot-toast"; +import { useInvitation } from "#/hooks/use-invitation"; export function Layout({ children }: { children: React.ReactNode }) { return ( @@ -37,5 +38,9 @@ export const meta: MetaFunction = () => [ ]; export default function App() { + // Handle invitation token cleanup when invitation flow completes + // This runs on all pages to catch redirects from auth callback + useInvitation(); + return ; } diff --git a/frontend/src/routes/login.tsx b/frontend/src/routes/login.tsx index 0de57b27b9..874743aa6e 100644 --- a/frontend/src/routes/login.tsx +++ b/frontend/src/routes/login.tsx @@ -4,6 +4,7 @@ import { useIsAuthed } from "#/hooks/query/use-is-authed"; import { useConfig } from "#/hooks/query/use-config"; import { useGitHubAuthUrl } from "#/hooks/use-github-auth-url"; import { useEmailVerification } from "#/hooks/use-email-verification"; +import { useInvitation } from "#/hooks/use-invitation"; import { LoginContent } from "#/components/features/auth/login-content"; import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal"; @@ -23,6 +24,8 @@ export default function LoginPage() { userId, } = useEmailVerification(); + const { hasInvitation, buildOAuthStateData } = useInvitation(); + const gitHubAuthUrl = useGitHubAuthUrl({ appMode: config.data?.app_mode || null, authUrl: config.data?.auth_url, @@ -69,6 +72,8 @@ export default function LoginPage() { emailVerified={emailVerified} hasDuplicatedEmail={hasDuplicatedEmail} recaptchaBlocked={recaptchaBlocked} + hasInvitation={hasInvitation} + buildOAuthStateData={buildOAuthStateData} />