feat: add user invitation logic (#12883)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Hiep Le
2026-02-18 13:24:19 +07:00
committed by GitHub
parent b18568da0b
commit 4d6f66ca28
28 changed files with 3666 additions and 53 deletions

View File

@@ -0,0 +1,110 @@
"""create org_invitation table
Revision ID: 094
Revises: 093
Create Date: 2026-02-18 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '094'
down_revision: Union[str, None] = '093'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create org_invitation table
op.create_table(
'org_invitation',
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
sa.Column('token', sa.String(64), nullable=False),
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('email', sa.String(255), nullable=False),
sa.Column('role_id', sa.Integer, nullable=False),
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
'status',
sa.String(20),
nullable=False,
server_default=sa.text("'pending'"),
),
sa.Column(
'created_at',
sa.DateTime,
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('accepted_at', sa.DateTime, nullable=True),
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
# Foreign key constraints
sa.ForeignKeyConstraint(
['org_id'],
['org.id'],
name='org_invitation_org_fkey',
ondelete='CASCADE',
),
sa.ForeignKeyConstraint(
['role_id'],
['role.id'],
name='org_invitation_role_fkey',
),
sa.ForeignKeyConstraint(
['inviter_id'],
['user.id'],
name='org_invitation_inviter_fkey',
),
sa.ForeignKeyConstraint(
['accepted_by_user_id'],
['user.id'],
name='org_invitation_accepter_fkey',
),
)
# Create indexes
op.create_index(
'ix_org_invitation_token',
'org_invitation',
['token'],
unique=True,
)
op.create_index(
'ix_org_invitation_org_id',
'org_invitation',
['org_id'],
)
op.create_index(
'ix_org_invitation_email',
'org_invitation',
['email'],
)
op.create_index(
'ix_org_invitation_status',
'org_invitation',
['status'],
)
# Composite index for checking pending invitations
op.create_index(
'ix_org_invitation_org_email_status',
'org_invitation',
['org_id', 'email', 'status'],
)
def downgrade() -> None:
# Drop indexes
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
# Drop table
op.drop_table('org_invitation')

View File

@@ -38,6 +38,12 @@ from server.routes.integration.linear import linear_integration_router # noqa:
from server.routes.integration.slack import slack_router # noqa: E402 from server.routes.integration.slack import slack_router # noqa: E402
from server.routes.mcp_patch import patch_mcp_server # noqa: E402 from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402 from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.org_invitations import ( # noqa: E402
accept_router as invitation_accept_router,
)
from server.routes.org_invitations import ( # noqa: E402
invitation_router,
)
from server.routes.orgs import org_router # noqa: E402 from server.routes.orgs import org_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402 from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402 from server.routes.user import saas_user_router # noqa: E402
@@ -99,6 +105,8 @@ if GITLAB_APP_CLIENT_ID:
base_app.include_router(api_keys_router) # Add routes for API key management base_app.include_router(api_keys_router) # Add routes for API key management
base_app.include_router(org_router) # Add routes for organization management base_app.include_router(org_router) # Add routes for organization management
base_app.include_router(invitation_router) # Add routes for org invitation management
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
add_github_proxy_routes(base_app) add_github_proxy_routes(base_app)
add_debugging_routes( add_debugging_routes(
base_app base_app

View File

@@ -160,6 +160,7 @@ class SetAuthCookieMiddleware:
'/api/billing/customer-setup-success', '/api/billing/customer-setup-success',
'/api/billing/stripe-webhook', '/api/billing/stripe-webhook',
'/api/email/resend', '/api/email/resend',
'/api/organizations/members/invite/accept',
'/oauth/device/authorize', '/oauth/device/authorize',
'/oauth/device/token', '/oauth/device/token',
'/api/v1/web-client/config', '/api/v1/web-client/config',

View File

@@ -5,6 +5,7 @@ import warnings
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Annotated, Literal, Optional from typing import Annotated, Literal, Optional
from urllib.parse import quote from urllib.parse import quote
from uuid import UUID as parse_uuid
import posthog import posthog
from fastapi import APIRouter, Header, HTTPException, Request, Response, status from fastapi import APIRouter, Header, HTTPException, Request, Response, status
@@ -26,6 +27,13 @@ from server.auth.token_manager import TokenManager
from server.config import sign_token from server.config import sign_token
from server.constants import IS_FEATURE_ENV from server.constants import IS_FEATURE_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id from server.routes.event_webhook import _get_session_api_key, _get_user_id
from server.services.org_invitation_service import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
OrgInvitationService,
UserAlreadyMemberError,
)
from storage.database import session_maker from storage.database import session_maker
from storage.user import User from storage.user import User
from storage.user_store import UserStore from storage.user_store import UserStore
@@ -104,22 +112,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
) )
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
Returns:
Tuple of (redirect_url, recaptcha_token, invitation_token).
Tokens may be None.
"""
if not state:
return '', None, None
try:
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return (
state_data.get('redirect_url', ''),
state_data.get('recaptcha_token'),
state_data.get('invitation_token'),
)
except Exception:
# Old format - state is just the redirect URL
return state, None, None
# Keep alias for backward compatibility
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]: def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
"""Extract redirect URL and reCAPTCHA token from OAuth state. """Extract redirect URL and reCAPTCHA token from OAuth state.
Deprecated: Use _extract_oauth_state instead.
Returns: Returns:
Tuple of (redirect_url, recaptcha_token). Token may be None. Tuple of (redirect_url, recaptcha_token). Token may be None.
""" """
if not state: redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
return '', None return redirect_url, recaptcha_token
try:
# Try to decode as JSON (new format with reCAPTCHA)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
except Exception:
# Old format - state is just the redirect URL
return state, None
@oauth_router.get('/keycloak/callback') @oauth_router.get('/keycloak/callback')
@@ -130,8 +156,8 @@ async def keycloak_callback(
error: Optional[str] = None, error: Optional[str] = None,
error_description: Optional[str] = None, error_description: Optional[str] = None,
): ):
# Extract redirect URL and reCAPTCHA token from state # Extract redirect URL, reCAPTCHA token, and invitation token from state
redirect_url, recaptcha_token = _extract_recaptcha_state(state) redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
if not redirect_url: if not redirect_url:
redirect_url = str(request.base_url) redirect_url = str(request.base_url)
@@ -302,8 +328,13 @@ async def keycloak_callback(
from server.routes.email import verify_email from server.routes.email import verify_email
await verify_email(request=request, user_id=user_id, is_auth_flow=True) await verify_email(request=request, user_id=user_id, is_auth_flow=True)
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}' verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
response = RedirectResponse(redirect_url, status_code=302) # Preserve invitation token so it can be included in OAuth state after verification
if invitation_token:
verification_redirect_url = (
f'{verification_redirect_url}&invitation_token={invitation_token}'
)
response = RedirectResponse(verification_redirect_url, status_code=302)
return response return response
# default to github IDP for now. # default to github IDP for now.
@@ -381,14 +412,90 @@ async def keycloak_callback(
) )
has_accepted_tos = user.accepted_tos is not None has_accepted_tos = user.accepted_tos is not None
# Process invitation token if present (after email verification but before TOS)
if invitation_token:
try:
logger.info(
'Processing invitation token during auth callback',
extra={
'user_id': user_id,
'invitation_token_prefix': invitation_token[:10] + '...',
},
)
await OrgInvitationService.accept_invitation(
invitation_token, parse_uuid(user_id)
)
logger.info(
'Invitation accepted during auth callback',
extra={'user_id': user_id},
)
except InvitationExpiredError:
logger.warning(
'Invitation expired during auth callback',
extra={'user_id': user_id},
)
# Add query param to redirect URL
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
except InvitationInvalidError as e:
logger.warning(
'Invalid invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
except UserAlreadyMemberError:
logger.info(
'User already member during invitation acceptance',
extra={'user_id': user_id},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
except EmailMismatchError as e:
logger.warning(
'Email mismatch during auth callback invitation acceptance',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
except Exception as e:
logger.exception(
'Unexpected error processing invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
# Don't fail the login if invitation processing fails
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_error=true'
else:
redirect_url = f'{redirect_url}?invitation_error=true'
# If the user hasn't accepted the TOS, redirect to the TOS page # If the user hasn't accepted the TOS, redirect to the TOS page
if not has_accepted_tos: if not has_accepted_tos:
encoded_redirect_url = quote(redirect_url, safe='') encoded_redirect_url = quote(redirect_url, safe='')
tos_redirect_url = ( tos_redirect_url = (
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}' f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
) )
if invitation_token:
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
response = RedirectResponse(tos_redirect_url, status_code=302) response = RedirectResponse(tos_redirect_url, status_code=302)
else: else:
if invitation_token:
redirect_url = f'{redirect_url}&invitation_success=true'
response = RedirectResponse(redirect_url, status_code=302) response = RedirectResponse(redirect_url, status_code=302)
set_response_cookie( set_response_cookie(

View File

@@ -0,0 +1,122 @@
"""
Pydantic models and custom exceptions for organization invitations.
"""
from pydantic import BaseModel, EmailStr
from storage.org_invitation import OrgInvitation
from storage.role_store import RoleStore
class InvitationError(Exception):
"""Base exception for invitation errors."""
pass
class InvitationAlreadyExistsError(InvitationError):
"""Raised when a pending invitation already exists for the email."""
def __init__(
self, message: str = 'A pending invitation already exists for this email'
):
super().__init__(message)
class UserAlreadyMemberError(InvitationError):
"""Raised when the user is already a member of the organization."""
def __init__(self, message: str = 'User is already a member of this organization'):
super().__init__(message)
class InvitationExpiredError(InvitationError):
"""Raised when the invitation has expired."""
def __init__(self, message: str = 'Invitation has expired'):
super().__init__(message)
class InvitationInvalidError(InvitationError):
"""Raised when the invitation is invalid or revoked."""
def __init__(self, message: str = 'Invitation is no longer valid'):
super().__init__(message)
class InsufficientPermissionError(InvitationError):
"""Raised when the user lacks permission to perform the action."""
def __init__(self, message: str = 'Insufficient permission'):
super().__init__(message)
class EmailMismatchError(InvitationError):
"""Raised when the accepting user's email doesn't match the invitation email."""
def __init__(self, message: str = 'Your email does not match the invitation'):
super().__init__(message)
class InvitationCreate(BaseModel):
"""Request model for creating invitation(s)."""
emails: list[EmailStr]
role: str = 'member' # Default to member role
class InvitationResponse(BaseModel):
"""Response model for invitation details."""
id: int
email: str
role: str
status: str
created_at: str
expires_at: str
inviter_email: str | None = None
@classmethod
def from_invitation(
cls,
invitation: OrgInvitation,
inviter_email: str | None = None,
) -> 'InvitationResponse':
"""Create an InvitationResponse from an OrgInvitation entity.
Args:
invitation: The invitation entity to convert
inviter_email: Optional email of the inviter
Returns:
InvitationResponse: The response model instance
"""
role_name = ''
if invitation.role:
role_name = invitation.role.name
elif invitation.role_id:
role = RoleStore.get_role_by_id(invitation.role_id)
role_name = role.name if role else ''
return cls(
id=invitation.id,
email=invitation.email,
role=role_name,
status=invitation.status,
created_at=invitation.created_at.isoformat(),
expires_at=invitation.expires_at.isoformat(),
inviter_email=inviter_email,
)
class InvitationFailure(BaseModel):
"""Response model for a failed invitation."""
email: str
error: str
class BatchInvitationResponse(BaseModel):
"""Response model for batch invitation creation."""
successful: list[InvitationResponse]
failed: list[InvitationFailure]

View File

@@ -0,0 +1,226 @@
"""API routes for organization invitations."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from server.routes.org_invitation_models import (
BatchInvitationResponse,
EmailMismatchError,
InsufficientPermissionError,
InvitationCreate,
InvitationExpiredError,
InvitationFailure,
InvitationInvalidError,
InvitationResponse,
UserAlreadyMemberError,
)
from server.services.org_invitation_service import OrgInvitationService
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
from openhands.server.user_auth.user_auth import get_user_auth
# Router for invitation operations on an organization (requires org_id)
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
# Router for accepting invitations (no org_id required)
accept_router = APIRouter(prefix='/api/organizations/members/invite')
@invitation_router.post(
'/invite',
response_model=BatchInvitationResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_invitation(
org_id: UUID,
invitation_data: InvitationCreate,
request: Request,
user_id: str = Depends(get_user_id),
):
"""Create organization invitations for multiple email addresses.
Sends emails to invitees with secure links to join the organization.
Supports batch invitations - some may succeed while others fail.
Permission rules:
- Only owners and admins can create invitations
- Admins can only invite with 'member' or 'admin' role (not 'owner')
- Owners can invite with any role
Args:
org_id: Organization UUID
invitation_data: Invitation details (emails array, role)
request: FastAPI request
user_id: Authenticated user ID (from dependency)
Returns:
BatchInvitationResponse: Lists of successful and failed invitations
Raises:
HTTPException 400: Invalid role or organization not found
HTTPException 403: User lacks permission to invite
HTTPException 429: Rate limit exceeded
"""
# Rate limit: 10 invitations per minute per user (6 seconds between requests)
await check_rate_limit_by_user_id(
request=request,
key_prefix='org_invitation_create',
user_id=user_id,
user_rate_limit_seconds=6,
)
try:
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=[str(email) for email in invitation_data.emails],
role_name=invitation_data.role,
inviter_id=UUID(user_id),
)
logger.info(
'Batch organization invitations created',
extra={
'org_id': str(org_id),
'total_emails': len(invitation_data.emails),
'successful': len(successful),
'failed': len(failed),
'inviter_id': user_id,
},
)
return BatchInvitationResponse(
successful=[InvitationResponse.from_invitation(inv) for inv in successful],
failed=[
InvitationFailure(email=email, error=error) for email, error in failed
],
)
except InsufficientPermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.exception(
'Unexpected error creating batch invitations',
extra={'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@accept_router.get('/accept')
async def accept_invitation(
token: str,
request: Request,
):
"""Accept an organization invitation via token.
This endpoint is accessed via the link in the invitation email.
Flow:
1. If user is authenticated: Accept invitation directly and redirect to home
2. If user is not authenticated: Redirect to login page with invitation token
- Frontend stores token and includes it in OAuth state during login
- After authentication, keycloak_callback processes the invitation
Args:
token: The invitation token from the email link
request: FastAPI request
Returns:
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
or home page with error query params on failure
"""
base_url = str(request.base_url).rstrip('/')
# Try to get user_id from auth (may not be authenticated)
user_id = None
try:
user_auth = await get_user_auth(request)
if user_auth:
user_id = await user_auth.get_user_id()
except Exception:
pass
if not user_id:
# User not authenticated - redirect to login page with invitation token
# Frontend will store the token and include it in OAuth state during login
logger.info(
'Invitation accept: redirecting unauthenticated user to login',
extra={'token_prefix': token[:10] + '...'},
)
login_url = f'{base_url}/login?invitation_token={token}'
return RedirectResponse(login_url, status_code=302)
# User is authenticated - process the invitation directly
try:
await OrgInvitationService.accept_invitation(token, UUID(user_id))
logger.info(
'Invitation accepted successfully',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
},
)
# Redirect to home page on success
return RedirectResponse(f'{base_url}/', status_code=302)
except InvitationExpiredError:
logger.warning(
'Invitation accept failed: expired',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
except InvitationInvalidError as e:
logger.warning(
'Invitation accept failed: invalid',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
except UserAlreadyMemberError:
logger.info(
'Invitation accept: user already member',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
except EmailMismatchError as e:
logger.warning(
'Invitation accept failed: email mismatch',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
except Exception as e:
logger.exception(
'Unexpected error accepting invitation',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)

View File

@@ -0,0 +1,131 @@
"""Email service for sending transactional emails via Resend."""
import os
try:
import resend
RESEND_AVAILABLE = True
except ImportError:
RESEND_AVAILABLE = False
from openhands.core.logger import openhands_logger as logger
DEFAULT_FROM_EMAIL = 'OpenHands <no-reply@openhands.dev>'
DEFAULT_WEB_HOST = 'https://app.all-hands.dev'
class EmailService:
"""Service for sending transactional emails."""
@staticmethod
def _get_resend_client() -> bool:
"""Initialize and return the Resend client.
Returns:
bool: True if client is ready, False otherwise
"""
if not RESEND_AVAILABLE:
logger.warning('Resend library not installed, skipping email')
return False
resend_api_key = os.environ.get('RESEND_API_KEY')
if not resend_api_key:
logger.warning('RESEND_API_KEY not configured, skipping email')
return False
resend.api_key = resend_api_key
return True
@staticmethod
def send_invitation_email(
to_email: str,
org_name: str,
inviter_name: str,
role_name: str,
invitation_token: str,
invitation_id: int,
) -> None:
"""Send an organization invitation email.
Args:
to_email: Recipient's email address
org_name: Name of the organization
inviter_name: Display name of the person who sent the invite
role_name: Role being offered (e.g., 'member', 'admin')
invitation_token: The secure invitation token
invitation_id: The invitation ID for logging
"""
if not EmailService._get_resend_client():
return
# Build invitation URL
web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST)
invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}'
from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL)
params = {
'from': from_email,
'to': [to_email],
'subject': f"You're invited to join {org_name} on OpenHands",
'html': f"""
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
<p>Hi,</p>
<p><strong>{inviter_name}</strong> has invited you to join <strong>{org_name}</strong> on OpenHands as a <strong>{role_name}</strong>.</p>
<p>Click the button below to accept the invitation:</p>
<p style="margin: 30px 0;">
<a href="{invitation_url}"
style="background-color: #c9b974; color: #0D0F11; padding: 8px 16px;
text-decoration: none; border-radius: 8px; display: inline-block;
font-size: 14px; font-weight: 600;">
Accept Invitation
</a>
</p>
<p style="color: #666; font-size: 14px;">
Or copy and paste this link into your browser:<br>
<a href="{invitation_url}" style="color: #c9b974; font-weight: 600;">{invitation_url}</a>
</p>
<p style="color: #666; font-size: 14px;">
This invitation will expire in 7 days.
</p>
<p style="color: #666; font-size: 14px;">
If you weren't expecting this invitation, you can safely ignore this email.
</p>
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
<p style="color: #999; font-size: 12px;">
Best,<br>
The OpenHands Team
</p>
</div>
""",
}
try:
response = resend.Emails.send(params)
logger.info(
'Invitation email sent',
extra={
'invitation_id': invitation_id,
'email': to_email,
'response_id': response.get('id') if response else None,
},
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation_id,
'email': to_email,
'error': str(e),
},
)
raise

View File

@@ -0,0 +1,397 @@
"""Service for managing organization invitations."""
import asyncio
from uuid import UUID
from server.auth.token_manager import TokenManager
from server.constants import ROLE_ADMIN, ROLE_OWNER
from server.routes.org_invitation_models import (
EmailMismatchError,
InsufficientPermissionError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.services.email_service import EmailService
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import OrgInvitationStore
from storage.org_member_store import OrgMemberStore
from storage.org_service import OrgService
from storage.org_store import OrgStore
from storage.role_store import RoleStore
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
class OrgInvitationService:
"""Service for organization invitation operations."""
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_name: str,
inviter_id: UUID,
) -> OrgInvitation:
"""Create a new organization invitation.
This method:
1. Validates the organization exists
2. Validates this is not a personal workspace
3. Checks inviter has owner/admin role
4. Validates role assignment permissions
5. Checks if user is already a member
6. Creates the invitation
7. Sends the invitation email
Args:
org_id: Organization UUID
email: Invitee's email address
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitation
Returns:
OrgInvitation: The created invitation
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
UserAlreadyMemberError: If email is already a member
InvitationAlreadyExistsError: If pending invitation exists
"""
email = email.lower().strip()
logger.info(
'Creating organization invitation',
extra={
'org_id': str(org_id),
'email': email,
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
# Step 2: Check this is not a personal workspace
# A personal workspace has org_id matching the user's id
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
# Step 3: Check inviter is a member and has permission
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
# Step 4: Validate role assignment permissions
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
# Get the target role
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 5: Check if user is already a member (by email)
existing_user = await UserStore.get_user_by_email_async(email)
if existing_user:
existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id)
if existing_member:
raise UserAlreadyMemberError(
'User is already a member of this organization'
)
# Step 6: Create the invitation
invitation = await OrgInvitationStore.create_invitation(
org_id=org_id,
email=email,
role_id=target_role.id,
inviter_id=inviter_id,
)
# Step 7: Send invitation email
try:
# Get inviter info for the email
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
inviter_name = 'A team member'
if inviter_user and inviter_user.email:
inviter_name = inviter_user.email.split('@')[0]
EmailService.send_invitation_email(
to_email=email,
org_name=org.name,
inviter_name=inviter_name,
role_name=target_role.name,
invitation_token=invitation.token,
invitation_id=invitation.id,
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation.id,
'email': email,
'error': str(e),
},
)
# Don't fail the invitation creation if email fails
# The user can still access via direct link
return invitation
@staticmethod
async def create_invitations_batch(
org_id: UUID,
emails: list[str],
role_name: str,
inviter_id: UUID,
) -> tuple[list[OrgInvitation], list[tuple[str, str]]]:
"""Create multiple organization invitations concurrently.
Validates permissions once upfront, then creates invitations in parallel.
Args:
org_id: Organization UUID
emails: List of invitee email addresses
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitations
Returns:
Tuple of (successful_invitations, failed_emails_with_errors)
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
"""
logger.info(
'Creating batch organization invitations',
extra={
'org_id': str(org_id),
'email_count': len(emails),
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate permissions upfront (shared for all emails)
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 2: Create invitations concurrently
async def create_single(
email: str,
) -> tuple[str, OrgInvitation | None, str | None]:
"""Create single invitation, return (email, invitation, error)."""
try:
invitation = await OrgInvitationService.create_invitation(
org_id=org_id,
email=email,
role_name=role_name,
inviter_id=inviter_id,
)
return (email, invitation, None)
except (UserAlreadyMemberError, ValueError) as e:
return (email, None, str(e))
results = await asyncio.gather(*[create_single(email) for email in emails])
# Step 3: Separate successes and failures
successful: list[OrgInvitation] = []
failed: list[tuple[str, str]] = []
for email, invitation, error in results:
if invitation:
successful.append(invitation)
elif error:
failed.append((email, error))
logger.info(
'Batch invitation creation completed',
extra={
'org_id': str(org_id),
'successful': len(successful),
'failed': len(failed),
},
)
return successful, failed
@staticmethod
async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation:
"""Accept an organization invitation.
This method:
1. Validates the token and invitation status
2. Checks expiration
3. Verifies user is not already a member
4. Creates LiteLLM integration
5. Adds user to the organization
6. Marks invitation as accepted
Args:
token: The invitation token
user_id: The user accepting the invitation
Returns:
OrgInvitation: The accepted invitation
Raises:
InvitationInvalidError: If token is invalid or invitation not pending
InvitationExpiredError: If invitation has expired
UserAlreadyMemberError: If user is already a member
"""
logger.info(
'Accepting organization invitation',
extra={
'token_prefix': token[:10] + '...' if len(token) > 10 else token,
'user_id': str(user_id),
},
)
# Step 1: Get and validate invitation
invitation = await OrgInvitationStore.get_invitation_by_token(token)
if not invitation:
raise InvitationInvalidError('Invalid invitation token')
if invitation.status != OrgInvitation.STATUS_PENDING:
if invitation.status == OrgInvitation.STATUS_ACCEPTED:
raise InvitationInvalidError('Invitation has already been accepted')
elif invitation.status == OrgInvitation.STATUS_REVOKED:
raise InvitationInvalidError('Invitation has been revoked')
else:
raise InvitationInvalidError('Invitation is no longer valid')
# Step 2: Check expiration
if OrgInvitationStore.is_token_expired(invitation):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
raise InvitationExpiredError('Invitation has expired')
# Step 2.5: Verify user email matches invitation email
user = await UserStore.get_user_by_id_async(str(user_id))
if not user:
raise InvitationInvalidError('User not found')
user_email = user.email
# Fallback: fetch email from Keycloak if not in database (for existing users)
if not user_email:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
user_email = user_info.get('email') if user_info else None
if not user_email:
raise EmailMismatchError('Your account does not have an email address')
user_email = user_email.lower().strip()
invitation_email = invitation.email.lower().strip()
if user_email != invitation_email:
logger.warning(
'Email mismatch during invitation acceptance',
extra={
'user_id': str(user_id),
'user_email': user_email,
'invitation_email': invitation_email,
'invitation_id': invitation.id,
},
)
raise EmailMismatchError()
# Step 3: Check if user is already a member
existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id)
if existing_member:
raise UserAlreadyMemberError(
'You are already a member of this organization'
)
# Step 4: Create LiteLLM integration for the user in the new org
try:
settings = await OrgService.create_litellm_integration(
invitation.org_id, str(user_id)
)
except Exception as e:
logger.error(
'Failed to create LiteLLM integration for invitation acceptance',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'error': str(e),
},
)
raise InvitationInvalidError(
'Failed to set up organization access. Please try again.'
)
# Step 5: Add user to organization
from storage.org_member_store import OrgMemberStore as OMS
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
# Don't override with org defaults - use invitation-specified role
org_member_kwargs.pop('llm_model', None)
org_member_kwargs.pop('llm_base_url', None)
OrgMemberStore.add_user_to_org(
org_id=invitation.org_id,
user_id=user_id,
role_id=invitation.role_id,
llm_api_key=settings.llm_api_key,
status='active',
)
# Step 6: Mark invitation as accepted
updated_invitation = await OrgInvitationStore.update_invitation_status(
invitation.id,
OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
logger.info(
'Organization invitation accepted',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'role_id': invitation.role_id,
},
)
return updated_invitation

View File

@@ -20,6 +20,7 @@ from storage.linear_workspace import LinearWorkspace
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from storage.openhands_pr import OpenhandsPR from storage.openhands_pr import OpenhandsPR
from storage.org import Org from storage.org import Org
from storage.org_invitation import OrgInvitation
from storage.org_member import OrgMember from storage.org_member import OrgMember
from storage.proactive_convos import ProactiveConversation from storage.proactive_convos import ProactiveConversation
from storage.role import Role from storage.role import Role
@@ -65,6 +66,7 @@ __all__ = [
'MaintenanceTaskStatus', 'MaintenanceTaskStatus',
'OpenhandsPR', 'OpenhandsPR',
'Org', 'Org',
'OrgInvitation',
'OrgMember', 'OrgMember',
'ProactiveConversation', 'ProactiveConversation',
'Role', 'Role',

View File

@@ -52,6 +52,7 @@ class Org(Base): # type: ignore
# Relationships # Relationships
org_members = relationship('OrgMember', back_populates='org') org_members = relationship('OrgMember', back_populates='org')
current_users = relationship('User', back_populates='current_org') current_users = relationship('User', back_populates='current_org')
invitations = relationship('OrgInvitation', back_populates='org')
billing_sessions = relationship('BillingSession', back_populates='org') billing_sessions = relationship('BillingSession', back_populates='org')
stored_conversation_metadata_saas = relationship( stored_conversation_metadata_saas = relationship(
'StoredConversationMetadataSaas', back_populates='org' 'StoredConversationMetadataSaas', back_populates='org'

View File

@@ -0,0 +1,59 @@
"""
SQLAlchemy model for Organization Invitation.
"""
from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text
from sqlalchemy.orm import relationship
from storage.base import Base
class OrgInvitation(Base): # type: ignore
"""Organization invitation model.
Represents an invitation for a user to join an organization.
Invitations are created by organization owners/admins and contain
a secure token that can be used to accept the invitation.
"""
__tablename__ = 'org_invitation'
id = Column(Integer, primary_key=True, autoincrement=True)
token = Column(String(64), nullable=False, unique=True, index=True)
org_id = Column(
UUID(as_uuid=True),
ForeignKey('org.id', ondelete='CASCADE'),
nullable=False,
index=True,
)
email = Column(String(255), nullable=False, index=True)
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
status = Column(
String(20),
nullable=False,
server_default=text("'pending'"),
)
created_at = Column(
DateTime,
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
)
expires_at = Column(DateTime, nullable=False)
accepted_at = Column(DateTime, nullable=True)
accepted_by_user_id = Column(
UUID(as_uuid=True),
ForeignKey('user.id'),
nullable=True,
)
# Relationships
org = relationship('Org', back_populates='invitations')
role = relationship('Role')
inviter = relationship('User', foreign_keys=[inviter_id])
accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id])
# Status constants
STATUS_PENDING = 'pending'
STATUS_ACCEPTED = 'accepted'
STATUS_REVOKED = 'revoked'
STATUS_EXPIRED = 'expired'

View File

@@ -0,0 +1,227 @@
"""
Store class for managing organization invitations.
"""
import secrets
import string
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
from sqlalchemy import and_, select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker
from storage.org_invitation import OrgInvitation
from openhands.core.logger import openhands_logger as logger
# Invitation token configuration
INVITATION_TOKEN_PREFIX = 'inv-'
INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix
DEFAULT_EXPIRATION_DAYS = 7
class OrgInvitationStore:
"""Store for managing organization invitations."""
@staticmethod
def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str:
"""Generate a secure invitation token.
Uses cryptographically secure random generation for tokens.
Pattern from api_key_store.py.
Args:
length: Length of the random part of the token
Returns:
str: Token with prefix (e.g., 'inv-aBcDeF123...')
"""
alphabet = string.ascii_letters + string.digits
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
return f'{INVITATION_TOKEN_PREFIX}{random_part}'
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_id: int,
inviter_id: UUID,
expiration_days: int = DEFAULT_EXPIRATION_DAYS,
) -> OrgInvitation:
"""Create a new organization invitation.
Args:
org_id: Organization UUID
email: Invitee's email address
role_id: Role ID to assign on acceptance
inviter_id: User ID of the person creating the invitation
expiration_days: Days until the invitation expires
Returns:
OrgInvitation: The created invitation record
"""
async with a_session_maker() as session:
token = OrgInvitationStore.generate_token()
# Use timezone-naive datetime for database compatibility
expires_at = datetime.utcnow() + timedelta(days=expiration_days)
invitation = OrgInvitation(
token=token,
org_id=org_id,
email=email.lower().strip(),
role_id=role_id,
inviter_id=inviter_id,
status=OrgInvitation.STATUS_PENDING,
expires_at=expires_at,
)
session.add(invitation)
await session.commit()
# Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.role))
.filter(OrgInvitation.id == invitation.id)
)
invitation = result.scalars().first()
logger.info(
'Created organization invitation',
extra={
'invitation_id': invitation.id,
'org_id': str(org_id),
'email': email,
'inviter_id': str(inviter_id),
'expires_at': expires_at.isoformat(),
},
)
return invitation
@staticmethod
async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]:
"""Get an invitation by its token.
Args:
token: The invitation token
Returns:
OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role))
.filter(OrgInvitation.token == token)
)
return result.scalars().first()
@staticmethod
async def get_pending_invitation(
org_id: UUID, email: str
) -> Optional[OrgInvitation]:
"""Get a pending invitation for an email in an organization.
Args:
org_id: Organization UUID
email: Email address to check
Returns:
OrgInvitation or None if no pending invitation exists
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(
and_(
OrgInvitation.org_id == org_id,
OrgInvitation.email == email.lower().strip(),
OrgInvitation.status == OrgInvitation.STATUS_PENDING,
)
)
)
return result.scalars().first()
@staticmethod
async def update_invitation_status(
invitation_id: int,
status: str,
accepted_by_user_id: Optional[UUID] = None,
) -> Optional[OrgInvitation]:
"""Update an invitation's status.
Args:
invitation_id: The invitation ID
status: New status (pending, accepted, revoked, expired)
accepted_by_user_id: User ID who accepted (only for 'accepted' status)
Returns:
Updated OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(OrgInvitation.id == invitation_id)
)
invitation = result.scalars().first()
if not invitation:
return None
old_status = invitation.status
invitation.status = status
if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id:
# Use timezone-naive datetime for database compatibility
invitation.accepted_at = datetime.utcnow()
invitation.accepted_by_user_id = accepted_by_user_id
await session.commit()
await session.refresh(invitation)
logger.info(
'Updated invitation status',
extra={
'invitation_id': invitation_id,
'old_status': old_status,
'new_status': status,
'accepted_by_user_id': (
str(accepted_by_user_id) if accepted_by_user_id else None
),
},
)
return invitation
@staticmethod
def is_token_expired(invitation: OrgInvitation) -> bool:
"""Check if an invitation token has expired.
Args:
invitation: The invitation to check
Returns:
bool: True if expired, False otherwise
"""
# Use timezone-naive datetime for comparison (database stores without timezone)
now = datetime.utcnow()
return invitation.expires_at < now
@staticmethod
async def mark_expired_if_needed(invitation: OrgInvitation) -> bool:
"""Check if invitation is expired and update status if needed.
Args:
invitation: The invitation to check
Returns:
bool: True if invitation was marked as expired, False otherwise
"""
if (
invitation.status == OrgInvitation.STATUS_PENDING
and OrgInvitationStore.is_token_expired(invitation)
):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
return True
return False

View File

@@ -770,6 +770,30 @@ class UserStore:
finally: finally:
await UserStore._release_user_creation_lock(user_id) await UserStore._release_user_creation_lock(user_id)
@staticmethod
async def get_user_by_email_async(email: str) -> Optional[User]:
"""Get user by email address (async version).
This method looks up a user by their email address. Note that email
addresses may not be unique across all users in rare cases.
Args:
email: The email address to search for
Returns:
User: The user with the matching email, or None if not found
"""
if not email:
return None
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.email == email.lower().strip())
)
return result.scalars().first()
@staticmethod @staticmethod
def list_users() -> list[User]: def list_users() -> list[User]:
"""List all users.""" """List all users."""

