feat: add user invitation logic (#12883)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Hiep Le
2026-02-18 13:24:19 +07:00
committed by GitHub
parent b18568da0b
commit 4d6f66ca28
28 changed files with 3666 additions and 53 deletions

View File

@@ -0,0 +1,110 @@
"""create org_invitation table
Revision ID: 094
Revises: 093
Create Date: 2026-02-18 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '094'
down_revision: Union[str, None] = '093'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create org_invitation table
op.create_table(
'org_invitation',
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
sa.Column('token', sa.String(64), nullable=False),
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('email', sa.String(255), nullable=False),
sa.Column('role_id', sa.Integer, nullable=False),
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
'status',
sa.String(20),
nullable=False,
server_default=sa.text("'pending'"),
),
sa.Column(
'created_at',
sa.DateTime,
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('accepted_at', sa.DateTime, nullable=True),
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
# Foreign key constraints
sa.ForeignKeyConstraint(
['org_id'],
['org.id'],
name='org_invitation_org_fkey',
ondelete='CASCADE',
),
sa.ForeignKeyConstraint(
['role_id'],
['role.id'],
name='org_invitation_role_fkey',
),
sa.ForeignKeyConstraint(
['inviter_id'],
['user.id'],
name='org_invitation_inviter_fkey',
),
sa.ForeignKeyConstraint(
['accepted_by_user_id'],
['user.id'],
name='org_invitation_accepter_fkey',
),
)
# Create indexes
op.create_index(
'ix_org_invitation_token',
'org_invitation',
['token'],
unique=True,
)
op.create_index(
'ix_org_invitation_org_id',
'org_invitation',
['org_id'],
)
op.create_index(
'ix_org_invitation_email',
'org_invitation',
['email'],
)
op.create_index(
'ix_org_invitation_status',
'org_invitation',
['status'],
)
# Composite index for checking pending invitations
op.create_index(
'ix_org_invitation_org_email_status',
'org_invitation',
['org_id', 'email', 'status'],
)
def downgrade() -> None:
# Drop indexes
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
# Drop table
op.drop_table('org_invitation')

View File

@@ -38,6 +38,12 @@ from server.routes.integration.linear import linear_integration_router # noqa:
from server.routes.integration.slack import slack_router # noqa: E402
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.org_invitations import ( # noqa: E402
accept_router as invitation_accept_router,
)
from server.routes.org_invitations import ( # noqa: E402
invitation_router,
)
from server.routes.orgs import org_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
@@ -99,6 +105,8 @@ if GITLAB_APP_CLIENT_ID:
base_app.include_router(api_keys_router) # Add routes for API key management
base_app.include_router(org_router) # Add routes for organization management
base_app.include_router(invitation_router) # Add routes for org invitation management
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
add_github_proxy_routes(base_app)
add_debugging_routes(
base_app

View File

@@ -160,6 +160,7 @@ class SetAuthCookieMiddleware:
'/api/billing/customer-setup-success',
'/api/billing/stripe-webhook',
'/api/email/resend',
'/api/organizations/members/invite/accept',
'/oauth/device/authorize',
'/oauth/device/token',
'/api/v1/web-client/config',

View File

@@ -5,6 +5,7 @@ import warnings
from datetime import datetime, timezone
from typing import Annotated, Literal, Optional
from urllib.parse import quote
from uuid import UUID as parse_uuid
import posthog
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
@@ -26,6 +27,13 @@ from server.auth.token_manager import TokenManager
from server.config import sign_token
from server.constants import IS_FEATURE_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
from server.services.org_invitation_service import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
OrgInvitationService,
UserAlreadyMemberError,
)
from storage.database import session_maker
from storage.user import User
from storage.user_store import UserStore
@@ -104,22 +112,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
)
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
Returns:
Tuple of (redirect_url, recaptcha_token, invitation_token).
Tokens may be None.
"""
if not state:
return '', None, None
try:
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return (
state_data.get('redirect_url', ''),
state_data.get('recaptcha_token'),
state_data.get('invitation_token'),
)
except Exception:
# Old format - state is just the redirect URL
return state, None, None
# Keep alias for backward compatibility
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
"""Extract redirect URL and reCAPTCHA token from OAuth state.
Deprecated: Use _extract_oauth_state instead.
Returns:
Tuple of (redirect_url, recaptcha_token). Token may be None.
"""
if not state:
return '', None
try:
# Try to decode as JSON (new format with reCAPTCHA)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
except Exception:
# Old format - state is just the redirect URL
return state, None
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
return redirect_url, recaptcha_token
@oauth_router.get('/keycloak/callback')
@@ -130,8 +156,8 @@ async def keycloak_callback(
error: Optional[str] = None,
error_description: Optional[str] = None,
):
# Extract redirect URL and reCAPTCHA token from state
redirect_url, recaptcha_token = _extract_recaptcha_state(state)
# Extract redirect URL, reCAPTCHA token, and invitation token from state
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
if not redirect_url:
redirect_url = str(request.base_url)
@@ -302,8 +328,13 @@ async def keycloak_callback(
from server.routes.email import verify_email
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
response = RedirectResponse(redirect_url, status_code=302)
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
# Preserve invitation token so it can be included in OAuth state after verification
if invitation_token:
verification_redirect_url = (
f'{verification_redirect_url}&invitation_token={invitation_token}'
)
response = RedirectResponse(verification_redirect_url, status_code=302)
return response
# default to github IDP for now.
@@ -381,14 +412,90 @@ async def keycloak_callback(
)
has_accepted_tos = user.accepted_tos is not None
# Process invitation token if present (after email verification but before TOS)
if invitation_token:
try:
logger.info(
'Processing invitation token during auth callback',
extra={
'user_id': user_id,
'invitation_token_prefix': invitation_token[:10] + '...',
},
)
await OrgInvitationService.accept_invitation(
invitation_token, parse_uuid(user_id)
)
logger.info(
'Invitation accepted during auth callback',
extra={'user_id': user_id},
)
except InvitationExpiredError:
logger.warning(
'Invitation expired during auth callback',
extra={'user_id': user_id},
)
# Add query param to redirect URL
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
except InvitationInvalidError as e:
logger.warning(
'Invalid invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
except UserAlreadyMemberError:
logger.info(
'User already member during invitation acceptance',
extra={'user_id': user_id},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
except EmailMismatchError as e:
logger.warning(
'Email mismatch during auth callback invitation acceptance',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
except Exception as e:
logger.exception(
'Unexpected error processing invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
# Don't fail the login if invitation processing fails
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_error=true'
else:
redirect_url = f'{redirect_url}?invitation_error=true'
# If the user hasn't accepted the TOS, redirect to the TOS page
if not has_accepted_tos:
encoded_redirect_url = quote(redirect_url, safe='')
tos_redirect_url = (
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
)
if invitation_token:
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
response = RedirectResponse(tos_redirect_url, status_code=302)
else:
if invitation_token:
redirect_url = f'{redirect_url}&invitation_success=true'
response = RedirectResponse(redirect_url, status_code=302)
set_response_cookie(

View File

@@ -0,0 +1,122 @@
"""
Pydantic models and custom exceptions for organization invitations.
"""
from pydantic import BaseModel, EmailStr
from storage.org_invitation import OrgInvitation
from storage.role_store import RoleStore
class InvitationError(Exception):
"""Base exception for invitation errors."""
pass
class InvitationAlreadyExistsError(InvitationError):
"""Raised when a pending invitation already exists for the email."""
def __init__(
self, message: str = 'A pending invitation already exists for this email'
):
super().__init__(message)
class UserAlreadyMemberError(InvitationError):
"""Raised when the user is already a member of the organization."""
def __init__(self, message: str = 'User is already a member of this organization'):
super().__init__(message)
class InvitationExpiredError(InvitationError):
"""Raised when the invitation has expired."""
def __init__(self, message: str = 'Invitation has expired'):
super().__init__(message)
class InvitationInvalidError(InvitationError):
"""Raised when the invitation is invalid or revoked."""
def __init__(self, message: str = 'Invitation is no longer valid'):
super().__init__(message)
class InsufficientPermissionError(InvitationError):
"""Raised when the user lacks permission to perform the action."""
def __init__(self, message: str = 'Insufficient permission'):
super().__init__(message)
class EmailMismatchError(InvitationError):
"""Raised when the accepting user's email doesn't match the invitation email."""
def __init__(self, message: str = 'Your email does not match the invitation'):
super().__init__(message)
class InvitationCreate(BaseModel):
"""Request model for creating invitation(s)."""
emails: list[EmailStr]
role: str = 'member' # Default to member role
class InvitationResponse(BaseModel):
"""Response model for invitation details."""
id: int
email: str
role: str
status: str
created_at: str
expires_at: str
inviter_email: str | None = None
@classmethod
def from_invitation(
cls,
invitation: OrgInvitation,
inviter_email: str | None = None,
) -> 'InvitationResponse':
"""Create an InvitationResponse from an OrgInvitation entity.
Args:
invitation: The invitation entity to convert
inviter_email: Optional email of the inviter
Returns:
InvitationResponse: The response model instance
"""
role_name = ''
if invitation.role:
role_name = invitation.role.name
elif invitation.role_id:
role = RoleStore.get_role_by_id(invitation.role_id)
role_name = role.name if role else ''
return cls(
id=invitation.id,
email=invitation.email,
role=role_name,
status=invitation.status,
created_at=invitation.created_at.isoformat(),
expires_at=invitation.expires_at.isoformat(),
inviter_email=inviter_email,
)
class InvitationFailure(BaseModel):
"""Response model for a failed invitation."""
email: str
error: str
class BatchInvitationResponse(BaseModel):
"""Response model for batch invitation creation."""
successful: list[InvitationResponse]
failed: list[InvitationFailure]

View File

@@ -0,0 +1,226 @@
"""API routes for organization invitations."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from server.routes.org_invitation_models import (
BatchInvitationResponse,
EmailMismatchError,
InsufficientPermissionError,
InvitationCreate,
InvitationExpiredError,
InvitationFailure,
InvitationInvalidError,
InvitationResponse,
UserAlreadyMemberError,
)
from server.services.org_invitation_service import OrgInvitationService
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
from openhands.server.user_auth.user_auth import get_user_auth
# Router for invitation operations on an organization (requires org_id)
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
# Router for accepting invitations (no org_id required)
accept_router = APIRouter(prefix='/api/organizations/members/invite')
@invitation_router.post(
'/invite',
response_model=BatchInvitationResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_invitation(
org_id: UUID,
invitation_data: InvitationCreate,
request: Request,
user_id: str = Depends(get_user_id),
):
"""Create organization invitations for multiple email addresses.
Sends emails to invitees with secure links to join the organization.
Supports batch invitations - some may succeed while others fail.
Permission rules:
- Only owners and admins can create invitations
- Admins can only invite with 'member' or 'admin' role (not 'owner')
- Owners can invite with any role
Args:
org_id: Organization UUID
invitation_data: Invitation details (emails array, role)
request: FastAPI request
user_id: Authenticated user ID (from dependency)
Returns:
BatchInvitationResponse: Lists of successful and failed invitations
Raises:
HTTPException 400: Invalid role or organization not found
HTTPException 403: User lacks permission to invite
HTTPException 429: Rate limit exceeded
"""
# Rate limit: 10 invitations per minute per user (6 seconds between requests)
await check_rate_limit_by_user_id(
request=request,
key_prefix='org_invitation_create',
user_id=user_id,
user_rate_limit_seconds=6,
)
try:
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=[str(email) for email in invitation_data.emails],
role_name=invitation_data.role,
inviter_id=UUID(user_id),
)
logger.info(
'Batch organization invitations created',
extra={
'org_id': str(org_id),
'total_emails': len(invitation_data.emails),
'successful': len(successful),
'failed': len(failed),
'inviter_id': user_id,
},
)
return BatchInvitationResponse(
successful=[InvitationResponse.from_invitation(inv) for inv in successful],
failed=[
InvitationFailure(email=email, error=error) for email, error in failed
],
)
except InsufficientPermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.exception(
'Unexpected error creating batch invitations',
extra={'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@accept_router.get('/accept')
async def accept_invitation(
token: str,
request: Request,
):
"""Accept an organization invitation via token.
This endpoint is accessed via the link in the invitation email.
Flow:
1. If user is authenticated: Accept invitation directly and redirect to home
2. If user is not authenticated: Redirect to login page with invitation token
- Frontend stores token and includes it in OAuth state during login
- After authentication, keycloak_callback processes the invitation
Args:
token: The invitation token from the email link
request: FastAPI request
Returns:
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
or home page with error query params on failure
"""
base_url = str(request.base_url).rstrip('/')
# Try to get user_id from auth (may not be authenticated)
user_id = None
try:
user_auth = await get_user_auth(request)
if user_auth:
user_id = await user_auth.get_user_id()
except Exception:
pass
if not user_id:
# User not authenticated - redirect to login page with invitation token
# Frontend will store the token and include it in OAuth state during login
logger.info(
'Invitation accept: redirecting unauthenticated user to login',
extra={'token_prefix': token[:10] + '...'},
)
login_url = f'{base_url}/login?invitation_token={token}'
return RedirectResponse(login_url, status_code=302)
# User is authenticated - process the invitation directly
try:
await OrgInvitationService.accept_invitation(token, UUID(user_id))
logger.info(
'Invitation accepted successfully',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
},
)
# Redirect to home page on success
return RedirectResponse(f'{base_url}/', status_code=302)
except InvitationExpiredError:
logger.warning(
'Invitation accept failed: expired',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
except InvitationInvalidError as e:
logger.warning(
'Invitation accept failed: invalid',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
except UserAlreadyMemberError:
logger.info(
'Invitation accept: user already member',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
except EmailMismatchError as e:
logger.warning(
'Invitation accept failed: email mismatch',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
except Exception as e:
logger.exception(
'Unexpected error accepting invitation',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)

View File

@@ -0,0 +1,131 @@
"""Email service for sending transactional emails via Resend."""
import os
try:
import resend
RESEND_AVAILABLE = True
except ImportError:
RESEND_AVAILABLE = False
from openhands.core.logger import openhands_logger as logger
DEFAULT_FROM_EMAIL = 'OpenHands <no-reply@openhands.dev>'
DEFAULT_WEB_HOST = 'https://app.all-hands.dev'
class EmailService:
"""Service for sending transactional emails."""
@staticmethod
def _get_resend_client() -> bool:
"""Initialize and return the Resend client.
Returns:
bool: True if client is ready, False otherwise
"""
if not RESEND_AVAILABLE:
logger.warning('Resend library not installed, skipping email')
return False
resend_api_key = os.environ.get('RESEND_API_KEY')
if not resend_api_key:
logger.warning('RESEND_API_KEY not configured, skipping email')
return False
resend.api_key = resend_api_key
return True
@staticmethod
def send_invitation_email(
to_email: str,
org_name: str,
inviter_name: str,
role_name: str,
invitation_token: str,
invitation_id: int,
) -> None:
"""Send an organization invitation email.
Args:
to_email: Recipient's email address
org_name: Name of the organization
inviter_name: Display name of the person who sent the invite
role_name: Role being offered (e.g., 'member', 'admin')
invitation_token: The secure invitation token
invitation_id: The invitation ID for logging
"""
if not EmailService._get_resend_client():
return
# Build invitation URL
web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST)
invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}'
from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL)
params = {
'from': from_email,
'to': [to_email],
'subject': f"You're invited to join {org_name} on OpenHands",
'html': f"""
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
<p>Hi,</p>
<p><strong>{inviter_name}</strong> has invited you to join <strong>{org_name}</strong> on OpenHands as a <strong>{role_name}</strong>.</p>
<p>Click the button below to accept the invitation:</p>
<p style="margin: 30px 0;">
<a href="{invitation_url}"
style="background-color: #c9b974; color: #0D0F11; padding: 8px 16px;
text-decoration: none; border-radius: 8px; display: inline-block;
font-size: 14px; font-weight: 600;">
Accept Invitation
</a>
</p>
<p style="color: #666; font-size: 14px;">
Or copy and paste this link into your browser:<br>
<a href="{invitation_url}" style="color: #c9b974; font-weight: 600;">{invitation_url}</a>
</p>
<p style="color: #666; font-size: 14px;">
This invitation will expire in 7 days.
</p>
<p style="color: #666; font-size: 14px;">
If you weren't expecting this invitation, you can safely ignore this email.
</p>
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
<p style="color: #999; font-size: 12px;">
Best,<br>
The OpenHands Team
</p>
</div>
""",
}
try:
response = resend.Emails.send(params)
logger.info(
'Invitation email sent',
extra={
'invitation_id': invitation_id,
'email': to_email,
'response_id': response.get('id') if response else None,
},
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation_id,
'email': to_email,
'error': str(e),
},
)
raise

View File

@@ -0,0 +1,397 @@
"""Service for managing organization invitations."""
import asyncio
from uuid import UUID
from server.auth.token_manager import TokenManager
from server.constants import ROLE_ADMIN, ROLE_OWNER
from server.routes.org_invitation_models import (
EmailMismatchError,
InsufficientPermissionError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.services.email_service import EmailService
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import OrgInvitationStore
from storage.org_member_store import OrgMemberStore
from storage.org_service import OrgService
from storage.org_store import OrgStore
from storage.role_store import RoleStore
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
class OrgInvitationService:
"""Service for organization invitation operations."""
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_name: str,
inviter_id: UUID,
) -> OrgInvitation:
"""Create a new organization invitation.
This method:
1. Validates the organization exists
2. Validates this is not a personal workspace
3. Checks inviter has owner/admin role
4. Validates role assignment permissions
5. Checks if user is already a member
6. Creates the invitation
7. Sends the invitation email
Args:
org_id: Organization UUID
email: Invitee's email address
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitation
Returns:
OrgInvitation: The created invitation
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
UserAlreadyMemberError: If email is already a member
InvitationAlreadyExistsError: If pending invitation exists
"""
email = email.lower().strip()
logger.info(
'Creating organization invitation',
extra={
'org_id': str(org_id),
'email': email,
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
# Step 2: Check this is not a personal workspace
# A personal workspace has org_id matching the user's id
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
# Step 3: Check inviter is a member and has permission
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
# Step 4: Validate role assignment permissions
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
# Get the target role
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 5: Check if user is already a member (by email)
existing_user = await UserStore.get_user_by_email_async(email)
if existing_user:
existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id)
if existing_member:
raise UserAlreadyMemberError(
'User is already a member of this organization'
)
# Step 6: Create the invitation
invitation = await OrgInvitationStore.create_invitation(
org_id=org_id,
email=email,
role_id=target_role.id,
inviter_id=inviter_id,
)
# Step 7: Send invitation email
try:
# Get inviter info for the email
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
inviter_name = 'A team member'
if inviter_user and inviter_user.email:
inviter_name = inviter_user.email.split('@')[0]
EmailService.send_invitation_email(
to_email=email,
org_name=org.name,
inviter_name=inviter_name,
role_name=target_role.name,
invitation_token=invitation.token,
invitation_id=invitation.id,
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation.id,
'email': email,
'error': str(e),
},
)
# Don't fail the invitation creation if email fails
# The user can still access via direct link
return invitation
@staticmethod
async def create_invitations_batch(
org_id: UUID,
emails: list[str],
role_name: str,
inviter_id: UUID,
) -> tuple[list[OrgInvitation], list[tuple[str, str]]]:
"""Create multiple organization invitations concurrently.
Validates permissions once upfront, then creates invitations in parallel.
Args:
org_id: Organization UUID
emails: List of invitee email addresses
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitations
Returns:
Tuple of (successful_invitations, failed_emails_with_errors)
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
"""
logger.info(
'Creating batch organization invitations',
extra={
'org_id': str(org_id),
'email_count': len(emails),
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate permissions upfront (shared for all emails)
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 2: Create invitations concurrently
async def create_single(
email: str,
) -> tuple[str, OrgInvitation | None, str | None]:
"""Create single invitation, return (email, invitation, error)."""
try:
invitation = await OrgInvitationService.create_invitation(
org_id=org_id,
email=email,
role_name=role_name,
inviter_id=inviter_id,
)
return (email, invitation, None)
except (UserAlreadyMemberError, ValueError) as e:
return (email, None, str(e))
results = await asyncio.gather(*[create_single(email) for email in emails])
# Step 3: Separate successes and failures
successful: list[OrgInvitation] = []
failed: list[tuple[str, str]] = []
for email, invitation, error in results:
if invitation:
successful.append(invitation)
elif error:
failed.append((email, error))
logger.info(
'Batch invitation creation completed',
extra={
'org_id': str(org_id),
'successful': len(successful),
'failed': len(failed),
},
)
return successful, failed
@staticmethod
async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation:
"""Accept an organization invitation.
This method:
1. Validates the token and invitation status
2. Checks expiration
3. Verifies user is not already a member
4. Creates LiteLLM integration
5. Adds user to the organization
6. Marks invitation as accepted
Args:
token: The invitation token
user_id: The user accepting the invitation
Returns:
OrgInvitation: The accepted invitation
Raises:
InvitationInvalidError: If token is invalid or invitation not pending
InvitationExpiredError: If invitation has expired
UserAlreadyMemberError: If user is already a member
"""
logger.info(
'Accepting organization invitation',
extra={
'token_prefix': token[:10] + '...' if len(token) > 10 else token,
'user_id': str(user_id),
},
)
# Step 1: Get and validate invitation
invitation = await OrgInvitationStore.get_invitation_by_token(token)
if not invitation:
raise InvitationInvalidError('Invalid invitation token')
if invitation.status != OrgInvitation.STATUS_PENDING:
if invitation.status == OrgInvitation.STATUS_ACCEPTED:
raise InvitationInvalidError('Invitation has already been accepted')
elif invitation.status == OrgInvitation.STATUS_REVOKED:
raise InvitationInvalidError('Invitation has been revoked')
else:
raise InvitationInvalidError('Invitation is no longer valid')
# Step 2: Check expiration
if OrgInvitationStore.is_token_expired(invitation):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
raise InvitationExpiredError('Invitation has expired')
# Step 2.5: Verify user email matches invitation email
user = await UserStore.get_user_by_id_async(str(user_id))
if not user:
raise InvitationInvalidError('User not found')
user_email = user.email
# Fallback: fetch email from Keycloak if not in database (for existing users)
if not user_email:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
user_email = user_info.get('email') if user_info else None
if not user_email:
raise EmailMismatchError('Your account does not have an email address')
user_email = user_email.lower().strip()
invitation_email = invitation.email.lower().strip()
if user_email != invitation_email:
logger.warning(
'Email mismatch during invitation acceptance',
extra={
'user_id': str(user_id),
'user_email': user_email,
'invitation_email': invitation_email,
'invitation_id': invitation.id,
},
)
raise EmailMismatchError()
# Step 3: Check if user is already a member
existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id)
if existing_member:
raise UserAlreadyMemberError(
'You are already a member of this organization'
)
# Step 4: Create LiteLLM integration for the user in the new org
try:
settings = await OrgService.create_litellm_integration(
invitation.org_id, str(user_id)
)
except Exception as e:
logger.error(
'Failed to create LiteLLM integration for invitation acceptance',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'error': str(e),
},
)
raise InvitationInvalidError(
'Failed to set up organization access. Please try again.'
)
# Step 5: Add user to organization
from storage.org_member_store import OrgMemberStore as OMS
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
# Don't override with org defaults - use invitation-specified role
org_member_kwargs.pop('llm_model', None)
org_member_kwargs.pop('llm_base_url', None)
OrgMemberStore.add_user_to_org(
org_id=invitation.org_id,
user_id=user_id,
role_id=invitation.role_id,
llm_api_key=settings.llm_api_key,
status='active',
)
# Step 6: Mark invitation as accepted
updated_invitation = await OrgInvitationStore.update_invitation_status(
invitation.id,
OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
logger.info(
'Organization invitation accepted',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'role_id': invitation.role_id,
},
)
return updated_invitation

View File

@@ -20,6 +20,7 @@ from storage.linear_workspace import LinearWorkspace
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from storage.openhands_pr import OpenhandsPR
from storage.org import Org
from storage.org_invitation import OrgInvitation
from storage.org_member import OrgMember
from storage.proactive_convos import ProactiveConversation
from storage.role import Role
@@ -65,6 +66,7 @@ __all__ = [
'MaintenanceTaskStatus',
'OpenhandsPR',
'Org',
'OrgInvitation',
'OrgMember',
'ProactiveConversation',
'Role',

View File

@@ -52,6 +52,7 @@ class Org(Base): # type: ignore
# Relationships
org_members = relationship('OrgMember', back_populates='org')
current_users = relationship('User', back_populates='current_org')
invitations = relationship('OrgInvitation', back_populates='org')
billing_sessions = relationship('BillingSession', back_populates='org')
stored_conversation_metadata_saas = relationship(
'StoredConversationMetadataSaas', back_populates='org'

View File

@@ -0,0 +1,59 @@
"""
SQLAlchemy model for Organization Invitation.
"""
from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text
from sqlalchemy.orm import relationship
from storage.base import Base
class OrgInvitation(Base): # type: ignore
"""Organization invitation model.
Represents an invitation for a user to join an organization.
Invitations are created by organization owners/admins and contain
a secure token that can be used to accept the invitation.
"""
__tablename__ = 'org_invitation'
id = Column(Integer, primary_key=True, autoincrement=True)
token = Column(String(64), nullable=False, unique=True, index=True)
org_id = Column(
UUID(as_uuid=True),
ForeignKey('org.id', ondelete='CASCADE'),
nullable=False,
index=True,
)
email = Column(String(255), nullable=False, index=True)
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
status = Column(
String(20),
nullable=False,
server_default=text("'pending'"),
)
created_at = Column(
DateTime,
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
)
expires_at = Column(DateTime, nullable=False)
accepted_at = Column(DateTime, nullable=True)
accepted_by_user_id = Column(
UUID(as_uuid=True),
ForeignKey('user.id'),
nullable=True,
)
# Relationships
org = relationship('Org', back_populates='invitations')
role = relationship('Role')
inviter = relationship('User', foreign_keys=[inviter_id])
accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id])
# Status constants
STATUS_PENDING = 'pending'
STATUS_ACCEPTED = 'accepted'
STATUS_REVOKED = 'revoked'
STATUS_EXPIRED = 'expired'

View File

@@ -0,0 +1,227 @@
"""
Store class for managing organization invitations.
"""
import secrets
import string
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
from sqlalchemy import and_, select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker
from storage.org_invitation import OrgInvitation
from openhands.core.logger import openhands_logger as logger
# Invitation token configuration
INVITATION_TOKEN_PREFIX = 'inv-'
INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix
DEFAULT_EXPIRATION_DAYS = 7
class OrgInvitationStore:
"""Store for managing organization invitations."""
@staticmethod
def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str:
"""Generate a secure invitation token.
Uses cryptographically secure random generation for tokens.
Pattern from api_key_store.py.
Args:
length: Length of the random part of the token
Returns:
str: Token with prefix (e.g., 'inv-aBcDeF123...')
"""
alphabet = string.ascii_letters + string.digits
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
return f'{INVITATION_TOKEN_PREFIX}{random_part}'
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_id: int,
inviter_id: UUID,
expiration_days: int = DEFAULT_EXPIRATION_DAYS,
) -> OrgInvitation:
"""Create a new organization invitation.
Args:
org_id: Organization UUID
email: Invitee's email address
role_id: Role ID to assign on acceptance
inviter_id: User ID of the person creating the invitation
expiration_days: Days until the invitation expires
Returns:
OrgInvitation: The created invitation record
"""
async with a_session_maker() as session:
token = OrgInvitationStore.generate_token()
# Use timezone-naive datetime for database compatibility
expires_at = datetime.utcnow() + timedelta(days=expiration_days)
invitation = OrgInvitation(
token=token,
org_id=org_id,
email=email.lower().strip(),
role_id=role_id,
inviter_id=inviter_id,
status=OrgInvitation.STATUS_PENDING,
expires_at=expires_at,
)
session.add(invitation)
await session.commit()
# Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.role))
.filter(OrgInvitation.id == invitation.id)
)
invitation = result.scalars().first()
logger.info(
'Created organization invitation',
extra={
'invitation_id': invitation.id,
'org_id': str(org_id),
'email': email,
'inviter_id': str(inviter_id),
'expires_at': expires_at.isoformat(),
},
)
return invitation
@staticmethod
async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]:
"""Get an invitation by its token.
Args:
token: The invitation token
Returns:
OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role))
.filter(OrgInvitation.token == token)
)
return result.scalars().first()
@staticmethod
async def get_pending_invitation(
org_id: UUID, email: str
) -> Optional[OrgInvitation]:
"""Get a pending invitation for an email in an organization.
Args:
org_id: Organization UUID
email: Email address to check
Returns:
OrgInvitation or None if no pending invitation exists
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(
and_(
OrgInvitation.org_id == org_id,
OrgInvitation.email == email.lower().strip(),
OrgInvitation.status == OrgInvitation.STATUS_PENDING,
)
)
)
return result.scalars().first()
@staticmethod
async def update_invitation_status(
invitation_id: int,
status: str,
accepted_by_user_id: Optional[UUID] = None,
) -> Optional[OrgInvitation]:
"""Update an invitation's status.
Args:
invitation_id: The invitation ID
status: New status (pending, accepted, revoked, expired)
accepted_by_user_id: User ID who accepted (only for 'accepted' status)
Returns:
Updated OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(OrgInvitation.id == invitation_id)
)
invitation = result.scalars().first()
if not invitation:
return None
old_status = invitation.status
invitation.status = status
if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id:
# Use timezone-naive datetime for database compatibility
invitation.accepted_at = datetime.utcnow()
invitation.accepted_by_user_id = accepted_by_user_id
await session.commit()
await session.refresh(invitation)
logger.info(
'Updated invitation status',
extra={
'invitation_id': invitation_id,
'old_status': old_status,
'new_status': status,
'accepted_by_user_id': (
str(accepted_by_user_id) if accepted_by_user_id else None
),
},
)
return invitation
@staticmethod
def is_token_expired(invitation: OrgInvitation) -> bool:
"""Check if an invitation token has expired.
Args:
invitation: The invitation to check
Returns:
bool: True if expired, False otherwise
"""
# Use timezone-naive datetime for comparison (database stores without timezone)
now = datetime.utcnow()
return invitation.expires_at < now
@staticmethod
async def mark_expired_if_needed(invitation: OrgInvitation) -> bool:
"""Check if invitation is expired and update status if needed.
Args:
invitation: The invitation to check
Returns:
bool: True if invitation was marked as expired, False otherwise
"""
if (
invitation.status == OrgInvitation.STATUS_PENDING
and OrgInvitationStore.is_token_expired(invitation)
):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
return True
return False

View File

@@ -770,6 +770,30 @@ class UserStore:
finally:
await UserStore._release_user_creation_lock(user_id)
@staticmethod
async def get_user_by_email_async(email: str) -> Optional[User]:
"""Get user by email address (async version).
This method looks up a user by their email address. Note that email
addresses may not be unique across all users in rare cases.
Args:
email: The email address to search for
Returns:
User: The user with the matching email, or None if not found
"""
if not email:
return None
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.email == email.lower().strip())
)
return result.scalars().first()
@staticmethod
def list_users() -> list[User]:
"""List all users."""

View File

@@ -0,0 +1,181 @@
"""Tests for auth callback invitation acceptance - EmailMismatchError handling."""
import pytest
class TestAuthCallbackInvitationEmailMismatch:
"""Test cases for EmailMismatchError handling during auth callback."""
@pytest.fixture
def mock_redirect_url(self):
"""Base redirect URL."""
return 'https://app.example.com/'
@pytest.fixture
def mock_user_id(self):
"""Mock user ID."""
return '87654321-4321-8765-4321-876543218765'
def test_email_mismatch_appends_to_url_without_query_params(
self, mock_redirect_url, mock_user_id
):
"""Test that email_mismatch=true is appended correctly when URL has no query params."""
from server.routes.org_invitation_models import EmailMismatchError
# Simulate the logic from auth.py
redirect_url = mock_redirect_url
try:
raise EmailMismatchError('Your email does not match the invitation')
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert redirect_url == 'https://app.example.com/?email_mismatch=true'
def test_email_mismatch_appends_to_url_with_query_params(self, mock_user_id):
"""Test that email_mismatch=true is appended correctly when URL has existing query params."""
from server.routes.org_invitation_models import EmailMismatchError
redirect_url = 'https://app.example.com/?other_param=value'
try:
raise EmailMismatchError()
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert (
redirect_url
== 'https://app.example.com/?other_param=value&email_mismatch=true'
)
def test_email_mismatch_error_has_default_message(self):
"""Test that EmailMismatchError has the default message."""
from server.routes.org_invitation_models import EmailMismatchError
error = EmailMismatchError()
assert str(error) == 'Your email does not match the invitation'
def test_email_mismatch_error_accepts_custom_message(self):
"""Test that EmailMismatchError accepts a custom message."""
from server.routes.org_invitation_models import EmailMismatchError
custom_message = 'Custom error message'
error = EmailMismatchError(custom_message)
assert str(error) == custom_message
def test_email_mismatch_error_is_invitation_error(self):
"""Test that EmailMismatchError inherits from InvitationError."""
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationError,
)
error = EmailMismatchError()
assert isinstance(error, InvitationError)
class TestInvitationTokenInOAuthState:
"""Test cases for invitation token handling in OAuth state."""
def test_invitation_token_included_in_oauth_state(self):
"""Test that invitation token is included in OAuth state data."""
import base64
import json
# Simulate building OAuth state with invitation token
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert decoded_data['invitation_token'] == 'inv-test-token-12345'
assert decoded_data['redirect_url'] == 'https://app.example.com/'
def test_invitation_token_extracted_from_oauth_state(self):
"""Test that invitation token can be extracted from OAuth state."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
# Simulate decoding in callback
decoded_state = json.loads(base64.b64decode(encoded_state))
invitation_token = decoded_state.get('invitation_token')
assert invitation_token == 'inv-test-token-12345'
def test_oauth_state_without_invitation_token(self):
"""Test that OAuth state works without invitation token."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert 'invitation_token' not in decoded_data
assert decoded_data['redirect_url'] == 'https://app.example.com/'
class TestAuthCallbackInvitationErrors:
"""Test cases for various invitation error scenarios in auth callback."""
def test_invitation_expired_appends_flag(self):
"""Test that invitation_expired=true is appended for expired invitations."""
from server.routes.org_invitation_models import InvitationExpiredError
redirect_url = 'https://app.example.com/'
try:
raise InvitationExpiredError()
except InvitationExpiredError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
assert redirect_url == 'https://app.example.com/?invitation_expired=true'
def test_invitation_invalid_appends_flag(self):
"""Test that invitation_invalid=true is appended for invalid invitations."""
from server.routes.org_invitation_models import InvitationInvalidError
redirect_url = 'https://app.example.com/'
try:
raise InvitationInvalidError()
except InvitationInvalidError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
assert redirect_url == 'https://app.example.com/?invitation_invalid=true'
def test_already_member_appends_flag(self):
"""Test that already_member=true is appended when user is already a member."""
from server.routes.org_invitation_models import UserAlreadyMemberError
redirect_url = 'https://app.example.com/'
try:
raise UserAlreadyMemberError()
except UserAlreadyMemberError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
assert redirect_url == 'https://app.example.com/?already_member=true'

View File

@@ -0,0 +1,192 @@
"""Tests for email service."""
import os
from unittest.mock import MagicMock, patch
from server.services.email_service import (
DEFAULT_WEB_HOST,
EmailService,
)
class TestEmailServiceInvitationUrl:
"""Test cases for invitation URL generation."""
def test_invitation_url_uses_correct_endpoint(self):
"""Test that invitation URL points to the correct API endpoint."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
# Get the call arguments
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify the URL in the email HTML contains the correct endpoint
assert (
'/api/organizations/members/invite/accept?token='
in email_params['html']
)
assert 'inv-test-token-12345' in email_params['html']
def test_invitation_url_uses_web_host_env_var(self):
"""Test that invitation URL uses WEB_HOST environment variable."""
custom_host = 'https://custom.example.com'
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(
os.environ,
{'RESEND_API_KEY': 'test-key', 'WEB_HOST': custom_host},
),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{custom_host}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
def test_invitation_url_uses_default_host_when_env_not_set(self):
"""Test that invitation URL falls back to DEFAULT_WEB_HOST when env not set."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
env_without_web_host = {'RESEND_API_KEY': 'test-key'}
# Remove WEB_HOST if it exists
env_without_web_host.pop('WEB_HOST', None)
with (
patch.dict(os.environ, env_without_web_host, clear=True),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
# Clear WEB_HOST from the environment
os.environ.pop('WEB_HOST', None)
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{DEFAULT_WEB_HOST}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
class TestEmailServiceGetResendClient:
"""Test cases for Resend client initialization."""
def test_get_resend_client_returns_false_when_resend_not_available(self):
"""Test that _get_resend_client returns False when resend is not installed."""
with patch('server.services.email_service.RESEND_AVAILABLE', False):
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_false_when_api_key_not_configured(self):
"""Test that _get_resend_client returns False when API key is missing."""
with (
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch.dict(os.environ, {}, clear=True),
):
os.environ.pop('RESEND_API_KEY', None)
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_true_when_configured(self):
"""Test that _get_resend_client returns True when properly configured."""
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
result = EmailService._get_resend_client()
assert result is True
assert mock_resend.api_key == 'test-key'
class TestEmailServiceSendInvitationEmail:
"""Test cases for send_invitation_email method."""
def test_send_invitation_email_skips_when_client_not_ready(self):
"""Test that email sending is skipped when client is not ready."""
with patch.object(
EmailService, '_get_resend_client', return_value=False
) as mock_get_client:
# Should not raise, just return early
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token',
invitation_id=1,
)
mock_get_client.assert_called_once()
def test_send_invitation_email_includes_all_required_info(self):
"""Test that invitation email includes org name, inviter name, and role."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Acme Corp',
inviter_name='John Doe',
role_name='admin',
invitation_token='inv-test-token-12345',
invitation_id=42,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify email content
assert email_params['to'] == ['test@example.com']
assert 'Acme Corp' in email_params['subject']
assert 'John Doe' in email_params['html']
assert 'Acme Corp' in email_params['html']
assert 'admin' in email_params['html']

View File

@@ -0,0 +1,464 @@
"""Tests for organization invitation service - email validation."""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID
import pytest
from server.routes.org_invitation_models import (
EmailMismatchError,
)
from server.services.org_invitation_service import OrgInvitationService
from storage.org_invitation import OrgInvitation
class TestAcceptInvitationEmailValidation:
"""Test cases for email validation during invitation acceptance."""
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation with pending status."""
invitation = MagicMock(spec=OrgInvitation)
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.status = OrgInvitation.STATUS_PENDING
invitation.org_id = UUID('12345678-1234-5678-1234-567812345678')
invitation.role_id = 1
return invitation
@pytest.fixture
def mock_user(self):
"""Create a mock user with email."""
user = MagicMock()
user.id = UUID('87654321-4321-8765-4321-876543218765')
user.email = 'alice@example.com'
return user
@pytest.mark.asyncio
async def test_accept_invitation_email_matches(self, mock_invitation, mock_user):
"""Test that invitation is accepted when user email matches invitation email."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
with patch.object(
OrgInvitationService, 'accept_invitation', new_callable=AsyncMock
) as mock_accept:
mock_accept.return_value = mock_invitation
# Act
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_accept.assert_called_once_with(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_email_mismatch_raises_error(
self, mock_invitation, mock_user
):
"""Test that EmailMismatchError is raised when emails don't match."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
mock_user.email = 'bob@example.com' # Different email
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Act & Assert
with pytest.raises(EmailMismatchError):
await OrgInvitationService.accept_invitation(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_user_no_email_keycloak_fallback_matches(
self, mock_invitation
):
"""Test that Keycloak email is used when user has no email in database."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager instance
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(
return_value=mock_keycloak_user_info
)
mock_token_manager_class.return_value = mock_token_manager
mock_get_member.return_value = None # Not already a member
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because Keycloak email matches
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_token_manager.get_user_info_from_user_id.assert_called_once_with(
str(user_id)
)
@pytest.mark.asyncio
async def test_accept_invitation_no_email_anywhere_raises_error(
self, mock_invitation
):
"""Test that EmailMismatchError is raised when user has no email in database or Keycloak."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager to return no email
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(return_value={})
mock_token_manager_class.return_value = mock_token_manager
# Act & Assert
with pytest.raises(EmailMismatchError) as exc_info:
await OrgInvitationService.accept_invitation(token, user_id)
assert 'does not have an email address' in str(exc_info.value)
@pytest.mark.asyncio
async def test_accept_invitation_email_comparison_is_case_insensitive(
self, mock_invitation
):
"""Test that email comparison is case insensitive."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = 'ALICE@EXAMPLE.COM' # Uppercase email
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
mock_get_member.return_value = None
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because emails match case-insensitively
await OrgInvitationService.accept_invitation(token, user_id)
# Assert - invitation was accepted (update_invitation_status was called)
mock_update_status.assert_called_once()
class TestCreateInvitationsBatch:
"""Test cases for batch invitation creation."""
@pytest.fixture
def org_id(self):
"""Organization UUID for testing."""
return UUID('12345678-1234-5678-1234-567812345678')
@pytest.fixture
def inviter_id(self):
"""Inviter UUID for testing."""
return UUID('87654321-4321-8765-4321-876543218765')
@pytest.fixture
def mock_org(self):
"""Create a mock organization."""
org = MagicMock()
org.id = UUID('12345678-1234-5678-1234-567812345678')
org.name = 'Test Org'
return org
@pytest.fixture
def mock_inviter_member(self):
"""Create a mock inviter member with owner role."""
member = MagicMock()
member.user_id = UUID('87654321-4321-8765-4321-876543218765')
member.role_id = 1
return member
@pytest.fixture
def mock_owner_role(self):
"""Create a mock owner role."""
role = MagicMock()
role.id = 1
role.name = 'owner'
return role
@pytest.fixture
def mock_member_role(self):
"""Create a mock member role."""
role = MagicMock()
role.id = 3
role.name = 'member'
return role
@pytest.mark.asyncio
async def test_batch_creates_all_invitations_successfully(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch creation succeeds for all valid emails."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
mock_invitation_1 = MagicMock(spec=OrgInvitation)
mock_invitation_1.id = 1
mock_invitation_2 = MagicMock(spec=OrgInvitation)
mock_invitation_2.id = 2
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation_1, mock_invitation_2],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 2
assert len(failed) == 0
@pytest.mark.asyncio
async def test_batch_handles_partial_success(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch returns partial results when some emails fail."""
# Arrange
from server.routes.org_invitation_models import UserAlreadyMemberError
emails = ['alice@example.com', 'existing@example.com']
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation, UserAlreadyMemberError()],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 1
assert len(failed) == 1
assert failed[0][0] == 'existing@example.com'
@pytest.mark.asyncio
async def test_batch_fails_entirely_on_permission_error(self, org_id, inviter_id):
"""Test that permission error fails the entire batch upfront."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
with patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=None, # Organization not found
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
assert 'not found' in str(exc_info.value)
@pytest.mark.asyncio
async def test_batch_fails_on_invalid_role(
self, org_id, inviter_id, mock_org, mock_inviter_member, mock_owner_role
):
"""Test that invalid role fails the entire batch."""
# Arrange
emails = ['alice@example.com']
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=None, # Invalid role
),
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='invalid_role',
inviter_id=inviter_id,
)
assert 'Invalid role' in str(exc_info.value)

View File

@@ -0,0 +1,308 @@
"""Tests for organization invitation store."""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import (
INVITATION_TOKEN_LENGTH,
INVITATION_TOKEN_PREFIX,
OrgInvitationStore,
)
class TestGenerateToken:
"""Test cases for token generation."""
def test_generate_token_has_correct_prefix(self):
"""Test that generated tokens have the correct prefix."""
token = OrgInvitationStore.generate_token()
assert token.startswith(INVITATION_TOKEN_PREFIX)
def test_generate_token_has_correct_length(self):
"""Test that generated tokens have the correct total length."""
token = OrgInvitationStore.generate_token()
expected_length = len(INVITATION_TOKEN_PREFIX) + INVITATION_TOKEN_LENGTH
assert len(token) == expected_length
def test_generate_token_uses_alphanumeric_characters(self):
"""Test that generated tokens use only alphanumeric characters."""
token = OrgInvitationStore.generate_token()
# Remove prefix and check the rest is alphanumeric
random_part = token[len(INVITATION_TOKEN_PREFIX) :]
assert random_part.isalnum()
def test_generate_token_is_unique(self):
"""Test that generated tokens are unique (probabilistically)."""
tokens = [OrgInvitationStore.generate_token() for _ in range(100)]
assert len(set(tokens)) == 100
class TestIsTokenExpired:
"""Test cases for token expiration checking."""
def test_token_not_expired_when_future(self):
"""Test that tokens with future expiration are not expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() + timedelta(days=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is False
def test_token_expired_when_past(self):
"""Test that tokens with past expiration are expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() - timedelta(seconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
def test_token_expired_at_exact_boundary(self):
"""Test that tokens at exact expiration time are expired."""
# A token that expires "now" should be expired
now = datetime.utcnow()
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = now - timedelta(microseconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
class TestCreateInvitation:
"""Test cases for invitation creation."""
@pytest.mark.asyncio
async def test_create_invitation_normalizes_email(self):
"""Test that email is normalized (lowercase, stripped) on creation."""
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.execute = AsyncMock()
# Mock the result of the re-fetch query
mock_result = MagicMock()
mock_invitation = MagicMock()
mock_invitation.id = 1
mock_invitation.email = 'test@example.com'
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute.return_value = mock_result
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.create_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
role_id=1,
inviter_id=UUID('87654321-4321-8765-4321-876543218765'),
)
# Verify that the OrgInvitation was created with normalized email
add_call = mock_session.add.call_args
created_invitation = add_call[0][0]
assert created_invitation.email == 'test@example.com'
class TestGetInvitationByToken:
"""Test cases for getting invitation by token."""
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_invitation(self):
"""Test that get_invitation_by_token returns the invitation when found."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.token = 'inv-test-token-12345'
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-test-token-12345'
)
assert result == mock_invitation
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_none_when_not_found(self):
"""Test that get_invitation_by_token returns None when not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-nonexistent-token'
)
assert result is None
class TestGetPendingInvitation:
"""Test cases for getting pending invitation."""
@pytest.mark.asyncio
async def test_get_pending_invitation_normalizes_email(self):
"""Test that email is normalized when querying for pending invitations."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.get_pending_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
)
# Verify the query was called (email normalization happens in the filter)
assert mock_session.execute.called
class TestUpdateInvitationStatus:
"""Test cases for updating invitation status."""
@pytest.mark.asyncio
async def test_update_status_sets_accepted_at_for_accepted(self):
"""Test that accepted_at is set when status is accepted."""
from uuid import UUID
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.refresh = AsyncMock()
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
user_id = UUID('87654321-4321-8765-4321-876543218765')
await OrgInvitationStore.update_invitation_status(
invitation_id=1,
status=OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
assert mock_invitation.accepted_at is not None
assert mock_invitation.accepted_by_user_id == user_id
@pytest.mark.asyncio
async def test_update_status_returns_none_when_not_found(self):
"""Test that update returns None when invitation not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.update_invitation_status(
invitation_id=999,
status=OrgInvitation.STATUS_ACCEPTED,
)
assert result is None
class TestMarkExpiredIfNeeded:
"""Test cases for marking expired invitations."""
@pytest.mark.asyncio
async def test_marks_expired_when_pending_and_past_expiry(self):
"""Test that pending expired invitations are marked as expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is True
mock_update.assert_called_once_with(1, OrgInvitation.STATUS_EXPIRED)
@pytest.mark.asyncio
async def test_does_not_mark_when_not_expired(self):
"""Test that non-expired invitations are not marked."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() + timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()
@pytest.mark.asyncio
async def test_does_not_mark_when_not_pending(self):
"""Test that non-pending invitations are not marked even if expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_ACCEPTED
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()

View File

@@ -0,0 +1,388 @@
"""Tests for organization invitations API router."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.routes.org_invitations import accept_router, invitation_router
@pytest.fixture
def app():
"""Create a FastAPI app with the invitation routers."""
app = FastAPI()
app.include_router(invitation_router)
app.include_router(accept_router)
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
class TestRouterPrefixes:
"""Test that router prefixes are configured correctly."""
def test_invitation_router_has_correct_prefix(self):
"""Test that invitation_router has /api/organizations/{org_id}/members prefix."""
assert invitation_router.prefix == '/api/organizations/{org_id}/members'
def test_accept_router_has_correct_prefix(self):
"""Test that accept_router has /api/organizations/members/invite prefix."""
assert accept_router.prefix == '/api/organizations/members/invite'
class TestAcceptInvitationEndpoint:
"""Test cases for the accept invitation endpoint."""
@pytest.fixture
def mock_user_auth(self):
"""Create a mock user auth."""
user_auth = MagicMock()
user_auth.get_user_id = AsyncMock(
return_value='87654321-4321-8765-4321-876543218765'
)
return user_auth
@pytest.mark.asyncio
async def test_accept_unauthenticated_redirects_to_login(self, client):
"""Test that unauthenticated users are redirected to login with invitation token."""
with patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=None,
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert '/login?invitation_token=inv-test-token-123' in response.headers.get(
'location', ''
)
@pytest.mark.asyncio
async def test_accept_authenticated_success_redirects_home(
self, client, mock_user_auth
):
"""Test that successful acceptance redirects to home page."""
mock_invitation = MagicMock()
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
return_value=mock_invitation,
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
location = response.headers.get('location', '')
assert location.endswith('/')
assert 'invitation_expired' not in location
assert 'invitation_invalid' not in location
assert 'email_mismatch' not in location
@pytest.mark.asyncio
async def test_accept_expired_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that expired invitation redirects with invitation_expired=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationExpiredError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_expired=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_invalid_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that invalid invitation redirects with invitation_invalid=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationInvalidError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_invalid=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_already_member_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that already member error redirects with already_member=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=UserAlreadyMemberError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'already_member=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_email_mismatch_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that email mismatch error redirects with email_mismatch=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=EmailMismatchError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'email_mismatch=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_unexpected_error_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that unexpected errors redirect with invitation_error=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=Exception('Unexpected error'),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_error=true' in response.headers.get('location', '')
class TestCreateInvitationBatchEndpoint:
"""Test cases for the batch invitation creation endpoint."""
@pytest.fixture
def batch_app(self):
"""Create a FastAPI app with dependency overrides for batch tests."""
from openhands.server.user_auth import get_user_id
app = FastAPI()
app.include_router(invitation_router)
# Override the get_user_id dependency
app.dependency_overrides[get_user_id] = (
lambda: '87654321-4321-8765-4321-876543218765'
)
return app
@pytest.fixture
def batch_client(self, batch_app):
"""Create a test client with dependency overrides."""
return TestClient(batch_app)
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation."""
from datetime import datetime
invitation = MagicMock()
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.role = MagicMock(name='member')
invitation.role.name = 'member'
invitation.role_id = 3
invitation.status = 'pending'
invitation.created_at = datetime(2026, 2, 17, 10, 0, 0)
invitation.expires_at = datetime(2026, 2, 24, 10, 0, 0)
return invitation
@pytest.mark.asyncio
async def test_batch_create_returns_successful_invitations(
self, batch_client, mock_invitation
):
"""Test that batch creation returns successful invitations."""
mock_invitation_2 = MagicMock()
mock_invitation_2.id = 2
mock_invitation_2.email = 'bob@example.com'
mock_invitation_2.role = MagicMock()
mock_invitation_2.role.name = 'member'
mock_invitation_2.role_id = 3
mock_invitation_2.status = 'pending'
mock_invitation_2.created_at = mock_invitation.created_at
mock_invitation_2.expires_at = mock_invitation.expires_at
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation, mock_invitation_2], []),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'bob@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 2
assert len(data['failed']) == 0
@pytest.mark.asyncio
async def test_batch_create_returns_partial_success(
self, batch_client, mock_invitation
):
"""Test that batch creation returns both successful and failed invitations."""
failed_emails = [('existing@example.com', 'User is already a member')]
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation], failed_emails),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'existing@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 1
assert len(data['failed']) == 1
assert data['failed'][0]['email'] == 'existing@example.com'
assert 'already a member' in data['failed'][0]['error']
@pytest.mark.asyncio
async def test_batch_create_permission_denied_returns_403(self, batch_client):
"""Test that permission denied returns 403 for entire batch."""
from server.routes.org_invitation_models import InsufficientPermissionError
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=InsufficientPermissionError(
'Only owners and admins can invite'
),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'member'},
)
assert response.status_code == 403
assert 'owners and admins' in response.json()['detail']
@pytest.mark.asyncio
async def test_batch_create_invalid_role_returns_400(self, batch_client):
"""Test that invalid role returns 400."""
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=ValueError('Invalid role: superuser'),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'superuser'},
)
assert response.status_code == 400
assert 'Invalid role' in response.json()['detail']