View File

@@ -0,0 +1,181 @@
"""Tests for auth callback invitation acceptance - EmailMismatchError handling."""
import pytest
class TestAuthCallbackInvitationEmailMismatch:
"""Test cases for EmailMismatchError handling during auth callback."""
@pytest.fixture
def mock_redirect_url(self):
"""Base redirect URL."""
return 'https://app.example.com/'
@pytest.fixture
def mock_user_id(self):
"""Mock user ID."""
return '87654321-4321-8765-4321-876543218765'
def test_email_mismatch_appends_to_url_without_query_params(
self, mock_redirect_url, mock_user_id
):
"""Test that email_mismatch=true is appended correctly when URL has no query params."""
from server.routes.org_invitation_models import EmailMismatchError
# Simulate the logic from auth.py
redirect_url = mock_redirect_url
try:
raise EmailMismatchError('Your email does not match the invitation')
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert redirect_url == 'https://app.example.com/?email_mismatch=true'
def test_email_mismatch_appends_to_url_with_query_params(self, mock_user_id):
"""Test that email_mismatch=true is appended correctly when URL has existing query params."""
from server.routes.org_invitation_models import EmailMismatchError
redirect_url = 'https://app.example.com/?other_param=value'
try:
raise EmailMismatchError()
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert (
redirect_url
== 'https://app.example.com/?other_param=value&email_mismatch=true'
)
def test_email_mismatch_error_has_default_message(self):
"""Test that EmailMismatchError has the default message."""
from server.routes.org_invitation_models import EmailMismatchError
error = EmailMismatchError()
assert str(error) == 'Your email does not match the invitation'
def test_email_mismatch_error_accepts_custom_message(self):
"""Test that EmailMismatchError accepts a custom message."""
from server.routes.org_invitation_models import EmailMismatchError
custom_message = 'Custom error message'
error = EmailMismatchError(custom_message)
assert str(error) == custom_message
def test_email_mismatch_error_is_invitation_error(self):
"""Test that EmailMismatchError inherits from InvitationError."""
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationError,
)
error = EmailMismatchError()
assert isinstance(error, InvitationError)
class TestInvitationTokenInOAuthState:
"""Test cases for invitation token handling in OAuth state."""
def test_invitation_token_included_in_oauth_state(self):
"""Test that invitation token is included in OAuth state data."""
import base64
import json
# Simulate building OAuth state with invitation token
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert decoded_data['invitation_token'] == 'inv-test-token-12345'
assert decoded_data['redirect_url'] == 'https://app.example.com/'
def test_invitation_token_extracted_from_oauth_state(self):
"""Test that invitation token can be extracted from OAuth state."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
# Simulate decoding in callback
decoded_state = json.loads(base64.b64decode(encoded_state))
invitation_token = decoded_state.get('invitation_token')
assert invitation_token == 'inv-test-token-12345'
def test_oauth_state_without_invitation_token(self):
"""Test that OAuth state works without invitation token."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert 'invitation_token' not in decoded_data
assert decoded_data['redirect_url'] == 'https://app.example.com/'
class TestAuthCallbackInvitationErrors:
"""Test cases for various invitation error scenarios in auth callback."""
def test_invitation_expired_appends_flag(self):
"""Test that invitation_expired=true is appended for expired invitations."""
from server.routes.org_invitation_models import InvitationExpiredError
redirect_url = 'https://app.example.com/'
try:
raise InvitationExpiredError()
except InvitationExpiredError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
assert redirect_url == 'https://app.example.com/?invitation_expired=true'
def test_invitation_invalid_appends_flag(self):
"""Test that invitation_invalid=true is appended for invalid invitations."""
from server.routes.org_invitation_models import InvitationInvalidError
redirect_url = 'https://app.example.com/'
try:
raise InvitationInvalidError()
except InvitationInvalidError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
assert redirect_url == 'https://app.example.com/?invitation_invalid=true'
def test_already_member_appends_flag(self):
"""Test that already_member=true is appended when user is already a member."""
from server.routes.org_invitation_models import UserAlreadyMemberError
redirect_url = 'https://app.example.com/'
try:
raise UserAlreadyMemberError()
except UserAlreadyMemberError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
assert redirect_url == 'https://app.example.com/?already_member=true'

View File

@@ -0,0 +1,192 @@
"""Tests for email service."""
import os
from unittest.mock import MagicMock, patch
from server.services.email_service import (
DEFAULT_WEB_HOST,
EmailService,
)
class TestEmailServiceInvitationUrl:
"""Test cases for invitation URL generation."""
def test_invitation_url_uses_correct_endpoint(self):
"""Test that invitation URL points to the correct API endpoint."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
# Get the call arguments
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify the URL in the email HTML contains the correct endpoint
assert (
'/api/organizations/members/invite/accept?token='
in email_params['html']
)
assert 'inv-test-token-12345' in email_params['html']
def test_invitation_url_uses_web_host_env_var(self):
"""Test that invitation URL uses WEB_HOST environment variable."""
custom_host = 'https://custom.example.com'
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(
os.environ,
{'RESEND_API_KEY': 'test-key', 'WEB_HOST': custom_host},
),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{custom_host}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
def test_invitation_url_uses_default_host_when_env_not_set(self):
"""Test that invitation URL falls back to DEFAULT_WEB_HOST when env not set."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
env_without_web_host = {'RESEND_API_KEY': 'test-key'}
# Remove WEB_HOST if it exists
env_without_web_host.pop('WEB_HOST', None)
with (
patch.dict(os.environ, env_without_web_host, clear=True),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
# Clear WEB_HOST from the environment
os.environ.pop('WEB_HOST', None)
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{DEFAULT_WEB_HOST}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
class TestEmailServiceGetResendClient:
"""Test cases for Resend client initialization."""
def test_get_resend_client_returns_false_when_resend_not_available(self):
"""Test that _get_resend_client returns False when resend is not installed."""
with patch('server.services.email_service.RESEND_AVAILABLE', False):
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_false_when_api_key_not_configured(self):
"""Test that _get_resend_client returns False when API key is missing."""
with (
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch.dict(os.environ, {}, clear=True),
):
os.environ.pop('RESEND_API_KEY', None)
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_true_when_configured(self):
"""Test that _get_resend_client returns True when properly configured."""
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
result = EmailService._get_resend_client()
assert result is True
assert mock_resend.api_key == 'test-key'
class TestEmailServiceSendInvitationEmail:
"""Test cases for send_invitation_email method."""
def test_send_invitation_email_skips_when_client_not_ready(self):
"""Test that email sending is skipped when client is not ready."""
with patch.object(
EmailService, '_get_resend_client', return_value=False
) as mock_get_client:
# Should not raise, just return early
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token',
invitation_id=1,
)
mock_get_client.assert_called_once()
def test_send_invitation_email_includes_all_required_info(self):
"""Test that invitation email includes org name, inviter name, and role."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Acme Corp',
inviter_name='John Doe',
role_name='admin',
invitation_token='inv-test-token-12345',
invitation_id=42,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify email content
assert email_params['to'] == ['test@example.com']
assert 'Acme Corp' in email_params['subject']
assert 'John Doe' in email_params['html']
assert 'Acme Corp' in email_params['html']
assert 'admin' in email_params['html']

View File

@@ -0,0 +1,464 @@
"""Tests for organization invitation service - email validation."""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID
import pytest
from server.routes.org_invitation_models import (
EmailMismatchError,
)
from server.services.org_invitation_service import OrgInvitationService
from storage.org_invitation import OrgInvitation
class TestAcceptInvitationEmailValidation:
"""Test cases for email validation during invitation acceptance."""
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation with pending status."""
invitation = MagicMock(spec=OrgInvitation)
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.status = OrgInvitation.STATUS_PENDING
invitation.org_id = UUID('12345678-1234-5678-1234-567812345678')
invitation.role_id = 1
return invitation
@pytest.fixture
def mock_user(self):
"""Create a mock user with email."""
user = MagicMock()
user.id = UUID('87654321-4321-8765-4321-876543218765')
user.email = 'alice@example.com'
return user
@pytest.mark.asyncio
async def test_accept_invitation_email_matches(self, mock_invitation, mock_user):
"""Test that invitation is accepted when user email matches invitation email."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
with patch.object(
OrgInvitationService, 'accept_invitation', new_callable=AsyncMock
) as mock_accept:
mock_accept.return_value = mock_invitation
# Act
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_accept.assert_called_once_with(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_email_mismatch_raises_error(
self, mock_invitation, mock_user
):
"""Test that EmailMismatchError is raised when emails don't match."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
mock_user.email = 'bob@example.com' # Different email
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Act & Assert
with pytest.raises(EmailMismatchError):
await OrgInvitationService.accept_invitation(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_user_no_email_keycloak_fallback_matches(
self, mock_invitation
):
"""Test that Keycloak email is used when user has no email in database."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager instance
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(
return_value=mock_keycloak_user_info
)
mock_token_manager_class.return_value = mock_token_manager
mock_get_member.return_value = None # Not already a member
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because Keycloak email matches
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_token_manager.get_user_info_from_user_id.assert_called_once_with(
str(user_id)
)
@pytest.mark.asyncio
async def test_accept_invitation_no_email_anywhere_raises_error(
self, mock_invitation
):
"""Test that EmailMismatchError is raised when user has no email in database or Keycloak."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager to return no email
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(return_value={})
mock_token_manager_class.return_value = mock_token_manager
# Act & Assert
with pytest.raises(EmailMismatchError) as exc_info:
await OrgInvitationService.accept_invitation(token, user_id)
assert 'does not have an email address' in str(exc_info.value)
@pytest.mark.asyncio
async def test_accept_invitation_email_comparison_is_case_insensitive(
self, mock_invitation
):
"""Test that email comparison is case insensitive."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = 'ALICE@EXAMPLE.COM' # Uppercase email
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
mock_get_member.return_value = None
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because emails match case-insensitively
await OrgInvitationService.accept_invitation(token, user_id)
# Assert - invitation was accepted (update_invitation_status was called)
mock_update_status.assert_called_once()
class TestCreateInvitationsBatch:
"""Test cases for batch invitation creation."""
@pytest.fixture
def org_id(self):
"""Organization UUID for testing."""
return UUID('12345678-1234-5678-1234-567812345678')
@pytest.fixture
def inviter_id(self):
"""Inviter UUID for testing."""
return UUID('87654321-4321-8765-4321-876543218765')
@pytest.fixture
def mock_org(self):
"""Create a mock organization."""
org = MagicMock()
org.id = UUID('12345678-1234-5678-1234-567812345678')
org.name = 'Test Org'
return org
@pytest.fixture
def mock_inviter_member(self):
"""Create a mock inviter member with owner role."""
member = MagicMock()
member.user_id = UUID('87654321-4321-8765-4321-876543218765')
member.role_id = 1
return member
@pytest.fixture
def mock_owner_role(self):
"""Create a mock owner role."""
role = MagicMock()
role.id = 1
role.name = 'owner'
return role
@pytest.fixture
def mock_member_role(self):
"""Create a mock member role."""
role = MagicMock()
role.id = 3
role.name = 'member'
return role
@pytest.mark.asyncio
async def test_batch_creates_all_invitations_successfully(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch creation succeeds for all valid emails."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
mock_invitation_1 = MagicMock(spec=OrgInvitation)
mock_invitation_1.id = 1
mock_invitation_2 = MagicMock(spec=OrgInvitation)
mock_invitation_2.id = 2
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation_1, mock_invitation_2],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 2
assert len(failed) == 0
@pytest.mark.asyncio
async def test_batch_handles_partial_success(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch returns partial results when some emails fail."""
# Arrange
from server.routes.org_invitation_models import UserAlreadyMemberError
emails = ['alice@example.com', 'existing@example.com']
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation, UserAlreadyMemberError()],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 1
assert len(failed) == 1
assert failed[0][0] == 'existing@example.com'
@pytest.mark.asyncio
async def test_batch_fails_entirely_on_permission_error(self, org_id, inviter_id):
"""Test that permission error fails the entire batch upfront."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
with patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=None, # Organization not found
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
assert 'not found' in str(exc_info.value)
@pytest.mark.asyncio
async def test_batch_fails_on_invalid_role(
self, org_id, inviter_id, mock_org, mock_inviter_member, mock_owner_role
):
"""Test that invalid role fails the entire batch."""
# Arrange
emails = ['alice@example.com']
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=None, # Invalid role
),
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='invalid_role',
inviter_id=inviter_id,
)
assert 'Invalid role' in str(exc_info.value)

View File

@@ -0,0 +1,308 @@
"""Tests for organization invitation store."""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import (
INVITATION_TOKEN_LENGTH,
INVITATION_TOKEN_PREFIX,
OrgInvitationStore,
)
class TestGenerateToken:
"""Test cases for token generation."""
def test_generate_token_has_correct_prefix(self):
"""Test that generated tokens have the correct prefix."""
token = OrgInvitationStore.generate_token()
assert token.startswith(INVITATION_TOKEN_PREFIX)
def test_generate_token_has_correct_length(self):
"""Test that generated tokens have the correct total length."""
token = OrgInvitationStore.generate_token()
expected_length = len(INVITATION_TOKEN_PREFIX) + INVITATION_TOKEN_LENGTH
assert len(token) == expected_length
def test_generate_token_uses_alphanumeric_characters(self):
"""Test that generated tokens use only alphanumeric characters."""
token = OrgInvitationStore.generate_token()
# Remove prefix and check the rest is alphanumeric
random_part = token[len(INVITATION_TOKEN_PREFIX) :]
assert random_part.isalnum()
def test_generate_token_is_unique(self):
"""Test that generated tokens are unique (probabilistically)."""
tokens = [OrgInvitationStore.generate_token() for _ in range(100)]
assert len(set(tokens)) == 100
class TestIsTokenExpired:
"""Test cases for token expiration checking."""
def test_token_not_expired_when_future(self):
"""Test that tokens with future expiration are not expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() + timedelta(days=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is False
def test_token_expired_when_past(self):
"""Test that tokens with past expiration are expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() - timedelta(seconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
def test_token_expired_at_exact_boundary(self):
"""Test that tokens at exact expiration time are expired."""
# A token that expires "now" should be expired
now = datetime.utcnow()
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = now - timedelta(microseconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
class TestCreateInvitation:
"""Test cases for invitation creation."""
@pytest.mark.asyncio
async def test_create_invitation_normalizes_email(self):
"""Test that email is normalized (lowercase, stripped) on creation."""
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.execute = AsyncMock()
# Mock the result of the re-fetch query
mock_result = MagicMock()
mock_invitation = MagicMock()
mock_invitation.id = 1
mock_invitation.email = 'test@example.com'
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute.return_value = mock_result
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.create_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
role_id=1,
inviter_id=UUID('87654321-4321-8765-4321-876543218765'),
)
# Verify that the OrgInvitation was created with normalized email
add_call = mock_session.add.call_args
created_invitation = add_call[0][0]
assert created_invitation.email == 'test@example.com'
class TestGetInvitationByToken:
"""Test cases for getting invitation by token."""
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_invitation(self):
"""Test that get_invitation_by_token returns the invitation when found."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.token = 'inv-test-token-12345'
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-test-token-12345'
)
assert result == mock_invitation
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_none_when_not_found(self):
"""Test that get_invitation_by_token returns None when not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-nonexistent-token'
)
assert result is None
class TestGetPendingInvitation:
"""Test cases for getting pending invitation."""
@pytest.mark.asyncio
async def test_get_pending_invitation_normalizes_email(self):
"""Test that email is normalized when querying for pending invitations."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.get_pending_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
)
# Verify the query was called (email normalization happens in the filter)
assert mock_session.execute.called
class TestUpdateInvitationStatus:
"""Test cases for updating invitation status."""
@pytest.mark.asyncio
async def test_update_status_sets_accepted_at_for_accepted(self):
"""Test that accepted_at is set when status is accepted."""
from uuid import UUID
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.refresh = AsyncMock()
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
user_id = UUID('87654321-4321-8765-4321-876543218765')
await OrgInvitationStore.update_invitation_status(
invitation_id=1,
status=OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
assert mock_invitation.accepted_at is not None
assert mock_invitation.accepted_by_user_id == user_id
@pytest.mark.asyncio
async def test_update_status_returns_none_when_not_found(self):
"""Test that update returns None when invitation not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.update_invitation_status(
invitation_id=999,
status=OrgInvitation.STATUS_ACCEPTED,
)
assert result is None
class TestMarkExpiredIfNeeded:
"""Test cases for marking expired invitations."""
@pytest.mark.asyncio
async def test_marks_expired_when_pending_and_past_expiry(self):
"""Test that pending expired invitations are marked as expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is True
mock_update.assert_called_once_with(1, OrgInvitation.STATUS_EXPIRED)
@pytest.mark.asyncio
async def test_does_not_mark_when_not_expired(self):
"""Test that non-expired invitations are not marked."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() + timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()
@pytest.mark.asyncio
async def test_does_not_mark_when_not_pending(self):
"""Test that non-pending invitations are not marked even if expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_ACCEPTED
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()

View File

@@ -0,0 +1,388 @@
"""Tests for organization invitations API router."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.routes.org_invitations import accept_router, invitation_router
@pytest.fixture
def app():
"""Create a FastAPI app with the invitation routers."""
app = FastAPI()
app.include_router(invitation_router)
app.include_router(accept_router)
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
class TestRouterPrefixes:
"""Test that router prefixes are configured correctly."""
def test_invitation_router_has_correct_prefix(self):
"""Test that invitation_router has /api/organizations/{org_id}/members prefix."""
assert invitation_router.prefix == '/api/organizations/{org_id}/members'
def test_accept_router_has_correct_prefix(self):
"""Test that accept_router has /api/organizations/members/invite prefix."""
assert accept_router.prefix == '/api/organizations/members/invite'
class TestAcceptInvitationEndpoint:
"""Test cases for the accept invitation endpoint."""
@pytest.fixture
def mock_user_auth(self):
"""Create a mock user auth."""
user_auth = MagicMock()
user_auth.get_user_id = AsyncMock(
return_value='87654321-4321-8765-4321-876543218765'
)
return user_auth
@pytest.mark.asyncio
async def test_accept_unauthenticated_redirects_to_login(self, client):
"""Test that unauthenticated users are redirected to login with invitation token."""
with patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=None,
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert '/login?invitation_token=inv-test-token-123' in response.headers.get(
'location', ''
)
@pytest.mark.asyncio
async def test_accept_authenticated_success_redirects_home(
self, client, mock_user_auth
):
"""Test that successful acceptance redirects to home page."""
mock_invitation = MagicMock()
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
return_value=mock_invitation,
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
location = response.headers.get('location', '')
assert location.endswith('/')
assert 'invitation_expired' not in location
assert 'invitation_invalid' not in location
assert 'email_mismatch' not in location
@pytest.mark.asyncio
async def test_accept_expired_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that expired invitation redirects with invitation_expired=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationExpiredError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_expired=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_invalid_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that invalid invitation redirects with invitation_invalid=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationInvalidError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_invalid=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_already_member_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that already member error redirects with already_member=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=UserAlreadyMemberError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'already_member=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_email_mismatch_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that email mismatch error redirects with email_mismatch=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=EmailMismatchError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'email_mismatch=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_unexpected_error_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that unexpected errors redirect with invitation_error=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=Exception('Unexpected error'),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_error=true' in response.headers.get('location', '')
class TestCreateInvitationBatchEndpoint:
"""Test cases for the batch invitation creation endpoint."""
@pytest.fixture
def batch_app(self):
"""Create a FastAPI app with dependency overrides for batch tests."""
from openhands.server.user_auth import get_user_id
app = FastAPI()
app.include_router(invitation_router)
# Override the get_user_id dependency
app.dependency_overrides[get_user_id] = (
lambda: '87654321-4321-8765-4321-876543218765'
)
return app
@pytest.fixture
def batch_client(self, batch_app):
"""Create a test client with dependency overrides."""
return TestClient(batch_app)
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation."""
from datetime import datetime
invitation = MagicMock()
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.role = MagicMock(name='member')
invitation.role.name = 'member'
invitation.role_id = 3
invitation.status = 'pending'
invitation.created_at = datetime(2026, 2, 17, 10, 0, 0)
invitation.expires_at = datetime(2026, 2, 24, 10, 0, 0)
return invitation
@pytest.mark.asyncio
async def test_batch_create_returns_successful_invitations(
self, batch_client, mock_invitation
):
"""Test that batch creation returns successful invitations."""
mock_invitation_2 = MagicMock()
mock_invitation_2.id = 2
mock_invitation_2.email = 'bob@example.com'
mock_invitation_2.role = MagicMock()
mock_invitation_2.role.name = 'member'
mock_invitation_2.role_id = 3
mock_invitation_2.status = 'pending'
mock_invitation_2.created_at = mock_invitation.created_at
mock_invitation_2.expires_at = mock_invitation.expires_at
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation, mock_invitation_2], []),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'bob@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 2
assert len(data['failed']) == 0
@pytest.mark.asyncio
async def test_batch_create_returns_partial_success(
self, batch_client, mock_invitation
):
"""Test that batch creation returns both successful and failed invitations."""
failed_emails = [('existing@example.com', 'User is already a member')]
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation], failed_emails),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'existing@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 1
assert len(data['failed']) == 1
assert data['failed'][0]['email'] == 'existing@example.com'
assert 'already a member' in data['failed'][0]['error']
@pytest.mark.asyncio
async def test_batch_create_permission_denied_returns_403(self, batch_client):
"""Test that permission denied returns 403 for entire batch."""
from server.routes.org_invitation_models import InsufficientPermissionError
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=InsufficientPermissionError(
'Only owners and admins can invite'
),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'member'},
)
assert response.status_code == 403
assert 'owners and admins' in response.json()['detail']
@pytest.mark.asyncio
async def test_batch_create_invalid_role_returns_400(self, batch_client):
"""Test that invalid role returns 400."""
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=ValueError('Invalid role: superuser'),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'superuser'},
)
assert response.status_code == 400
assert 'Invalid role' in response.json()['detail']