View File

@@ -151,8 +151,9 @@ describe("LoginContent", () => {
await user.click(githubButton);
// Wait for async handleAuthRedirect to complete
// The URL includes state parameter added by handleAuthRedirect
await waitFor(() => {
expect(window.location.href).toBe(mockUrl);
expect(window.location.href).toContain(mockUrl);
});
});
@@ -201,4 +202,103 @@ describe("LoginContent", () => {
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
});
it("should display invitation pending message when hasInvitation is true", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation
/>
</MemoryRouter>,
);
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
it("should not display invitation pending message when hasInvitation is false", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation={false}
/>
</MemoryRouter>,
);
expect(
screen.queryByText("AUTH$INVITATION_PENDING"),
).not.toBeInTheDocument();
});
it("should call buildOAuthStateData when clicking auth button", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
expect(mockBuildOAuthStateData).toHaveBeenCalled();
const callArg = mockBuildOAuthStateData.mock.calls[0][0];
expect(callArg).toHaveProperty("redirect_url");
});
});
it("should encode state with invitation token when buildOAuthStateData provides token", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
const redirectUrl = window.location.href;
// The URL should contain an encoded state parameter
expect(redirectUrl).toContain("state=");
// Decode and verify the state contains invitation_token
const url = new URL(redirectUrl);
const state = url.searchParams.get("state");
if (state) {
const decodedState = JSON.parse(atob(state));
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
}
});
});
});