View File

@@ -151,8 +151,9 @@ describe("LoginContent", () => {
await user.click(githubButton); await user.click(githubButton);
// Wait for async handleAuthRedirect to complete // Wait for async handleAuthRedirect to complete
// The URL includes state parameter added by handleAuthRedirect
await waitFor(() => { await waitFor(() => {
expect(window.location.href).toBe(mockUrl); expect(window.location.href).toContain(mockUrl);
}); });
}); });
@@ -201,4 +202,103 @@ describe("LoginContent", () => {
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument(); expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
}); });
it("should display invitation pending message when hasInvitation is true", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation
/>
</MemoryRouter>,
);
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
it("should not display invitation pending message when hasInvitation is false", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation={false}
/>
</MemoryRouter>,
);
expect(
screen.queryByText("AUTH$INVITATION_PENDING"),
).not.toBeInTheDocument();
});
it("should call buildOAuthStateData when clicking auth button", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
expect(mockBuildOAuthStateData).toHaveBeenCalled();
const callArg = mockBuildOAuthStateData.mock.calls[0][0];
expect(callArg).toHaveProperty("redirect_url");
});
});
it("should encode state with invitation token when buildOAuthStateData provides token", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
const redirectUrl = window.location.href;
// The URL should contain an encoded state parameter
expect(redirectUrl).toContain("state=");
// Decode and verify the state contains invitation_token
const url = new URL(redirectUrl);
const state = url.searchParams.get("state");
if (state) {
const decodedState = JSON.parse(atob(state));
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
}
});
});
}); });

View File

@@ -0,0 +1,170 @@
import { act, renderHook } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
// Mock setSearchParams function
const mockSetSearchParams = vi.fn();
// Default mock searchParams
let mockSearchParamsData: Record<string, string> = {};
// Mock react-router
vi.mock("react-router", () => ({
useSearchParams: () => [
{
get: (key: string) => mockSearchParamsData[key] || null,
has: (key: string) => key in mockSearchParamsData,
},
mockSetSearchParams,
],
}));
// Import after mocking
import { useInvitation } from "#/hooks/use-invitation";
describe("useInvitation", () => {
beforeEach(() => {
// Clear localStorage before each test
localStorage.clear();
// Reset mock data
mockSearchParamsData = {};
mockSetSearchParams.mockClear();
});
afterEach(() => {
vi.clearAllMocks();
});
describe("initialization", () => {
it("should initialize with null token when localStorage is empty", () => {
// Arrange - localStorage is empty (cleared in beforeEach)
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
it("should initialize with token from localStorage if present", () => {
// Arrange
const storedToken = "inv-stored-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBe(storedToken);
expect(result.current.hasInvitation).toBe(true);
});
});
describe("URL token capture", () => {
it("should capture invitation_token from URL and store in localStorage", () => {
// Arrange
const urlToken = "inv-url-token-67890";
mockSearchParamsData = { invitation_token: urlToken };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBe(urlToken);
expect(mockSetSearchParams).toHaveBeenCalled();
});
});
describe("completion cleanup", () => {
it("should clear localStorage when email_mismatch param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { email_mismatch: "true" };
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(mockSetSearchParams).toHaveBeenCalled();
});
it("should clear localStorage when invitation_success param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { invitation_success: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
it("should clear localStorage when invitation_expired param is present", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token");
mockSearchParamsData = { invitation_expired: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
});
describe("buildOAuthStateData", () => {
it("should include invitation_token in OAuth state when token is present", () => {
// Arrange
const token = "inv-oauth-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, token);
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBe(token);
expect(stateData.redirect_url).toBe("/dashboard");
});
it("should not include invitation_token when no token is present", () => {
// Arrange - no token in localStorage
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBeUndefined();
expect(stateData.redirect_url).toBe("/dashboard");
});
});
describe("clearInvitation", () => {
it("should remove token from localStorage when called", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token-to-clear");
const { result } = renderHook(() => useInvitation());
// Act
act(() => {
result.current.clearInvitation();
});
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
});
});

View File

@@ -57,6 +57,22 @@ vi.mock("#/hooks/use-tracking", () => ({
}), }),
})); }));
const { useInvitationMock, buildOAuthStateDataMock } = vi.hoisted(() => ({
useInvitationMock: vi.fn(() => ({
invitationToken: null as string | null,
hasInvitation: false,
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
clearInvitation: vi.fn(),
})),
buildOAuthStateDataMock: vi.fn(
(baseState: Record<string, string>) => baseState,
),
}));
vi.mock("#/hooks/use-invitation", () => ({
useInvitation: () => useInvitationMock(),
}));
const RouterStub = createRoutesStub([ const RouterStub = createRoutesStub([
{ {
Component: LoginPage, Component: LoginPage,
@@ -234,7 +250,8 @@ describe("LoginPage", () => {
}); });
await user.click(githubButton); await user.click(githubButton);
expect(window.location.href).toBe(mockUrl); // URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain(mockUrl);
}); });
it("should redirect to GitLab auth URL when GitLab button is clicked", async () => { it("should redirect to GitLab auth URL when GitLab button is clicked", async () => {
@@ -255,7 +272,8 @@ describe("LoginPage", () => {
}); });
await user.click(gitlabButton); await user.click(gitlabButton);
expect(window.location.href).toBe("https://gitlab.com/oauth/authorize"); // URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain("https://gitlab.com/oauth/authorize");
}); });
it("should redirect to Bitbucket auth URL when Bitbucket button is clicked", async () => { it("should redirect to Bitbucket auth URL when Bitbucket button is clicked", async () => {
@@ -282,7 +300,8 @@ describe("LoginPage", () => {
}); });
await user.click(bitbucketButton); await user.click(bitbucketButton);
expect(window.location.href).toBe( // URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain(
"https://bitbucket.org/site/oauth2/authorize", "https://bitbucket.org/site/oauth2/authorize",
); );
}); });
@@ -479,4 +498,137 @@ describe("LoginPage", () => {
}); });
}); });
}); });
describe("Invitation Flow", () => {
it("should display invitation pending message when hasInvitation is true", async () => {
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
});
it("should not display invitation pending message when hasInvitation is false", async () => {
useInvitationMock.mockReturnValue({
invitationToken: null,
hasInvitation: false,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByTestId("login-content")).toBeInTheDocument();
});
expect(
screen.queryByText("AUTH$INVITATION_PENDING"),
).not.toBeInTheDocument();
});
it("should pass buildOAuthStateData to LoginContent for OAuth state encoding", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: mockBuildOAuthStateData,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
).toBeInTheDocument();
});
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
// buildOAuthStateData should have been called during the OAuth redirect
expect(mockBuildOAuthStateData).toHaveBeenCalled();
});
it("should include invitation token in OAuth state when invitation is present", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: mockBuildOAuthStateData,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
).toBeInTheDocument();
});
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
// Verify the redirect URL contains the state with invitation token
await waitFor(() => {
expect(window.location.href).toContain("state=");
});
// Decode and verify the state contains invitation_token
const url = new URL(window.location.href);
const state = url.searchParams.get("state");
if (state) {
const decodedState = JSON.parse(atob(state));
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
}
});
it("should handle login with invitation_token URL parameter", async () => {
useInvitationMock.mockReturnValue({
invitationToken: "inv-url-token-67890",
hasInvitation: true,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login?invitation_token=inv-url-token-67890"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
});
});
}); });