View File

@@ -0,0 +1,170 @@
import { act, renderHook } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
// Mock setSearchParams function
const mockSetSearchParams = vi.fn();
// Default mock searchParams
let mockSearchParamsData: Record<string, string> = {};
// Mock react-router
vi.mock("react-router", () => ({
useSearchParams: () => [
{
get: (key: string) => mockSearchParamsData[key] || null,
has: (key: string) => key in mockSearchParamsData,
},
mockSetSearchParams,
],
}));
// Import after mocking
import { useInvitation } from "#/hooks/use-invitation";
describe("useInvitation", () => {
beforeEach(() => {
// Clear localStorage before each test
localStorage.clear();
// Reset mock data
mockSearchParamsData = {};
mockSetSearchParams.mockClear();
});
afterEach(() => {
vi.clearAllMocks();
});
describe("initialization", () => {
it("should initialize with null token when localStorage is empty", () => {
// Arrange - localStorage is empty (cleared in beforeEach)
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
it("should initialize with token from localStorage if present", () => {
// Arrange
const storedToken = "inv-stored-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBe(storedToken);
expect(result.current.hasInvitation).toBe(true);
});
});
describe("URL token capture", () => {
it("should capture invitation_token from URL and store in localStorage", () => {
// Arrange
const urlToken = "inv-url-token-67890";
mockSearchParamsData = { invitation_token: urlToken };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBe(urlToken);
expect(mockSetSearchParams).toHaveBeenCalled();
});
});
describe("completion cleanup", () => {
it("should clear localStorage when email_mismatch param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { email_mismatch: "true" };
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(mockSetSearchParams).toHaveBeenCalled();
});
it("should clear localStorage when invitation_success param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { invitation_success: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
it("should clear localStorage when invitation_expired param is present", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token");
mockSearchParamsData = { invitation_expired: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
});
describe("buildOAuthStateData", () => {
it("should include invitation_token in OAuth state when token is present", () => {
// Arrange
const token = "inv-oauth-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, token);
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBe(token);
expect(stateData.redirect_url).toBe("/dashboard");
});
it("should not include invitation_token when no token is present", () => {
// Arrange - no token in localStorage
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBeUndefined();
expect(stateData.redirect_url).toBe("/dashboard");
});
});
describe("clearInvitation", () => {
it("should remove token from localStorage when called", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token-to-clear");
const { result } = renderHook(() => useInvitation());
// Act
act(() => {
result.current.clearInvitation();
});
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
});
});

View File

@@ -57,6 +57,22 @@ vi.mock("#/hooks/use-tracking", () => ({
}),
}));
const { useInvitationMock, buildOAuthStateDataMock } = vi.hoisted(() => ({
useInvitationMock: vi.fn(() => ({
invitationToken: null as string | null,
hasInvitation: false,
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
clearInvitation: vi.fn(),
})),
buildOAuthStateDataMock: vi.fn(
(baseState: Record<string, string>) => baseState,
),
}));
vi.mock("#/hooks/use-invitation", () => ({
useInvitation: () => useInvitationMock(),
}));
const RouterStub = createRoutesStub([
{
Component: LoginPage,
@@ -234,7 +250,8 @@ describe("LoginPage", () => {
});
await user.click(githubButton);
expect(window.location.href).toBe(mockUrl);
// URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain(mockUrl);
});
it("should redirect to GitLab auth URL when GitLab button is clicked", async () => {
@@ -255,7 +272,8 @@ describe("LoginPage", () => {
});
await user.click(gitlabButton);
expect(window.location.href).toBe("https://gitlab.com/oauth/authorize");
// URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain("https://gitlab.com/oauth/authorize");
});
it("should redirect to Bitbucket auth URL when Bitbucket button is clicked", async () => {
@@ -282,7 +300,8 @@ describe("LoginPage", () => {
});
await user.click(bitbucketButton);
expect(window.location.href).toBe(
// URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain(
"https://bitbucket.org/site/oauth2/authorize",
);
});
@@ -479,4 +498,137 @@ describe("LoginPage", () => {
});
});
});
describe("Invitation Flow", () => {
it("should display invitation pending message when hasInvitation is true", async () => {
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
});
it("should not display invitation pending message when hasInvitation is false", async () => {
useInvitationMock.mockReturnValue({
invitationToken: null,
hasInvitation: false,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByTestId("login-content")).toBeInTheDocument();
});
expect(
screen.queryByText("AUTH$INVITATION_PENDING"),
).not.toBeInTheDocument();
});
it("should pass buildOAuthStateData to LoginContent for OAuth state encoding", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: mockBuildOAuthStateData,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
).toBeInTheDocument();
});
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
// buildOAuthStateData should have been called during the OAuth redirect
expect(mockBuildOAuthStateData).toHaveBeenCalled();
});
it("should include invitation token in OAuth state when invitation is present", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: mockBuildOAuthStateData,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
).toBeInTheDocument();
});
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
// Verify the redirect URL contains the state with invitation token
await waitFor(() => {
expect(window.location.href).toContain("state=");
});
// Decode and verify the state contains invitation_token
const url = new URL(window.location.href);
const state = url.searchParams.get("state");
if (state) {
const decodedState = JSON.parse(atob(state));
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
}
});
it("should handle login with invitation_token URL parameter", async () => {
useInvitationMock.mockReturnValue({
invitationToken: "inv-url-token-67890",
hasInvitation: true,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login?invitation_token=inv-url-token-67890"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
});
});
});

View File

@@ -42,6 +42,15 @@ vi.mock("#/utils/custom-toast-handlers", () => ({
displaySuccessToast: vi.fn(),
}));
vi.mock("#/hooks/use-invitation", () => ({
useInvitation: () => ({
invitationToken: null,
hasInvitation: false,
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
clearInvitation: vi.fn(),
}),
}));
function LoginStub() {
const [searchParams] = useSearchParams();
const emailVerificationRequired =
@@ -353,4 +362,68 @@ describe("MainApp", () => {
);
});
});
describe("Invitation URL Parameters", () => {
beforeEach(() => {
vi.spyOn(AuthService, "authenticate").mockRejectedValue({
response: { status: 401 },
isAxiosError: true,
});
});
it("should redirect to login when email_mismatch=true is in query params", async () => {
renderMainApp(["/?email_mismatch=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_success=true is in query params", async () => {
renderMainApp(["/?invitation_success=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_expired=true is in query params", async () => {
renderMainApp(["/?invitation_expired=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_invalid=true is in query params", async () => {
renderMainApp(["/?invitation_invalid=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when already_member=true is in query params", async () => {
renderMainApp(["/?already_member=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
});
});

View File

@@ -21,6 +21,10 @@ export interface LoginContentProps {
emailVerified?: boolean;
hasDuplicatedEmail?: boolean;
recaptchaBlocked?: boolean;
hasInvitation?: boolean;
buildOAuthStateData?: (
baseStateData: Record<string, string>,
) => Record<string, string>;
}
export function LoginContent({
@@ -31,6 +35,8 @@ export function LoginContent({
emailVerified = false,
hasDuplicatedEmail = false,
recaptchaBlocked = false,
hasInvitation = false,
buildOAuthStateData,
}: LoginContentProps) {
const { t } = useTranslation();
const { trackLoginButtonClick } = useTracking();
@@ -59,31 +65,36 @@ export function LoginContent({
) => {
trackLoginButtonClick({ provider });
if (!config?.recaptcha_site_key || !recaptchaReady) {
// No reCAPTCHA or token generation failed - redirect normally
window.location.href = redirectUrl;
return;
const url = new URL(redirectUrl);
const currentState =
url.searchParams.get("state") || window.location.origin;
// Build base state data
let stateData: Record<string, string> = {
redirect_url: currentState,
};
// Add invitation token if present
if (buildOAuthStateData) {
stateData = buildOAuthStateData(stateData);
}
// If reCAPTCHA is configured, encode token in OAuth state
try {
const token = await executeRecaptcha("LOGIN");
if (token) {
const url = new URL(redirectUrl);
const currentState =
url.searchParams.get("state") || window.location.origin;
// Encode state with reCAPTCHA token for backend verification
const stateData = {
redirect_url: currentState,
recaptcha_token: token,
};
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
window.location.href = url.toString();
// If reCAPTCHA is configured, add token to state
if (config?.recaptcha_site_key && recaptchaReady) {
try {
const token = await executeRecaptcha("LOGIN");
if (token) {
stateData.recaptcha_token = token;
}
} catch (err) {
displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
return;
}
} catch (err) {
displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
}
// Encode state and redirect
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
window.location.href = url.toString();
};
const handleGitHubAuth = () => {
@@ -123,6 +134,10 @@ export function LoginContent({
const buttonBaseClasses =
"w-[301.5px] h-10 rounded p-2 flex items-center justify-center cursor-pointer transition-opacity hover:opacity-90 disabled:opacity-50 disabled:cursor-not-allowed";
const buttonLabelClasses = "text-sm font-medium leading-5 px-1";
const shouldShownHelperText =
emailVerified || hasDuplicatedEmail || recaptchaBlocked || hasInvitation;
return (
<div
className="flex flex-col items-center w-full gap-12.5"
@@ -136,20 +151,29 @@ export function LoginContent({
{t(I18nKey.AUTH$LETS_GET_STARTED)}
</h1>
{emailVerified && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
</p>
)}
{hasDuplicatedEmail && (
<p className="text-sm text-danger text-center">
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
</p>
)}
{recaptchaBlocked && (
<p className="text-sm text-danger text-center max-w-125">
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
</p>
{shouldShownHelperText && (
<div className="flex flex-col items-center gap-3">
{emailVerified && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
</p>
)}
{hasDuplicatedEmail && (
<p className="text-sm text-danger text-center">
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
</p>
)}
{recaptchaBlocked && (
<p className="text-sm text-danger text-center max-w-125">
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
</p>
)}
{hasInvitation && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$INVITATION_PENDING)}
</p>
)}
</div>
)}
<div className="flex flex-col items-center gap-3">

View File

@@ -0,0 +1,119 @@
import React from "react";
import { useSearchParams } from "react-router";
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
interface UseInvitationReturn {
/** The invitation token, if present */
invitationToken: string | null;
/** Whether there is an active invitation */
hasInvitation: boolean;
/** Clear the stored invitation token */
clearInvitation: () => void;
/** Build OAuth state data including invitation token if present */
buildOAuthStateData: (
baseStateData: Record<string, string>,
) => Record<string, string>;
}
/**
* Hook to manage organization invitation tokens during the login flow.
*
* This hook:
* 1. Reads invitation_token from URL query params on mount
* 2. Persists the token in localStorage (survives page refresh and works across tabs)
* 3. Provides the token for inclusion in OAuth state
* 4. Provides cleanup method after successful authentication
*
* The invitation token flow:
* 1. User clicks invitation link → /api/invitations/accept?token=xxx
* 2. Backend redirects to /login?invitation_token=xxx
* 3. This hook captures token and stores in localStorage
* 4. When user clicks login button, token is included in OAuth state
* 5. After auth callback processes invitation, frontend clears the token
*
* Note: localStorage is used instead of sessionStorage to support scenarios where
* the user opens the email verification link in a new tab/browser window.
*/
export function useInvitation(): UseInvitationReturn {
const [searchParams, setSearchParams] = useSearchParams();
const [invitationToken, setInvitationToken] = React.useState<string | null>(
() => {
// Initialize from localStorage (persists across tabs and page refreshes)
if (typeof window !== "undefined") {
return localStorage.getItem(INVITATION_TOKEN_KEY);
}
return null;
},
);
// Capture invitation token from URL and persist to localStorage
// This only runs on the login page where the hook is used
React.useEffect(() => {
const tokenFromUrl = searchParams.get("invitation_token");
if (tokenFromUrl) {
// Store in localStorage for persistence across tabs and refreshes
localStorage.setItem(INVITATION_TOKEN_KEY, tokenFromUrl);
setInvitationToken(tokenFromUrl);
// Remove token from URL to clean up (prevents token exposure in browser history)
const newSearchParams = new URLSearchParams(searchParams);
newSearchParams.delete("invitation_token");
setSearchParams(newSearchParams, { replace: true });
}
}, [searchParams, setSearchParams]);
// Clear invitation token when invitation flow completes (success or failure)
// These query params are set by the backend after processing the invitation
React.useEffect(() => {
const invitationCompleted =
searchParams.has("invitation_success") ||
searchParams.has("invitation_expired") ||
searchParams.has("invitation_invalid") ||
searchParams.has("invitation_error") ||
searchParams.has("already_member") ||
searchParams.has("email_mismatch");
if (invitationCompleted) {
localStorage.removeItem(INVITATION_TOKEN_KEY);
setInvitationToken(null);
// Remove invitation params from URL to clean up
const newSearchParams = new URLSearchParams(searchParams);
newSearchParams.delete("invitation_success");
newSearchParams.delete("invitation_expired");
newSearchParams.delete("invitation_invalid");
newSearchParams.delete("invitation_error");
newSearchParams.delete("already_member");
newSearchParams.delete("email_mismatch");
setSearchParams(newSearchParams, { replace: true });
}
}, [searchParams, setSearchParams]);
const clearInvitation = React.useCallback(() => {
localStorage.removeItem(INVITATION_TOKEN_KEY);
setInvitationToken(null);
}, []);
const buildOAuthStateData = React.useCallback(
(baseStateData: Record<string, string>): Record<string, string> => {
const stateData = { ...baseStateData };
// Include invitation token in state if present
if (invitationToken) {
stateData.invitation_token = invitationToken;
}
return stateData;
},
[invitationToken],
);
return {
invitationToken,
hasInvitation: invitationToken !== null,
clearInvitation,
buildOAuthStateData,
};
}

View File

@@ -763,6 +763,7 @@ export enum I18nKey {
AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR",
AUTH$RECAPTCHA_BLOCKED = "AUTH$RECAPTCHA_BLOCKED",
AUTH$LETS_GET_STARTED = "AUTH$LETS_GET_STARTED",
AUTH$INVITATION_PENDING = "AUTH$INVITATION_PENDING",
COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE",
COMMON$AND = "COMMON$AND",
COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY",

View File

@@ -12207,6 +12207,22 @@
"de": "Lass uns anfangen",
"uk": "Почнімо"
},
"AUTH$INVITATION_PENDING": {
"en": "Sign in to accept your organization invitation",
"ja": "組織への招待を受け入れるにはサインインしてください",
"zh-CN": "登录以接受您的组织邀请",
"zh-TW": "登入以接受您的組織邀請",
"ko-KR": "조직 초대를 수락하려면 로그인하세요",
"no": "Logg inn for å godta organisasjonsinvitasjonen din",
"it": "Accedi per accettare l'invito della tua organizzazione",
"pt": "Faça login para aceitar o convite da sua organização",
"es": "Inicia sesión para aceptar la invitación de tu organización",
"ar": "سجّل الدخول لقبول دعوة مؤسستك",
"fr": "Connectez-vous pour accepter l'invitation de votre organisation",
"tr": "Organizasyon davetinizi kabul etmek için giriş yapın",
"de": "Melden Sie sich an, um Ihre Organisationseinladung anzunehmen",
"uk": "Увійдіть, щоб прийняти запрошення до організації"
},
"COMMON$TERMS_OF_SERVICE": {
"en": "Terms of Service",
"ja": "利用規約",

View File

@@ -10,6 +10,7 @@ import "./tailwind.css";
import "./index.css";
import React from "react";
import { Toaster } from "react-hot-toast";
import { useInvitation } from "#/hooks/use-invitation";
export function Layout({ children }: { children: React.ReactNode }) {
return (
@@ -37,5 +38,9 @@ export const meta: MetaFunction = () => [
];
export default function App() {
// Handle invitation token cleanup when invitation flow completes
// This runs on all pages to catch redirects from auth callback
useInvitation();
return <Outlet />;
}

View File

@@ -4,6 +4,7 @@ import { useIsAuthed } from "#/hooks/query/use-is-authed";
import { useConfig } from "#/hooks/query/use-config";
import { useGitHubAuthUrl } from "#/hooks/use-github-auth-url";
import { useEmailVerification } from "#/hooks/use-email-verification";
import { useInvitation } from "#/hooks/use-invitation";
import { LoginContent } from "#/components/features/auth/login-content";
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
@@ -23,6 +24,8 @@ export default function LoginPage() {
userId,
} = useEmailVerification();
const { hasInvitation, buildOAuthStateData } = useInvitation();
const gitHubAuthUrl = useGitHubAuthUrl({
appMode: config.data?.app_mode || null,
authUrl: config.data?.auth_url,
@@ -69,6 +72,8 @@ export default function LoginPage() {
emailVerified={emailVerified}
hasDuplicatedEmail={hasDuplicatedEmail}
recaptchaBlocked={recaptchaBlocked}
hasInvitation={hasInvitation}
buildOAuthStateData={buildOAuthStateData}
/>
</main>