View File

@@ -42,6 +42,15 @@ vi.mock("#/utils/custom-toast-handlers", () => ({
displaySuccessToast: vi.fn(), displaySuccessToast: vi.fn(),
})); }));
vi.mock("#/hooks/use-invitation", () => ({
useInvitation: () => ({
invitationToken: null,
hasInvitation: false,
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
clearInvitation: vi.fn(),
}),
}));
function LoginStub() { function LoginStub() {
const [searchParams] = useSearchParams(); const [searchParams] = useSearchParams();
const emailVerificationRequired = const emailVerificationRequired =
@@ -353,4 +362,68 @@ describe("MainApp", () => {
); );
}); });
}); });
describe("Invitation URL Parameters", () => {
beforeEach(() => {
vi.spyOn(AuthService, "authenticate").mockRejectedValue({
response: { status: 401 },
isAxiosError: true,
});
});
it("should redirect to login when email_mismatch=true is in query params", async () => {
renderMainApp(["/?email_mismatch=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_success=true is in query params", async () => {
renderMainApp(["/?invitation_success=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_expired=true is in query params", async () => {
renderMainApp(["/?invitation_expired=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_invalid=true is in query params", async () => {
renderMainApp(["/?invitation_invalid=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when already_member=true is in query params", async () => {
renderMainApp(["/?already_member=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
});
}); });

View File

@@ -21,6 +21,10 @@ export interface LoginContentProps {
emailVerified?: boolean; emailVerified?: boolean;
hasDuplicatedEmail?: boolean; hasDuplicatedEmail?: boolean;
recaptchaBlocked?: boolean; recaptchaBlocked?: boolean;
hasInvitation?: boolean;
buildOAuthStateData?: (
baseStateData: Record<string, string>,
) => Record<string, string>;
} }
export function LoginContent({ export function LoginContent({
@@ -31,6 +35,8 @@ export function LoginContent({
emailVerified = false, emailVerified = false,
hasDuplicatedEmail = false, hasDuplicatedEmail = false,
recaptchaBlocked = false, recaptchaBlocked = false,
hasInvitation = false,
buildOAuthStateData,
}: LoginContentProps) { }: LoginContentProps) {
const { t } = useTranslation(); const { t } = useTranslation();
const { trackLoginButtonClick } = useTracking(); const { trackLoginButtonClick } = useTracking();
@@ -59,31 +65,36 @@ export function LoginContent({
) => { ) => {
trackLoginButtonClick({ provider }); trackLoginButtonClick({ provider });
if (!config?.recaptcha_site_key || !recaptchaReady) { const url = new URL(redirectUrl);
// No reCAPTCHA or token generation failed - redirect normally const currentState =
window.location.href = redirectUrl; url.searchParams.get("state") || window.location.origin;
return;
// Build base state data
let stateData: Record<string, string> = {
redirect_url: currentState,
};
// Add invitation token if present
if (buildOAuthStateData) {
stateData = buildOAuthStateData(stateData);
} }
// If reCAPTCHA is configured, encode token in OAuth state // If reCAPTCHA is configured, add token to state
try { if (config?.recaptcha_site_key && recaptchaReady) {
const token = await executeRecaptcha("LOGIN"); try {
if (token) { const token = await executeRecaptcha("LOGIN");
const url = new URL(redirectUrl); if (token) {
const currentState = stateData.recaptcha_token = token;
url.searchParams.get("state") || window.location.origin; }
} catch (err) {
// Encode state with reCAPTCHA token for backend verification displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
const stateData = { return;
redirect_url: currentState,
recaptcha_token: token,
};
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
window.location.href = url.toString();
} }
} catch (err) {
displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
} }
// Encode state and redirect
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
window.location.href = url.toString();
}; };
const handleGitHubAuth = () => { const handleGitHubAuth = () => {
@@ -123,6 +134,10 @@ export function LoginContent({
const buttonBaseClasses = const buttonBaseClasses =
"w-[301.5px] h-10 rounded p-2 flex items-center justify-center cursor-pointer transition-opacity hover:opacity-90 disabled:opacity-50 disabled:cursor-not-allowed"; "w-[301.5px] h-10 rounded p-2 flex items-center justify-center cursor-pointer transition-opacity hover:opacity-90 disabled:opacity-50 disabled:cursor-not-allowed";
const buttonLabelClasses = "text-sm font-medium leading-5 px-1"; const buttonLabelClasses = "text-sm font-medium leading-5 px-1";
const shouldShownHelperText =
emailVerified || hasDuplicatedEmail || recaptchaBlocked || hasInvitation;
return ( return (
<div <div
className="flex flex-col items-center w-full gap-12.5" className="flex flex-col items-center w-full gap-12.5"
@@ -136,20 +151,29 @@ export function LoginContent({
{t(I18nKey.AUTH$LETS_GET_STARTED)} {t(I18nKey.AUTH$LETS_GET_STARTED)}
</h1> </h1>
{emailVerified && ( {shouldShownHelperText && (
<p className="text-sm text-muted-foreground text-center"> <div className="flex flex-col items-center gap-3">
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)} {emailVerified && (
</p> <p className="text-sm text-muted-foreground text-center">
)} {t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
{hasDuplicatedEmail && ( </p>
<p className="text-sm text-danger text-center"> )}
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)} {hasDuplicatedEmail && (
</p> <p className="text-sm text-danger text-center">
)} {t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
{recaptchaBlocked && ( </p>
<p className="text-sm text-danger text-center max-w-125"> )}
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)} {recaptchaBlocked && (
</p> <p className="text-sm text-danger text-center max-w-125">
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
</p>
)}
{hasInvitation && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$INVITATION_PENDING)}
</p>
)}
</div>
)} )}
<div className="flex flex-col items-center gap-3"> <div className="flex flex-col items-center gap-3">

View File

@@ -0,0 +1,119 @@
import React from "react";
import { useSearchParams } from "react-router";
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
interface UseInvitationReturn {
/** The invitation token, if present */
invitationToken: string | null;
/** Whether there is an active invitation */
hasInvitation: boolean;
/** Clear the stored invitation token */
clearInvitation: () => void;
/** Build OAuth state data including invitation token if present */
buildOAuthStateData: (
baseStateData: Record<string, string>,
) => Record<string, string>;
}
/**
* Hook to manage organization invitation tokens during the login flow.
*
* This hook:
* 1. Reads invitation_token from URL query params on mount
* 2. Persists the token in localStorage (survives page refresh and works across tabs)
* 3. Provides the token for inclusion in OAuth state
* 4. Provides cleanup method after successful authentication
*
* The invitation token flow:
* 1. User clicks invitation link → /api/invitations/accept?token=xxx
* 2. Backend redirects to /login?invitation_token=xxx
* 3. This hook captures token and stores in localStorage
* 4. When user clicks login button, token is included in OAuth state
* 5. After auth callback processes invitation, frontend clears the token
*
* Note: localStorage is used instead of sessionStorage to support scenarios where
* the user opens the email verification link in a new tab/browser window.
*/
export function useInvitation(): UseInvitationReturn {
const [searchParams, setSearchParams] = useSearchParams();
const [invitationToken, setInvitationToken] = React.useState<string | null>(
() => {
// Initialize from localStorage (persists across tabs and page refreshes)
if (typeof window !== "undefined") {
return localStorage.getItem(INVITATION_TOKEN_KEY);
}
return null;
},
);
// Capture invitation token from URL and persist to localStorage
// This only runs on the login page where the hook is used
React.useEffect(() => {
const tokenFromUrl = searchParams.get("invitation_token");
if (tokenFromUrl) {
// Store in localStorage for persistence across tabs and refreshes
localStorage.setItem(INVITATION_TOKEN_KEY, tokenFromUrl);
setInvitationToken(tokenFromUrl);
// Remove token from URL to clean up (prevents token exposure in browser history)
const newSearchParams = new URLSearchParams(searchParams);
newSearchParams.delete("invitation_token");
setSearchParams(newSearchParams, { replace: true });
}
}, [searchParams, setSearchParams]);
// Clear invitation token when invitation flow completes (success or failure)
// These query params are set by the backend after processing the invitation
React.useEffect(() => {
const invitationCompleted =
searchParams.has("invitation_success") ||
searchParams.has("invitation_expired") ||
searchParams.has("invitation_invalid") ||
searchParams.has("invitation_error") ||
searchParams.has("already_member") ||
searchParams.has("email_mismatch");
if (invitationCompleted) {
localStorage.removeItem(INVITATION_TOKEN_KEY);
setInvitationToken(null);
// Remove invitation params from URL to clean up
const newSearchParams = new URLSearchParams(searchParams);
newSearchParams.delete("invitation_success");
newSearchParams.delete("invitation_expired");
newSearchParams.delete("invitation_invalid");
newSearchParams.delete("invitation_error");
newSearchParams.delete("already_member");
newSearchParams.delete("email_mismatch");
setSearchParams(newSearchParams, { replace: true });
}
}, [searchParams, setSearchParams]);
const clearInvitation = React.useCallback(() => {
localStorage.removeItem(INVITATION_TOKEN_KEY);
setInvitationToken(null);
}, []);
const buildOAuthStateData = React.useCallback(
(baseStateData: Record<string, string>): Record<string, string> => {
const stateData = { ...baseStateData };
// Include invitation token in state if present
if (invitationToken) {
stateData.invitation_token = invitationToken;
}
return stateData;
},
[invitationToken],
);
return {
invitationToken,
hasInvitation: invitationToken !== null,
clearInvitation,
buildOAuthStateData,
};
}

View File

@@ -763,6 +763,7 @@ export enum I18nKey {
AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR", AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR",
AUTH$RECAPTCHA_BLOCKED = "AUTH$RECAPTCHA_BLOCKED", AUTH$RECAPTCHA_BLOCKED = "AUTH$RECAPTCHA_BLOCKED",
AUTH$LETS_GET_STARTED = "AUTH$LETS_GET_STARTED", AUTH$LETS_GET_STARTED = "AUTH$LETS_GET_STARTED",
AUTH$INVITATION_PENDING = "AUTH$INVITATION_PENDING",
COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE", COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE",
COMMON$AND = "COMMON$AND", COMMON$AND = "COMMON$AND",
COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY", COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY",

View File

@@ -12207,6 +12207,22 @@
"de": "Lass uns anfangen", "de": "Lass uns anfangen",
"uk": "Почнімо" "uk": "Почнімо"
}, },
"AUTH$INVITATION_PENDING": {
"en": "Sign in to accept your organization invitation",
"ja": "組織への招待を受け入れるにはサインインしてください",
"zh-CN": "登录以接受您的组织邀请",
"zh-TW": "登入以接受您的組織邀請",
"ko-KR": "조직 초대를 수락하려면 로그인하세요",
"no": "Logg inn for å godta organisasjonsinvitasjonen din",
"it": "Accedi per accettare l'invito della tua organizzazione",
"pt": "Faça login para aceitar o convite da sua organização",
"es": "Inicia sesión para aceptar la invitación de tu organización",
"ar": "سجّل الدخول لقبول دعوة مؤسستك",
"fr": "Connectez-vous pour accepter l'invitation de votre organisation",
"tr": "Organizasyon davetinizi kabul etmek için giriş yapın",
"de": "Melden Sie sich an, um Ihre Organisationseinladung anzunehmen",
"uk": "Увійдіть, щоб прийняти запрошення до організації"
},
"COMMON$TERMS_OF_SERVICE": { "COMMON$TERMS_OF_SERVICE": {
"en": "Terms of Service", "en": "Terms of Service",
"ja": "利用規約", "ja": "利用規約",

View File

@@ -10,6 +10,7 @@ import "./tailwind.css";
import "./index.css"; import "./index.css";
import React from "react"; import React from "react";
import { Toaster } from "react-hot-toast"; import { Toaster } from "react-hot-toast";
import { useInvitation } from "#/hooks/use-invitation";
export function Layout({ children }: { children: React.ReactNode }) { export function Layout({ children }: { children: React.ReactNode }) {
return ( return (
@@ -37,5 +38,9 @@ export const meta: MetaFunction = () => [
]; ];
export default function App() { export default function App() {
// Handle invitation token cleanup when invitation flow completes
// This runs on all pages to catch redirects from auth callback
useInvitation();
return <Outlet />; return <Outlet />;
} }

View File

@@ -4,6 +4,7 @@ import { useIsAuthed } from "#/hooks/query/use-is-authed";
import { useConfig } from "#/hooks/query/use-config"; import { useConfig } from "#/hooks/query/use-config";
import { useGitHubAuthUrl } from "#/hooks/use-github-auth-url"; import { useGitHubAuthUrl } from "#/hooks/use-github-auth-url";
import { useEmailVerification } from "#/hooks/use-email-verification"; import { useEmailVerification } from "#/hooks/use-email-verification";
import { useInvitation } from "#/hooks/use-invitation";
import { LoginContent } from "#/components/features/auth/login-content"; import { LoginContent } from "#/components/features/auth/login-content";
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal"; import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
@@ -23,6 +24,8 @@ export default function LoginPage() {
userId, userId,
} = useEmailVerification(); } = useEmailVerification();
const { hasInvitation, buildOAuthStateData } = useInvitation();
const gitHubAuthUrl = useGitHubAuthUrl({ const gitHubAuthUrl = useGitHubAuthUrl({
appMode: config.data?.app_mode || null, appMode: config.data?.app_mode || null,
authUrl: config.data?.auth_url, authUrl: config.data?.auth_url,
@@ -69,6 +72,8 @@ export default function LoginPage() {
emailVerified={emailVerified} emailVerified={emailVerified}
hasDuplicatedEmail={hasDuplicatedEmail} hasDuplicatedEmail={hasDuplicatedEmail}
recaptchaBlocked={recaptchaBlocked} recaptchaBlocked={recaptchaBlocked}
hasInvitation={hasInvitation}
buildOAuthStateData={buildOAuthStateData}
/> />
</main> </main>