mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-25 21:36:52 +08:00
Merge branch 'main' into feature/public-conversation-sharing
This commit is contained in:
commit
a62d77cf3a
109
enterprise/server/auth/email_validation.py
Normal file
109
enterprise/server/auth/email_validation.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Email validation utilities for preventing duplicate signups with + modifier."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def extract_base_email(email: str) -> str | None:
|
||||
"""Extract base email from an email address.
|
||||
|
||||
For emails with + modifier, extracts the base email (local part before + and @, plus domain).
|
||||
For emails without + modifier, returns the email as-is.
|
||||
|
||||
Examples:
|
||||
extract_base_email("joe+test@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe+openhands+test@example.com") -> "joe@example.com"
|
||||
|
||||
Args:
|
||||
email: The email address to process
|
||||
|
||||
Returns:
|
||||
The base email address, or None if email format is invalid
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = email.rsplit('@', 1)
|
||||
# Extract the part before + if it exists
|
||||
base_local = local_part.split('+', 1)[0]
|
||||
return f'{base_local}@{domain}'
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def has_plus_modifier(email: str) -> bool:
|
||||
"""Check if an email address contains a + modifier.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if email contains + before @, False otherwise
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, _ = email.rsplit('@', 1)
|
||||
return '+' in local_part
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
def matches_base_email(email: str, base_email: str) -> bool:
|
||||
"""Check if an email matches a base email pattern.
|
||||
|
||||
An email matches if:
|
||||
- It is exactly the base email (e.g., joe@example.com)
|
||||
- It has the same base local part and domain, with or without + modifier
|
||||
(e.g., joe+test@example.com matches base joe@example.com)
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
base_email: The base email to match against
|
||||
|
||||
Returns:
|
||||
True if email matches the base pattern, False otherwise
|
||||
"""
|
||||
if not email or not base_email:
|
||||
return False
|
||||
|
||||
# Extract base from both emails for comparison
|
||||
email_base = extract_base_email(email)
|
||||
base_email_normalized = extract_base_email(base_email)
|
||||
|
||||
if not email_base or not base_email_normalized:
|
||||
return False
|
||||
|
||||
# Emails match if they have the same base
|
||||
return email_base.lower() == base_email_normalized.lower()
|
||||
|
||||
|
||||
def get_base_email_regex_pattern(base_email: str) -> re.Pattern | None:
|
||||
"""Generate a regex pattern to match emails with the same base.
|
||||
|
||||
For base_email "joe@example.com", the pattern will match:
|
||||
- joe@example.com
|
||||
- joe+anything@example.com
|
||||
|
||||
Args:
|
||||
base_email: The base email address
|
||||
|
||||
Returns:
|
||||
A compiled regex pattern, or None if base_email is invalid
|
||||
"""
|
||||
base = extract_base_email(base_email)
|
||||
if not base:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = base.rsplit('@', 1)
|
||||
# Escape special regex characters in local part and domain
|
||||
escaped_local = re.escape(local_part)
|
||||
escaped_domain = re.escape(domain)
|
||||
# Pattern: joe@example.com OR joe+anything@example.com
|
||||
pattern = rf'^{escaped_local}(\+[^@\s]+)?@{escaped_domain}$'
|
||||
return re.compile(pattern, re.IGNORECASE)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
@ -154,8 +154,10 @@ class SaasUserAuth(UserAuth):
|
||||
try:
|
||||
# TODO: I think we can do this in a single request if we refactor
|
||||
with session_maker() as session:
|
||||
tokens = session.query(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == self.user_id
|
||||
tokens = (
|
||||
session.query(AuthTokens)
|
||||
.where(AuthTokens.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
@ -25,6 +26,11 @@ from server.auth.constants import (
|
||||
KEYCLOAK_SERVER_URL,
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
)
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
matches_base_email,
|
||||
)
|
||||
from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@ -509,6 +515,183 @@ class TokenManager:
|
||||
logger.info(f'Got user ID {keycloak_user_id} from email: {email}')
|
||||
return keycloak_user_id
|
||||
|
||||
async def _query_users_by_wildcard_pattern(
|
||||
self, local_part: str, domain: str
|
||||
) -> dict[str, dict]:
|
||||
"""Query Keycloak for users matching a wildcard email pattern.
|
||||
|
||||
Tries multiple query methods to find users with emails matching
|
||||
the pattern {local_part}*@{domain}. This catches the base email
|
||||
and all + modifier variants.
|
||||
|
||||
Args:
|
||||
local_part: The local part of the email (before @)
|
||||
domain: The domain part of the email (after @)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping user IDs to user objects
|
||||
"""
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
all_users = {}
|
||||
|
||||
# Query for users with emails matching the base pattern using wildcard
|
||||
# Pattern: {local_part}*@{domain} - catches base email and all + variants
|
||||
# This may also catch unintended matches (e.g., joesmith@example.com), but
|
||||
# they will be filtered out by the regex pattern check later
|
||||
# Use 'search' parameter for Keycloak 26+ (better wildcard support)
|
||||
wildcard_queries = [
|
||||
{'search': f'{local_part}*@{domain}'}, # Try 'search' parameter first
|
||||
{'q': f'email:{local_part}*@{domain}'}, # Fallback to 'q' parameter
|
||||
]
|
||||
|
||||
for query_params in wildcard_queries:
|
||||
try:
|
||||
users = await keycloak_admin.a_get_users(query_params)
|
||||
for user in users:
|
||||
all_users[user.get('id')] = user
|
||||
break # Success, no need to try fallback
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f'Wildcard query failed with {list(query_params.keys())[0]}: {e}'
|
||||
)
|
||||
continue # Try next query method
|
||||
|
||||
return all_users
|
||||
|
||||
def _find_duplicate_in_users(
|
||||
self, users: dict[str, dict], base_email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if any user in the provided list matches the base email pattern.
|
||||
|
||||
Filters users to find duplicates that match the base email pattern,
|
||||
excluding the current user.
|
||||
|
||||
Args:
|
||||
users: Dictionary mapping user IDs to user objects
|
||||
base_email: The base email to match against
|
||||
current_user_id: The user ID to exclude from the check
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found, False otherwise
|
||||
"""
|
||||
regex_pattern = get_base_email_regex_pattern(base_email)
|
||||
if not regex_pattern:
|
||||
logger.warning(
|
||||
f'Could not generate regex pattern for base email: {base_email}'
|
||||
)
|
||||
# Fallback to simple matching
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '').lower()
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and matches_base_email(user_email, base_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
else:
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '')
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and regex_pattern.match(user_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def check_duplicate_base_email(
|
||||
self, email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if a user with the same base email already exists.
|
||||
|
||||
This method checks for duplicate signups using email + modifier.
|
||||
It checks if any user exists with the same base email, regardless of whether
|
||||
the provided email has a + modifier or not.
|
||||
|
||||
Examples:
|
||||
- If email is "joe+test@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe@example.com", "joe+1@example.com")
|
||||
- If email is "joe@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe+1@example.com", "joe+test@example.com")
|
||||
|
||||
Args:
|
||||
email: The email address to check (may or may not contain + modifier)
|
||||
current_user_id: The user ID of the current user (to exclude from check)
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found (excluding current user), False otherwise
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
|
||||
base_email = extract_base_email(email)
|
||||
if not base_email:
|
||||
logger.warning(f'Could not extract base email from: {email}')
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, domain = base_email.rsplit('@', 1)
|
||||
users = await self._query_users_by_wildcard_pattern(local_part, domain)
|
||||
return self._find_duplicate_in_users(users, base_email, current_user_id)
|
||||
|
||||
except KeycloakConnectionError:
|
||||
logger.exception('KeycloakConnectionError when checking duplicate email')
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error checking duplicate email: {e}')
|
||||
# On any error, allow signup to proceed (fail open)
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def delete_keycloak_user(self, user_id: str) -> bool:
|
||||
"""Delete a user from Keycloak.
|
||||
|
||||
This method is used to clean up user accounts that were created
|
||||
but should not exist (e.g., duplicate email signups).
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to delete
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
# Use the sync method (python-keycloak doesn't have async delete_user)
|
||||
# Run it in a thread executor to avoid blocking the event loop
|
||||
await asyncio.to_thread(keycloak_admin.delete_user, user_id)
|
||||
logger.info(f'Successfully deleted Keycloak user {user_id}')
|
||||
return True
|
||||
except KeycloakConnectionError:
|
||||
logger.exception(f'KeycloakConnectionError when deleting user {user_id}')
|
||||
raise
|
||||
except KeycloakError as e:
|
||||
# User might not exist or already deleted
|
||||
logger.warning(
|
||||
f'KeycloakError when deleting user {user_id}: {e}',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error deleting Keycloak user {user_id}: {e}')
|
||||
return False
|
||||
|
||||
async def get_user_info_from_user_id(self, user_id: str) -> dict | None:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
user = await keycloak_admin.a_get_user(user_id)
|
||||
|
||||
@ -146,9 +146,11 @@ async def keycloak_callback(
|
||||
content={'error': 'Missing user ID or username in response'},
|
||||
)
|
||||
|
||||
# Check if email domain is blocked
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
|
||||
# Check if email domain is blocked
|
||||
email = user_info.get('email')
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
@ -164,6 +166,54 @@ async def keycloak_callback(
|
||||
},
|
||||
)
|
||||
|
||||
# Check for duplicate email with + modifier
|
||||
if email:
|
||||
try:
|
||||
has_duplicate = await token_manager.check_duplicate_base_email(
|
||||
email, user_id
|
||||
)
|
||||
if has_duplicate:
|
||||
logger.warning(
|
||||
f'Blocked signup attempt for email {email} - duplicate base email found',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Delete the Keycloak user that was automatically created during OAuth
|
||||
# This prevents orphaned accounts in Keycloak
|
||||
# The delete_keycloak_user method already handles all errors internally
|
||||
deletion_success = await token_manager.delete_keycloak_user(user_id)
|
||||
if deletion_success:
|
||||
logger.info(
|
||||
f'Deleted Keycloak user {user_id} after detecting duplicate email {email}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. '
|
||||
f'User may need to be manually cleaned up.'
|
||||
)
|
||||
|
||||
# Redirect to home page with query parameter indicating the issue
|
||||
home_url = f'{request.base_url}?duplicated_email=true'
|
||||
return RedirectResponse(home_url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log error but allow signup to proceed (fail open)
|
||||
logger.error(
|
||||
f'Error checking duplicate email for {email}: {e}',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Check email verification status
|
||||
email_verified = user_info.get('email_verified', False)
|
||||
if not email_verified:
|
||||
# Send verification email
|
||||
# Import locally to avoid circular import with email.py
|
||||
from server.routes.email import verify_email
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
|
||||
redirect_url = f'{request.base_url}?email_verification_required=true'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
return response
|
||||
|
||||
# default to github IDP for now.
|
||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||
|
||||
@ -74,7 +74,7 @@ async def update_email(
|
||||
accepted_tos=user_auth.accepted_tos,
|
||||
)
|
||||
|
||||
await _verify_email(request=request, user_id=user_id)
|
||||
await verify_email(request=request, user_id=user_id)
|
||||
|
||||
logger.info(f'Updating email address for {user_id} to {email}')
|
||||
return response
|
||||
@ -91,8 +91,10 @@ async def update_email(
|
||||
|
||||
|
||||
@api_router.put('/verify')
|
||||
async def verify_email(request: Request, user_id: str = Depends(get_user_id)):
|
||||
await _verify_email(request=request, user_id=user_id)
|
||||
async def resend_email_verification(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
):
|
||||
await verify_email(request=request, user_id=user_id)
|
||||
|
||||
logger.info(f'Resending verification email for {user_id}')
|
||||
return JSONResponse(
|
||||
@ -124,10 +126,14 @@ async def verified_email(request: Request):
|
||||
return response
|
||||
|
||||
|
||||
async def _verify_email(request: Request, user_id: str):
|
||||
async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False):
|
||||
keycloak_admin = get_keycloak_admin()
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
|
||||
redirect_uri = (
|
||||
f'{scheme}://{request.url.netloc}?email_verified=true'
|
||||
if is_auth_flow
|
||||
else f'{scheme}://{request.url.netloc}/api/email/verified'
|
||||
)
|
||||
logger.info(f'Redirect URI: {redirect_uri}')
|
||||
await keycloak_admin.a_send_verify_email(
|
||||
user_id=user_id,
|
||||
|
||||
151
enterprise/tests/unit/server/routes/test_email_routes.py
Normal file
151
enterprise/tests/unit/server/routes/test_email_routes.py
Normal file
@ -0,0 +1,151 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.email import verified_email, verify_email
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.url = MagicMock()
|
||||
request.url.hostname = 'localhost'
|
||||
request.url.netloc = 'localhost:8000'
|
||||
request.url.path = '/api/email/verified'
|
||||
request.base_url = 'http://localhost:8000/'
|
||||
request.headers = {}
|
||||
request.cookies = {}
|
||||
request.query_params = MagicMock()
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_auth():
|
||||
"""Create a mock SaasUserAuth object."""
|
||||
auth = MagicMock(spec=SaasUserAuth)
|
||||
auth.access_token = SecretStr('test_access_token')
|
||||
auth.refresh_token = SecretStr('test_refresh_token')
|
||||
auth.email = 'test@example.com'
|
||||
auth.email_verified = False
|
||||
auth.accepted_tos = True
|
||||
auth.refresh = AsyncMock()
|
||||
return auth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_default_behavior(mock_request):
|
||||
"""Test verify_email with default is_auth_flow=False."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['user_id'] == user_id
|
||||
assert (
|
||||
call_args.kwargs['redirect_uri'] == 'http://localhost:8000/api/email/verified'
|
||||
)
|
||||
assert 'client_id' in call_args.kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_with_auth_flow(mock_request):
|
||||
"""Test verify_email with is_auth_flow=True."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['user_id'] == user_id
|
||||
assert (
|
||||
call_args.kwargs['redirect_uri'] == 'http://localhost:8000?email_verified=true'
|
||||
)
|
||||
assert 'client_id' in call_args.kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_https_scheme(mock_request):
|
||||
"""Test verify_email uses https scheme for non-localhost hosts."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
|
||||
|
||||
# Assert
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['redirect_uri'].startswith('https://')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_default_redirect(mock_request, mock_user_auth):
|
||||
"""Test verified_email redirects to /settings/user by default."""
|
||||
# Arrange
|
||||
mock_request.query_params.get.return_value = None
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert result.headers['location'] == 'http://localhost:8000/settings/user'
|
||||
mock_user_auth.refresh.assert_called_once()
|
||||
mock_set_cookie.assert_called_once()
|
||||
assert mock_user_auth.email_verified is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_https_scheme(mock_request, mock_user_auth):
|
||||
"""Test verified_email uses https scheme for non-localhost hosts."""
|
||||
# Arrange
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_request.query_params.get.return_value = None
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.headers['location'].startswith('https://')
|
||||
mock_set_cookie.assert_called_once()
|
||||
# Verify secure flag is True for https
|
||||
call_kwargs = mock_set_cookie.call_args.kwargs
|
||||
assert call_kwargs['secure'] is True
|
||||
@ -136,6 +136,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@ -184,6 +185,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@ -214,6 +216,82 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
mock_posthog.set.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
"""Test keycloak_callback when email is not verified."""
|
||||
# Arrange
|
||||
mock_verify_email = AsyncMock()
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': False,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'email_verification_required=true' in result.headers['location']
|
||||
mock_verify_email.assert_called_once_with(
|
||||
request=mock_request, user_id='test_user_id', is_auth_flow=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
"""Test keycloak_callback when email_verified field is missing (defaults to False)."""
|
||||
# Arrange
|
||||
mock_verify_email = AsyncMock()
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
# email_verified field is missing
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'email_verification_required=true' in result.headers['location']
|
||||
mock_verify_email.assert_called_once_with(
|
||||
request=mock_request, user_id='test_user_id', is_auth_flow=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
"""Test successful keycloak_callback without valid offline token."""
|
||||
@ -248,6 +326,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@ -513,6 +592,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@ -566,6 +646,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@colsch.us',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@ -615,6 +696,7 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
# No email field
|
||||
}
|
||||
)
|
||||
@ -635,3 +717,222 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
"""Test keycloak_callback when duplicate email is detected."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
"""Test keycloak_callback when duplicate is detected but deletion fails."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
"""Test keycloak_callback when duplicate check raises exception."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(
|
||||
side_effect=Exception('Check failed')
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should proceed with normal flow despite exception (fail open)
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
"""Test keycloak_callback when no duplicate email is found."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
# Should not delete user when no duplicate found
|
||||
mock_token_manager.delete_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
"""Test keycloak_callback when email is not in user_info."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
# No email field
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
# Should not check for duplicate when email is missing
|
||||
mock_token_manager.check_duplicate_base_email.assert_not_called()
|
||||
|
||||
294
enterprise/tests/unit/test_email_validation.py
Normal file
294
enterprise/tests/unit/test_email_validation.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""Tests for email validation utilities."""
|
||||
|
||||
import re
|
||||
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
has_plus_modifier,
|
||||
matches_base_email,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractBaseEmail:
|
||||
"""Test cases for extract_base_email function."""
|
||||
|
||||
def test_extract_base_email_with_plus_modifier(self):
|
||||
"""Test extracting base email from email with + modifier."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_without_plus_modifier(self):
|
||||
"""Test that email without + modifier is returned as-is."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_multiple_plus_signs(self):
|
||||
"""Test extracting base email when multiple + signs exist."""
|
||||
# Arrange
|
||||
email = 'joe+openhands+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns None."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_empty_string(self):
|
||||
"""Test that empty string returns None."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_none(self):
|
||||
"""Test that None input returns None."""
|
||||
# Arrange
|
||||
email = None
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestHasPlusModifier:
|
||||
"""Test cases for has_plus_modifier function."""
|
||||
|
||||
def test_has_plus_modifier_true(self):
|
||||
"""Test detecting + modifier in email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_has_plus_modifier_false(self):
|
||||
"""Test that email without + modifier returns False."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns False."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_empty_string(self):
|
||||
"""Test that empty string returns False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMatchesBaseEmail:
|
||||
"""Test cases for matches_base_email function."""
|
||||
|
||||
def test_matches_base_email_exact_match(self):
|
||||
"""Test that exact base email matches."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_with_plus_variant(self):
|
||||
"""Test that email with + variant matches base email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_different_base(self):
|
||||
"""Test that different base emails do not match."""
|
||||
# Arrange
|
||||
email = 'jane@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_different_domain(self):
|
||||
"""Test that same local part but different domain does not match."""
|
||||
# Arrange
|
||||
email = 'joe@other.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_case_insensitive(self):
|
||||
"""Test that matching is case-insensitive."""
|
||||
# Arrange
|
||||
email = 'JOE+TEST@EXAMPLE.COM'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_empty_strings(self):
|
||||
"""Test that empty strings return False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetBaseEmailRegexPattern:
|
||||
"""Test cases for get_base_email_regex_pattern function."""
|
||||
|
||||
def test_get_base_email_regex_pattern_valid(self):
|
||||
"""Test generating valid regex pattern for base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is not None
|
||||
assert isinstance(pattern, re.Pattern)
|
||||
assert pattern.match('joe@example.com') is not None
|
||||
assert pattern.match('joe+test@example.com') is not None
|
||||
assert pattern.match('joe+openhands@example.com') is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_matches_plus_variant(self):
|
||||
"""Test that regex pattern matches + variant."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe+test@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_base(self):
|
||||
"""Test that regex pattern rejects different base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('jane@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_domain(self):
|
||||
"""Test that regex pattern rejects different domain."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe@other.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_case_insensitive(self):
|
||||
"""Test that regex pattern is case-insensitive."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('JOE+TEST@EXAMPLE.COM')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_special_characters(self):
|
||||
"""Test that regex pattern handles special characters in email."""
|
||||
# Arrange
|
||||
base_email = 'user.name+tag@example-site.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('user.name+test@example-site.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_invalid_base_email(self):
|
||||
"""Test that invalid base email returns None."""
|
||||
# Arrange
|
||||
base_email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is None
|
||||
@ -1,6 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from keycloak.exceptions import KeycloakConnectionError, KeycloakError
|
||||
from server.auth.token_manager import TokenManager
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.offline_token_store import OfflineTokenStore
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
@ -32,6 +34,14 @@ def token_store(mock_session_maker, mock_config):
|
||||
return OfflineTokenStore('test_user_id', mock_session_maker, mock_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_manager():
|
||||
with patch('server.config.get_config') as mock_get_config:
|
||||
mock_config = mock_get_config.return_value
|
||||
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
||||
return TokenManager(external=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_token_new_record(token_store, mock_session):
|
||||
# Setup
|
||||
@ -109,3 +119,419 @@ async def test_get_instance(mock_config):
|
||||
assert isinstance(result, OfflineTokenStore)
|
||||
assert result.user_id == test_user_id
|
||||
assert result.config == mock_config
|
||||
|
||||
|
||||
class TestCheckDuplicateBaseEmail:
|
||||
"""Test cases for check_duplicate_base_email method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_no_plus_modifier(self, token_manager):
|
||||
"""Test that emails without + modifier are still checked for duplicates."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = False
|
||||
mock_query.return_value = {}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_query.assert_called_once()
|
||||
mock_find.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_empty_email(self, token_manager):
|
||||
"""Test that empty email returns False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
current_user_id = 'user123'
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_invalid_email(self, token_manager):
|
||||
"""Test that invalid email returns False."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
current_user_id = 'user123'
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_duplicate_found(self, token_manager):
|
||||
"""Test that duplicate email is detected when found."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
existing_user = {
|
||||
'id': 'existing_user_id',
|
||||
'email': 'joe@example.com',
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = True
|
||||
mock_query.return_value = {'existing_user_id': existing_user}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_query.assert_called_once()
|
||||
mock_find.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_no_duplicate(self, token_manager):
|
||||
"""Test that no duplicate is found when none exists."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = False
|
||||
mock_query.return_value = {}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_keycloak_connection_error(
|
||||
self, token_manager
|
||||
):
|
||||
"""Test that KeycloakConnectionError triggers retry and raises RetryError."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query:
|
||||
mock_query.side_effect = KeycloakConnectionError('Connection failed')
|
||||
|
||||
# Act & Assert
|
||||
# KeycloakConnectionError is re-raised, which triggers retry decorator
|
||||
# After retries exhaust (2 attempts), it raises RetryError
|
||||
from tenacity import RetryError
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_general_exception(self, token_manager):
|
||||
"""Test that general exceptions are handled gracefully."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query:
|
||||
mock_query.side_effect = Exception('Unexpected error')
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestQueryUsersByWildcardPattern:
|
||||
"""Test cases for _query_users_by_wildcard_pattern method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_success_with_search(
|
||||
self, token_manager
|
||||
):
|
||||
"""Test successful query using search parameter."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
mock_users = [
|
||||
{'id': 'user1', 'email': 'joe@example.com'},
|
||||
{'id': 'user2', 'email': 'joe+test@example.com'},
|
||||
]
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_users = AsyncMock(return_value=mock_users)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert 'user1' in result
|
||||
assert 'user2' in result
|
||||
mock_admin.a_get_users.assert_called_once_with(
|
||||
{'search': 'joe*@example.com'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_fallback_to_q(self, token_manager):
|
||||
"""Test fallback to q parameter when search fails."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
mock_users = [{'id': 'user1', 'email': 'joe@example.com'}]
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
# First call fails, second succeeds
|
||||
mock_admin.a_get_users = AsyncMock(
|
||||
side_effect=[Exception('Search failed'), mock_users]
|
||||
)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert 'user1' in result
|
||||
assert mock_admin.a_get_users.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_empty_result(self, token_manager):
|
||||
"""Test query returns empty dict when no users found."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_users = AsyncMock(return_value=[])
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestFindDuplicateInUsers:
|
||||
"""Test cases for _find_duplicate_in_users method."""
|
||||
|
||||
def test_find_duplicate_in_users_with_regex_match(self, token_manager):
|
||||
"""Test finding duplicate using regex pattern."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
'user2': {'id': 'user2', 'email': 'joe+test@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user3'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_find_duplicate_in_users_fallback_to_simple_matching(self, token_manager):
|
||||
"""Test fallback to simple matching when regex pattern is None."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
}
|
||||
base_email = 'invalid-email' # Will cause regex pattern to be None
|
||||
current_user_id = 'user2'
|
||||
|
||||
with patch(
|
||||
'server.auth.token_manager.get_base_email_regex_pattern', return_value=None
|
||||
):
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should use fallback matching, but invalid base_email won't match
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_excludes_current_user(self, token_manager):
|
||||
"""Test that current user is excluded from duplicate check."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user1' # Same as user in users dict
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_no_match(self, token_manager):
|
||||
"""Test that no duplicate is found when emails don't match."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'jane@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user2'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_empty_dict(self, token_manager):
|
||||
"""Test that empty users dict returns False."""
|
||||
# Arrange
|
||||
users: dict[str, dict] = {}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user1'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDeleteKeycloakUser:
|
||||
"""Test cases for delete_keycloak_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_success(self, token_manager):
|
||||
"""Test successful deletion of Keycloak user."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.return_value = None
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_to_thread.assert_called_once_with(mock_admin.delete_user, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_connection_error(self, token_manager):
|
||||
"""Test handling of KeycloakConnectionError triggers retry and raises RetryError."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = KeycloakConnectionError('Connection failed')
|
||||
|
||||
# Act & Assert
|
||||
# KeycloakConnectionError triggers retry decorator
|
||||
# After retries exhaust (2 attempts), it raises RetryError
|
||||
from tenacity import RetryError
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_keycloak_error(self, token_manager):
|
||||
"""Test handling of KeycloakError (e.g., user not found)."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = KeycloakError('User not found')
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_general_exception(self, token_manager):
|
||||
"""Test handling of general exceptions."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = Exception('Unexpected error')
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
146
frontend/__tests__/MSW.md
Normal file
146
frontend/__tests__/MSW.md
Normal file
@ -0,0 +1,146 @@
|
||||
# Mock Service Worker (MSW) Guide
|
||||
|
||||
## Overview
|
||||
|
||||
[Mock Service Worker (MSW)](https://mswjs.io/) is an API mocking library that intercepts outgoing network requests at the network level. Unlike traditional mocking that patches `fetch` or `axios`, MSW uses a Service Worker in the browser and direct request interception in Node.js—making mocks transparent to your application code.
|
||||
|
||||
We use MSW in this project for:
|
||||
- **Testing**: Write reliable unit and integration tests without real network calls
|
||||
- **Development**: Run the frontend with mocked APIs when the backend isn't available or when working on features with pending backend APIs
|
||||
|
||||
The same mock handlers work in both environments, so you write them once and reuse everywhere.
|
||||
|
||||
## Relevant Files
|
||||
|
||||
- `src/mocks/handlers.ts` - Main handler registry that combines all domain handlers
|
||||
- `src/mocks/*-handlers.ts` - Domain-specific handlers (auth, billing, conversation, etc.)
|
||||
- `src/mocks/browser.ts` - Browser setup for development mode
|
||||
- `src/mocks/node.ts` - Node.js setup for tests
|
||||
- `vitest.setup.ts` - Global test setup with MSW lifecycle hooks
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Running with Mocked APIs
|
||||
|
||||
```sh
|
||||
# Run with API mocking enabled
|
||||
npm run dev:mock
|
||||
|
||||
# Run with API mocking + SaaS mode simulation
|
||||
npm run dev:mock:saas
|
||||
```
|
||||
|
||||
These commands set `VITE_MOCK_API=true` which activates the MSW Service Worker to intercept requests.
|
||||
|
||||
> [!NOTE]
|
||||
> **OSS vs SaaS Mode**
|
||||
>
|
||||
> OpenHands runs in two modes:
|
||||
> - **OSS mode**: For local/self-hosted deployments where users provide their own LLM API keys and configure git providers manually
|
||||
> - **SaaS mode**: For the cloud offering with billing, managed API keys, and OAuth-based GitHub integration
|
||||
>
|
||||
> Use `dev:mock:saas` when working on SaaS-specific features like billing, API key management, or subscription flows.
|
||||
|
||||
|
||||
## Writing Tests
|
||||
|
||||
### Service Layer Mocking (Recommended)
|
||||
|
||||
For most tests, mock at the service layer using `vi.spyOn`. This approach is explicit, test-scoped, and makes the scenario being tested clear.
|
||||
|
||||
```typescript
|
||||
import { vi } from "vitest";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
llm_model: "openai/gpt-4o",
|
||||
llm_api_key_set: true,
|
||||
// ... other settings
|
||||
});
|
||||
```
|
||||
|
||||
Use `mockResolvedValue` for success scenarios and `mockRejectedValue` for error scenarios:
|
||||
|
||||
```typescript
|
||||
getSettingsSpy.mockRejectedValue(new Error("Failed to fetch settings"));
|
||||
```
|
||||
|
||||
### Network Layer Mocking (Advanced)
|
||||
|
||||
For tests that need actual network-level behavior (WebSockets, testing retry logic, etc.), use `server.use()` to override handlers per test.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Reuse the global server instance** - Don't create new `setupServer()` calls in individual tests. The project already has a global MSW server configured in `vitest.setup.ts` that handles lifecycle (`server.listen()`, `server.resetHandlers()`, `server.close()`). Use `server.use()` to add runtime handlers for specific test scenarios.
|
||||
|
||||
```typescript
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { server } from "#/mocks/node";
|
||||
|
||||
it("should handle server errors", async () => {
|
||||
server.use(
|
||||
http.get("/api/my-endpoint", () => {
|
||||
return new HttpResponse(null, { status: 500 });
|
||||
}),
|
||||
);
|
||||
// ... test code
|
||||
});
|
||||
```
|
||||
|
||||
For WebSocket testing, see `__tests__/helpers/msw-websocket-setup.ts` for utilities.
|
||||
|
||||
## Adding New API Mocks
|
||||
|
||||
When adding new API endpoints, create mocks in both places to maintain 1:1 similarity with the backend:
|
||||
|
||||
### 1. Add to `src/mocks/` (for development)
|
||||
|
||||
Create or update a domain-specific handler file:
|
||||
|
||||
```typescript
|
||||
// src/mocks/my-feature-handlers.ts
|
||||
import { http, HttpResponse } from "msw";
|
||||
|
||||
export const MY_FEATURE_HANDLERS = [
|
||||
http.get("/api/my-feature", () => {
|
||||
return HttpResponse.json({
|
||||
data: "mock response",
|
||||
});
|
||||
}),
|
||||
];
|
||||
```
|
||||
|
||||
Register in `handlers.ts`:
|
||||
|
||||
```typescript
|
||||
import { MY_FEATURE_HANDLERS } from "./my-feature-handlers";
|
||||
|
||||
export const handlers = [
|
||||
// ... existing handlers
|
||||
...MY_FEATURE_HANDLERS,
|
||||
];
|
||||
```
|
||||
|
||||
### 2. Mock in tests for specific scenarios
|
||||
|
||||
In your test files, spy on the service method to control responses per test case:
|
||||
|
||||
```typescript
|
||||
import { vi } from "vitest";
|
||||
import MyFeatureService from "#/api/my-feature-service.api";
|
||||
|
||||
const spy = vi.spyOn(MyFeatureService, "getData");
|
||||
spy.mockResolvedValue({ data: "test-specific response" });
|
||||
```
|
||||
|
||||
See `__tests__/routes/llm-settings.test.tsx` for a real-world example of service layer mocking.
|
||||
|
||||
> [!TIP]
|
||||
> For guidance on creating service APIs, see `src/api/README.md`.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Keep mocks close to real API contracts** - Update mocks when backend changes
|
||||
- **Use service layer mocking for most tests** - It's simpler and more explicit
|
||||
- **Reserve network layer mocking for integration tests** - WebSockets, retry logic, etc.
|
||||
- **Export mock data from handler files** - Reuse in tests (e.g., `MOCK_DEFAULT_USER_SETTINGS`)
|
||||
@ -1,6 +1,7 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { AuthModal } from "#/components/features/waitlist/auth-modal";
|
||||
|
||||
// Mock the useAuthUrl hook
|
||||
@ -27,11 +28,13 @@ describe("AuthModal", () => {
|
||||
|
||||
it("should render the GitHub and GitLab buttons", () => {
|
||||
render(
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github", "gitlab"]}
|
||||
/>,
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github", "gitlab"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
@ -49,11 +52,13 @@ describe("AuthModal", () => {
|
||||
const user = userEvent.setup();
|
||||
const mockUrl = "https://github.com/login/oauth/authorize";
|
||||
render(
|
||||
<AuthModal
|
||||
githubAuthUrl={mockUrl}
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>,
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl={mockUrl}
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
@ -65,10 +70,14 @@ describe("AuthModal", () => {
|
||||
});
|
||||
|
||||
it("should render Terms of Service and Privacy Policy text with correct links", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Find the terms of service section using data-testid
|
||||
const termsSection = screen.getByTestId("auth-modal-terms-of-service");
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toBeInTheDocument();
|
||||
|
||||
// Check that all text content is present in the paragraph
|
||||
@ -105,8 +114,44 @@ describe("AuthModal", () => {
|
||||
expect(termsSection).toContainElement(privacyLink);
|
||||
});
|
||||
|
||||
it("should display email verified message when emailVerified prop is true", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
emailVerified={true}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display email verified message when emailVerified prop is false", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
emailVerified={false}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should open Terms of Service link in new tab", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const tosLink = screen.getByRole("link", {
|
||||
name: "COMMON$TERMS_OF_SERVICE",
|
||||
@ -115,11 +160,58 @@ describe("AuthModal", () => {
|
||||
});
|
||||
|
||||
it("should open Privacy Policy link in new tab", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const privacyLink = screen.getByRole("link", {
|
||||
name: "COMMON$PRIVACY_POLICY",
|
||||
});
|
||||
expect(privacyLink).toHaveAttribute("target", "_blank");
|
||||
});
|
||||
|
||||
describe("Duplicate email error message", () => {
|
||||
const renderAuthModalWithRouter = (initialEntries: string[]) => {
|
||||
const hasDuplicatedEmail = initialEntries.includes(
|
||||
"/?duplicated_email=true",
|
||||
);
|
||||
|
||||
return render(
|
||||
<MemoryRouter initialEntries={initialEntries}>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
hasDuplicatedEmail={hasDuplicatedEmail}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
};
|
||||
|
||||
it("should display error message when duplicated_email query parameter is true", () => {
|
||||
// Arrange
|
||||
const initialEntries = ["/?duplicated_email=true"];
|
||||
|
||||
// Act
|
||||
renderAuthModalWithRouter(initialEntries);
|
||||
|
||||
// Assert
|
||||
const errorMessage = screen.getByText("AUTH$DUPLICATE_EMAIL_ERROR");
|
||||
expect(errorMessage).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display error message when duplicated_email query parameter is missing", () => {
|
||||
// Arrange
|
||||
const initialEntries = ["/"];
|
||||
|
||||
// Act
|
||||
renderAuthModalWithRouter(initialEntries);
|
||||
|
||||
// Assert
|
||||
const errorMessage = screen.queryByText("AUTH$DUPLICATE_EMAIL_ERROR");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { it, describe, expect, vi, beforeEach } from "vitest";
|
||||
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
|
||||
|
||||
describe("EmailVerificationModal", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render the email verification message", () => {
|
||||
// Arrange & Act
|
||||
render(<EmailVerificationModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render the TermsAndPrivacyNotice component", () => {
|
||||
// Arrange & Act
|
||||
render(<EmailVerificationModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@ -0,0 +1,48 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { it, describe, expect } from "vitest";
|
||||
import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-notice";
|
||||
|
||||
describe("TermsAndPrivacyNotice", () => {
|
||||
it("should render Terms of Service and Privacy Policy links", () => {
|
||||
// Arrange & Act
|
||||
render(<TermsAndPrivacyNotice />);
|
||||
|
||||
// Assert
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toBeInTheDocument();
|
||||
|
||||
const tosLink = screen.getByRole("link", {
|
||||
name: "COMMON$TERMS_OF_SERVICE",
|
||||
});
|
||||
const privacyLink = screen.getByRole("link", {
|
||||
name: "COMMON$PRIVACY_POLICY",
|
||||
});
|
||||
|
||||
expect(tosLink).toBeInTheDocument();
|
||||
expect(tosLink).toHaveAttribute("href", "https://www.all-hands.dev/tos");
|
||||
expect(tosLink).toHaveAttribute("target", "_blank");
|
||||
expect(tosLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
|
||||
expect(privacyLink).toBeInTheDocument();
|
||||
expect(privacyLink).toHaveAttribute(
|
||||
"href",
|
||||
"https://www.all-hands.dev/privacy",
|
||||
);
|
||||
expect(privacyLink).toHaveAttribute("target", "_blank");
|
||||
expect(privacyLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
});
|
||||
|
||||
it("should render all required text content", () => {
|
||||
// Arrange & Act
|
||||
render(<TermsAndPrivacyNotice />);
|
||||
|
||||
// Assert
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toHaveTextContent(
|
||||
"AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR",
|
||||
);
|
||||
expect(termsSection).toHaveTextContent("COMMON$TERMS_OF_SERVICE");
|
||||
expect(termsSection).toHaveTextContent("COMMON$AND");
|
||||
expect(termsSection).toHaveTextContent("COMMON$PRIVACY_POLICY");
|
||||
});
|
||||
});
|
||||
227
frontend/__tests__/router.md
Normal file
227
frontend/__tests__/router.md
Normal file
@ -0,0 +1,227 @@
|
||||
# Testing with React Router
|
||||
|
||||
## Overview
|
||||
|
||||
React Router components and hooks require a routing context to function. In tests, we need to provide this context while maintaining control over the routing state.
|
||||
|
||||
This guide covers the two main approaches used in the OpenHands frontend:
|
||||
|
||||
1. **`createRoutesStub`** - Creates a complete route structure for testing components with their actual route configuration, loaders, and nested routes.
|
||||
2. **`MemoryRouter`** - Provides a minimal routing context for components that just need router hooks to work.
|
||||
|
||||
Choose your approach based on what your component actually needs from the router.
|
||||
|
||||
## When to Use Each Approach
|
||||
|
||||
### `createRoutesStub` (Recommended)
|
||||
|
||||
Use `createRoutesStub` when your component:
|
||||
- Relies on route parameters (`useParams`)
|
||||
- Uses loader data (`useLoaderData`) or `clientLoader`
|
||||
- Has nested routes or uses `<Outlet />`
|
||||
- Needs to test navigation between routes
|
||||
|
||||
> [!NOTE]
|
||||
> `createRoutesStub` is intended for unit testing **reusable components** that depend on router context. For testing full route/page components, consider E2E tests (Playwright, Cypress) instead.
|
||||
|
||||
```typescript
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { render } from "@testing-library/react";
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MyRouteComponent,
|
||||
path: "/conversations/:conversationId",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/conversations/123"]} />);
|
||||
```
|
||||
|
||||
**With nested routes and loaders:**
|
||||
|
||||
```typescript
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: SettingsScreen,
|
||||
clientLoader,
|
||||
path: "/settings",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="llm-settings" />,
|
||||
path: "/settings",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings" />,
|
||||
path: "/settings/integrations",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/settings/integrations"]} />);
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> When using `clientLoader` from a Route module, you may encounter type mismatches. Use `@ts-expect-error` as a workaround:
|
||||
|
||||
```typescript
|
||||
import { clientLoader } from "@/routes/settings";
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
path: "/settings",
|
||||
Component: SettingsScreen,
|
||||
// @ts-expect-error: loader types won't align between test and app code
|
||||
loader: clientLoader,
|
||||
},
|
||||
]);
|
||||
```
|
||||
|
||||
### `MemoryRouter`
|
||||
|
||||
Use `MemoryRouter` when your component:
|
||||
- Only needs basic routing context to render
|
||||
- Uses `<Link>` components but you don't need to test navigation
|
||||
- Doesn't depend on specific route parameters or loaders
|
||||
|
||||
```typescript
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { render } from "@testing-library/react";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<MyComponent />
|
||||
</MemoryRouter>
|
||||
);
|
||||
```
|
||||
|
||||
**With initial route:**
|
||||
|
||||
```typescript
|
||||
render(
|
||||
<MemoryRouter initialEntries={["/some/path"]}>
|
||||
<MyComponent />
|
||||
</MemoryRouter>
|
||||
);
|
||||
```
|
||||
|
||||
## Anti-patterns to Avoid
|
||||
|
||||
### Using `BrowserRouter` in tests
|
||||
|
||||
`BrowserRouter` interacts with the actual browser history API, which can cause issues in test environments:
|
||||
|
||||
```typescript
|
||||
// ❌ Avoid
|
||||
render(
|
||||
<BrowserRouter>
|
||||
<MyComponent />
|
||||
</BrowserRouter>
|
||||
);
|
||||
|
||||
// ✅ Use MemoryRouter instead
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<MyComponent />
|
||||
</MemoryRouter>
|
||||
);
|
||||
```
|
||||
|
||||
### Mocking router hooks when `createRoutesStub` would work
|
||||
|
||||
Mocking hooks like `useParams` directly can be brittle and doesn't test the actual routing behavior:
|
||||
|
||||
```typescript
|
||||
// ❌ Avoid when possible
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
return {
|
||||
...actual,
|
||||
useParams: () => ({ conversationId: "123" }),
|
||||
};
|
||||
});
|
||||
|
||||
// ✅ Prefer createRoutesStub - tests real routing behavior
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MyComponent,
|
||||
path: "/conversations/:conversationId",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/conversations/123"]} />);
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Combining with `QueryClientProvider`
|
||||
|
||||
Many components need both routing and TanStack Query context:
|
||||
|
||||
```typescript
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
});
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MyComponent,
|
||||
path: "/",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
```
|
||||
|
||||
### Testing navigation behavior
|
||||
|
||||
Verify that user interactions trigger the expected navigation:
|
||||
|
||||
```typescript
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: HomeScreen,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="settings-screen" />,
|
||||
path: "/settings",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/"]} />);
|
||||
|
||||
const user = userEvent.setup();
|
||||
await user.click(screen.getByRole("link", { name: /settings/i }));
|
||||
|
||||
expect(screen.getByTestId("settings-screen")).toBeInTheDocument();
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
### Codebase Examples
|
||||
|
||||
- [settings.test.tsx](__tests__/routes/settings.test.tsx) - `createRoutesStub` with nested routes and loaders
|
||||
- [home-screen.test.tsx](__tests__/routes/home-screen.test.tsx) - `createRoutesStub` with navigation testing
|
||||
- [chat-interface.test.tsx](__tests__/components/chat/chat-interface.test.tsx) - `MemoryRouter` usage
|
||||
|
||||
### Official Documentation
|
||||
|
||||
- [React Router Testing Guide](https://reactrouter.com/start/framework/testing) - Official guide on testing with `createRoutesStub`
|
||||
- [MemoryRouter API](https://reactrouter.com/api/declarative-routers/MemoryRouter) - API reference for `MemoryRouter`
|
||||
242
frontend/__tests__/routes/root-layout.test.tsx
Normal file
242
frontend/__tests__/routes/root-layout.test.tsx
Normal file
@ -0,0 +1,242 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import MainApp from "#/routes/root-layout";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
// Mock other hooks that are not the focus of these tests
|
||||
vi.mock("#/hooks/use-github-auth-url", () => ({
|
||||
useGitHubAuthUrl: () => "https://github.com/oauth/authorize",
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-is-on-tos-page", () => ({
|
||||
useIsOnTosPage: () => false,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-auto-login", () => ({
|
||||
useAutoLogin: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-auth-callback", () => ({
|
||||
useAuthCallback: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-migrate-user-consent", () => ({
|
||||
useMigrateUserConsent: () => ({
|
||||
migrateUserConsent: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-reo-tracking", () => ({
|
||||
useReoTracking: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-sync-posthog-consent", () => ({
|
||||
useSyncPostHogConsent: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/custom-toast-handlers", () => ({
|
||||
displaySuccessToast: vi.fn(),
|
||||
}));
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MainApp,
|
||||
path: "/",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="outlet-content">Content</div>,
|
||||
path: "/",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
};
|
||||
|
||||
describe("MainApp - Email Verification Flow", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Default mocks for services
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
GITHUB_CLIENT_ID: "test-client-id",
|
||||
POSTHOG_CLIENT_KEY: "test-posthog-key",
|
||||
PROVIDERS_CONFIGURED: ["github"],
|
||||
AUTH_URL: "https://auth.example.com",
|
||||
FEATURE_FLAGS: {
|
||||
ENABLE_BILLING: false,
|
||||
HIDE_LLM_SETTINGS: false,
|
||||
ENABLE_JIRA: false,
|
||||
ENABLE_JIRA_DC: false,
|
||||
ENABLE_LINEAR: false,
|
||||
},
|
||||
});
|
||||
|
||||
vi.spyOn(AuthService, "authenticate").mockResolvedValue(true);
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
language: "en",
|
||||
user_consents_to_analytics: true,
|
||||
llm_model: "",
|
||||
llm_base_url: "",
|
||||
agent: "",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Mock localStorage
|
||||
vi.stubGlobal("localStorage", {
|
||||
getItem: vi.fn(() => null),
|
||||
setItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should display EmailVerificationModal when email_verification_required=true is in query params", async () => {
|
||||
// Arrange & Act
|
||||
render(
|
||||
<RouterStub initialEntries={["/?email_verification_required=true"]} />,
|
||||
{ wrapper: createWrapper() },
|
||||
);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should set emailVerified state and pass to AuthModal when email_verified=true is in query params", async () => {
|
||||
// Arrange
|
||||
// Mock a 401 error to simulate unauthenticated user
|
||||
const axiosError = {
|
||||
response: { status: 401 },
|
||||
isAxiosError: true,
|
||||
};
|
||||
vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError);
|
||||
|
||||
// Act
|
||||
render(<RouterStub initialEntries={["/?email_verified=true"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Assert - Wait for AuthModal to render (since user is not authenticated)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle both email_verification_required and email_verified params together", async () => {
|
||||
// Arrange & Act
|
||||
render(
|
||||
<RouterStub
|
||||
initialEntries={[
|
||||
"/?email_verification_required=true&email_verified=true",
|
||||
]}
|
||||
/>,
|
||||
{ wrapper: createWrapper() },
|
||||
);
|
||||
|
||||
// Assert - EmailVerificationModal should take precedence
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should remove query parameters from URL after processing", async () => {
|
||||
// Arrange & Act
|
||||
const { container } = render(
|
||||
<RouterStub initialEntries={["/?email_verification_required=true"]} />,
|
||||
{ wrapper: createWrapper() },
|
||||
);
|
||||
|
||||
// Assert - Wait for the modal to appear (which indicates processing happened)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Verify that the query parameter was processed by checking the modal appeared
|
||||
// The hook removes the parameter from the URL, so we verify the behavior indirectly
|
||||
expect(container).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display EmailVerificationModal when email_verification_required is not in query params", async () => {
|
||||
// Arrange - No query params set
|
||||
|
||||
// Act
|
||||
render(<RouterStub />, { wrapper: createWrapper() });
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should not display email verified message when email_verified is not in query params", async () => {
|
||||
// Arrange
|
||||
// Mock a 401 error to simulate unauthenticated user
|
||||
const axiosError = {
|
||||
response: { status: 401 },
|
||||
isAxiosError: true,
|
||||
};
|
||||
vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError);
|
||||
|
||||
// Act
|
||||
render(<RouterStub />, { wrapper: createWrapper() });
|
||||
|
||||
// Assert - AuthModal should render but without email verified message
|
||||
await waitFor(() => {
|
||||
const authModal = screen.queryByText(
|
||||
"AUTH$SIGN_IN_WITH_IDENTITY_PROVIDER",
|
||||
);
|
||||
if (authModal) {
|
||||
expect(
|
||||
screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).not.toBeInTheDocument();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
102
frontend/src/api/README.md
Normal file
102
frontend/src/api/README.md
Normal file
@ -0,0 +1,102 @@
|
||||
# API Services Guide
|
||||
|
||||
## Overview
|
||||
|
||||
Services are the abstraction layer between frontend components and backend APIs. They encapsulate HTTP requests using the shared `openHands` axios instance and provide typed methods for each endpoint.
|
||||
|
||||
Each service is a plain object with async methods.
|
||||
|
||||
## Structure
|
||||
|
||||
Each service lives in its own directory:
|
||||
|
||||
```
|
||||
src/api/
|
||||
├── billing-service/
|
||||
│ ├── billing-service.api.ts # Service methods
|
||||
│ └── billing.types.ts # Types and interfaces
|
||||
├── organization-service/
|
||||
│ ├── organization-service.api.ts
|
||||
│ └── organization.types.ts
|
||||
└── open-hands-axios.ts # Shared axios instance
|
||||
```
|
||||
|
||||
## Creating a Service
|
||||
|
||||
Use an object literal with named export. Use object destructuring for parameters to make calls self-documenting.
|
||||
|
||||
```typescript
|
||||
// feature-service/feature-service.api.ts
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import { Feature, CreateFeatureParams } from "./feature.types";
|
||||
|
||||
export const featureService = {
|
||||
getFeature: async ({ id }: { id: string }) => {
|
||||
const { data } = await openHands.get<Feature>(`/api/features/${id}`);
|
||||
return data;
|
||||
},
|
||||
|
||||
createFeature: async ({ name, description }: CreateFeatureParams) => {
|
||||
const { data } = await openHands.post<Feature>("/api/features", {
|
||||
name,
|
||||
description,
|
||||
});
|
||||
return data;
|
||||
},
|
||||
};
|
||||
```
|
||||
|
||||
### Types
|
||||
|
||||
Define types in a separate file within the same directory:
|
||||
|
||||
```typescript
|
||||
// feature-service/feature.types.ts
|
||||
export interface Feature {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface CreateFeatureParams {
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Don't call services directly in components.** Wrap them in TanStack Query hooks.
|
||||
>
|
||||
> Why? TanStack Query provides:
|
||||
> - **Caching** - Avoid redundant network requests
|
||||
> - **Deduplication** - Multiple components requesting the same data share one request
|
||||
> - **Loading/error states** - Built-in `isLoading`, `isError`, `data` states
|
||||
> - **Background refetching** - Data stays fresh automatically
|
||||
>
|
||||
> Hooks location:
|
||||
> - `src/hooks/query/` for data fetching (`useQuery`)
|
||||
> - `src/hooks/mutation/` for writes/updates (`useMutation`)
|
||||
|
||||
```typescript
|
||||
// src/hooks/query/use-feature.ts
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { featureService } from "#/api/feature-service/feature-service.api";
|
||||
|
||||
export const useFeature = (id: string) => {
|
||||
return useQuery({
|
||||
queryKey: ["feature", id],
|
||||
queryFn: () => featureService.getFeature({ id }),
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
## Naming Conventions
|
||||
|
||||
| Item | Convention | Example |
|
||||
|------|------------|---------|
|
||||
| Directory | `feature-service/` | `billing-service/` |
|
||||
| Service file | `feature-service.api.ts` | `billing-service.api.ts` |
|
||||
| Types file | `feature.types.ts` | `billing.types.ts` |
|
||||
| Export name | `featureService` | `billingService` |
|
||||
@ -13,12 +13,15 @@ import { useAuthUrl } from "#/hooks/use-auth-url";
|
||||
import { GetConfigResponse } from "#/api/option-service/option.types";
|
||||
import { Provider } from "#/types/settings";
|
||||
import { useTracking } from "#/hooks/use-tracking";
|
||||
import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-notice";
|
||||
|
||||
interface AuthModalProps {
|
||||
githubAuthUrl: string | null;
|
||||
appMode?: GetConfigResponse["APP_MODE"] | null;
|
||||
authUrl?: GetConfigResponse["AUTH_URL"];
|
||||
providersConfigured?: Provider[];
|
||||
emailVerified?: boolean;
|
||||
hasDuplicatedEmail?: boolean;
|
||||
}
|
||||
|
||||
export function AuthModal({
|
||||
@ -26,6 +29,8 @@ export function AuthModal({
|
||||
appMode,
|
||||
authUrl,
|
||||
providersConfigured,
|
||||
emailVerified = false,
|
||||
hasDuplicatedEmail = false,
|
||||
}: AuthModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const { trackLoginButtonClick } = useTracking();
|
||||
@ -123,6 +128,18 @@ export function AuthModal({
|
||||
<ModalBackdrop>
|
||||
<ModalBody className="border border-tertiary">
|
||||
<OpenHandsLogo width={68} height={46} />
|
||||
{emailVerified && (
|
||||
<div className="flex flex-col gap-2 w-full items-center text-center">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
{hasDuplicatedEmail && (
|
||||
<div className="text-center text-danger text-sm mt-2 mb-2">
|
||||
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex flex-col gap-2 w-full items-center text-center">
|
||||
<h1 className="text-2xl font-bold">
|
||||
{t(I18nKey.AUTH$SIGN_IN_WITH_IDENTITY_PROVIDER)}
|
||||
@ -198,30 +215,7 @@ export function AuthModal({
|
||||
)}
|
||||
</div>
|
||||
|
||||
<p
|
||||
className="mt-4 text-xs text-center text-muted-foreground"
|
||||
data-testid="auth-modal-terms-of-service"
|
||||
>
|
||||
{t(I18nKey.AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR)}{" "}
|
||||
<a
|
||||
href="https://www.all-hands.dev/tos"
|
||||
target="_blank"
|
||||
className="underline hover:text-primary"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
{t(I18nKey.COMMON$TERMS_OF_SERVICE)}
|
||||
</a>{" "}
|
||||
{t(I18nKey.COMMON$AND)}{" "}
|
||||
<a
|
||||
href="https://www.all-hands.dev/privacy"
|
||||
target="_blank"
|
||||
className="underline hover:text-primary"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
{t(I18nKey.COMMON$PRIVACY_POLICY)}
|
||||
</a>
|
||||
.
|
||||
</p>
|
||||
<TermsAndPrivacyNotice />
|
||||
</ModalBody>
|
||||
</ModalBackdrop>
|
||||
);
|
||||
|
||||
@ -0,0 +1,31 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import OpenHandsLogo from "#/assets/branding/openhands-logo.svg?react";
|
||||
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
|
||||
import { ModalBody } from "#/components/shared/modals/modal-body";
|
||||
import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-notice";
|
||||
|
||||
interface EmailVerificationModalProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function EmailVerificationModal({
|
||||
onClose,
|
||||
}: EmailVerificationModalProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<ModalBackdrop onClose={onClose}>
|
||||
<ModalBody className="border border-tertiary">
|
||||
<OpenHandsLogo width={68} height={46} />
|
||||
<div className="flex flex-col gap-2 w-full items-center text-center">
|
||||
<h1 className="text-2xl font-bold">
|
||||
{t(I18nKey.AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY)}
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
<TermsAndPrivacyNotice />
|
||||
</ModalBody>
|
||||
</ModalBackdrop>
|
||||
);
|
||||
}
|
||||
37
frontend/src/components/shared/terms-and-privacy-notice.tsx
Normal file
37
frontend/src/components/shared/terms-and-privacy-notice.tsx
Normal file
@ -0,0 +1,37 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
|
||||
interface TermsAndPrivacyNoticeProps {
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function TermsAndPrivacyNotice({
|
||||
className = "mt-4 text-xs text-center text-muted-foreground",
|
||||
}: TermsAndPrivacyNoticeProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<p className={className} data-testid="terms-and-privacy-notice">
|
||||
{t(I18nKey.AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR)}{" "}
|
||||
<a
|
||||
href="https://www.all-hands.dev/tos"
|
||||
target="_blank"
|
||||
className="underline hover:text-primary"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
{t(I18nKey.COMMON$TERMS_OF_SERVICE)}
|
||||
</a>{" "}
|
||||
{t(I18nKey.COMMON$AND)}{" "}
|
||||
<a
|
||||
href="https://www.all-hands.dev/privacy"
|
||||
target="_blank"
|
||||
className="underline hover:text-primary"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
{t(I18nKey.COMMON$PRIVACY_POLICY)}
|
||||
</a>
|
||||
.
|
||||
</p>
|
||||
);
|
||||
}
|
||||
63
frontend/src/hooks/use-email-verification.ts
Normal file
63
frontend/src/hooks/use-email-verification.ts
Normal file
@ -0,0 +1,63 @@
|
||||
import React from "react";
|
||||
import { useSearchParams } from "react-router";
|
||||
|
||||
/**
|
||||
* Hook to handle email verification logic from URL query parameters.
|
||||
* Manages the email verification modal state and email verified state
|
||||
* based on query parameters in the URL.
|
||||
*
|
||||
* @returns An object containing:
|
||||
* - emailVerificationModalOpen: boolean state for modal visibility
|
||||
* - setEmailVerificationModalOpen: function to control modal visibility
|
||||
* - emailVerified: boolean state for email verification status
|
||||
* - setEmailVerified: function to control email verification status
|
||||
* - hasDuplicatedEmail: boolean state for duplicate email error status
|
||||
*/
|
||||
export function useEmailVerification() {
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
const [emailVerificationModalOpen, setEmailVerificationModalOpen] =
|
||||
React.useState(false);
|
||||
const [emailVerified, setEmailVerified] = React.useState(false);
|
||||
const [hasDuplicatedEmail, setHasDuplicatedEmail] = React.useState(false);
|
||||
|
||||
// Check for email verification query parameters
|
||||
React.useEffect(() => {
|
||||
const emailVerificationRequired = searchParams.get(
|
||||
"email_verification_required",
|
||||
);
|
||||
const emailVerifiedParam = searchParams.get("email_verified");
|
||||
const duplicatedEmailParam = searchParams.get("duplicated_email");
|
||||
let shouldUpdate = false;
|
||||
|
||||
if (emailVerificationRequired === "true") {
|
||||
setEmailVerificationModalOpen(true);
|
||||
searchParams.delete("email_verification_required");
|
||||
shouldUpdate = true;
|
||||
}
|
||||
|
||||
if (emailVerifiedParam === "true") {
|
||||
setEmailVerified(true);
|
||||
searchParams.delete("email_verified");
|
||||
shouldUpdate = true;
|
||||
}
|
||||
|
||||
if (duplicatedEmailParam === "true") {
|
||||
setHasDuplicatedEmail(true);
|
||||
searchParams.delete("duplicated_email");
|
||||
shouldUpdate = true;
|
||||
}
|
||||
|
||||
// Clean up the URL by removing parameters if any were found
|
||||
if (shouldUpdate) {
|
||||
setSearchParams(searchParams, { replace: true });
|
||||
}
|
||||
}, [searchParams, setSearchParams]);
|
||||
|
||||
return {
|
||||
emailVerificationModalOpen,
|
||||
setEmailVerificationModalOpen,
|
||||
emailVerified,
|
||||
setEmailVerified,
|
||||
hasDuplicatedEmail,
|
||||
};
|
||||
}
|
||||
@ -730,6 +730,9 @@ export enum I18nKey {
|
||||
MICROAGENT_MANAGEMENT$USE_MICROAGENTS = "MICROAGENT_MANAGEMENT$USE_MICROAGENTS",
|
||||
AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR = "AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR",
|
||||
AUTH$NO_PROVIDERS_CONFIGURED = "AUTH$NO_PROVIDERS_CONFIGURED",
|
||||
AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY = "AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY",
|
||||
AUTH$EMAIL_VERIFIED_PLEASE_LOGIN = "AUTH$EMAIL_VERIFIED_PLEASE_LOGIN",
|
||||
AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR",
|
||||
COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE",
|
||||
COMMON$AND = "COMMON$AND",
|
||||
COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY",
|
||||
|
||||
@ -11679,6 +11679,54 @@
|
||||
"de": "Mindestens ein Identitätsanbieter muss konfiguriert werden (z.B. GitHub)",
|
||||
"uk": "Принаймні один постачальник ідентифікації має бути налаштований (наприклад, GitHub)"
|
||||
},
|
||||
"AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY": {
|
||||
"en": "Please check your email to verify your account.",
|
||||
"ja": "アカウントを確認するためにメールを確認してください。",
|
||||
"zh-CN": "请检查您的电子邮件以验证您的账户。",
|
||||
"zh-TW": "請檢查您的電子郵件以驗證您的帳戶。",
|
||||
"ko-KR": "계정을 확인하려면 이메일을 확인하세요.",
|
||||
"no": "Vennligst sjekk e-posten din for å bekrefte kontoen din.",
|
||||
"it": "Controlla la tua email per verificare il tuo account.",
|
||||
"pt": "Por favor, verifique seu e-mail para verificar sua conta.",
|
||||
"es": "Por favor, verifica tu correo electrónico para verificar tu cuenta.",
|
||||
"ar": "يرجى التحقق من بريدك الإلكتروني للتحقق من حسابك.",
|
||||
"fr": "Veuillez vérifier votre e-mail pour vérifier votre compte.",
|
||||
"tr": "Hesabınızı doğrulamak için lütfen e-postanızı kontrol edin.",
|
||||
"de": "Bitte überprüfen Sie Ihre E-Mail, um Ihr Konto zu verifizieren.",
|
||||
"uk": "Будь ласка, перевірте свою електронну пошту, щоб підтвердити свій обліковий запис."
|
||||
},
|
||||
"AUTH$EMAIL_VERIFIED_PLEASE_LOGIN": {
|
||||
"en": "Your email has been verified. Please login below.",
|
||||
"ja": "メールアドレスが確認されました。下記からログインしてください。",
|
||||
"zh-CN": "您的电子邮件已验证。请在下方登录。",
|
||||
"zh-TW": "您的電子郵件已驗證。請在下方登錄。",
|
||||
"ko-KR": "이메일이 확인되었습니다. 아래에서 로그인하세요.",
|
||||
"no": "E-posten din er bekreftet. Vennligst logg inn nedenfor.",
|
||||
"it": "La tua email è stata verificata. Effettua il login qui sotto.",
|
||||
"pt": "Seu e-mail foi verificado. Por favor, faça login abaixo.",
|
||||
"es": "Tu correo electrónico ha sido verificado. Por favor, inicia sesión a continuación.",
|
||||
"ar": "تم التحقق من بريدك الإلكتروني. يرجى تسجيل الدخول أدناه.",
|
||||
"fr": "Votre e-mail a été vérifié. Veuillez vous connecter ci-dessous.",
|
||||
"tr": "E-postanız doğrulandı. Lütfen aşağıdan giriş yapın.",
|
||||
"de": "Ihre E-Mail wurde verifiziert. Bitte melden Sie sich unten an.",
|
||||
"uk": "Вашу електронну пошту підтверджено. Будь ласка, увійдіть нижче."
|
||||
},
|
||||
"AUTH$DUPLICATE_EMAIL_ERROR": {
|
||||
"en": "Your account is unable to be created. Please use a different login or try again.",
|
||||
"ja": "アカウントを作成できません。別のログインを使用するか、もう一度お試しください。",
|
||||
"zh-CN": "无法创建您的账户。请使用其他登录方式或重试。",
|
||||
"zh-TW": "無法建立您的帳戶。請使用其他登入方式或重試。",
|
||||
"ko-KR": "계정을 생성할 수 없습니다. 다른 로그인을 사용하거나 다시 시도해 주세요.",
|
||||
"no": "Kontoen din kan ikke opprettes. Vennligst bruk en annen innlogging eller prøv igjen.",
|
||||
"it": "Impossibile creare il tuo account. Utilizza un altro accesso o riprova.",
|
||||
"pt": "Não foi possível criar sua conta. Use um login diferente ou tente novamente.",
|
||||
"es": "No se puede crear su cuenta. Utilice un inicio de sesión diferente o inténtelo de nuevo.",
|
||||
"ar": "لا يمكن إنشاء حسابك. يرجى استخدام تسجيل دخول مختلف أو المحاولة مرة أخرى.",
|
||||
"fr": "Votre compte ne peut pas être créé. Veuillez utiliser une autre connexion ou réessayer.",
|
||||
"tr": "Hesabınız oluşturulamadı. Lütfen farklı bir giriş kullanın veya tekrar deneyin.",
|
||||
"de": "Ihr Konto kann nicht erstellt werden. Bitte verwenden Sie eine andere Anmeldung oder versuchen Sie es erneut.",
|
||||
"uk": "Ваш обліковий запис не може бути створений. Будь ласка, використовуйте інший спосіб входу або спробуйте ще раз."
|
||||
},
|
||||
"COMMON$TERMS_OF_SERVICE": {
|
||||
"en": "Terms of Service",
|
||||
"ja": "利用規約",
|
||||
|
||||
@ -15,6 +15,7 @@ import { useConfig } from "#/hooks/query/use-config";
|
||||
import { Sidebar } from "#/components/features/sidebar/sidebar";
|
||||
import { AuthModal } from "#/components/features/waitlist/auth-modal";
|
||||
import { ReauthModal } from "#/components/features/waitlist/reauth-modal";
|
||||
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
|
||||
import { AnalyticsConsentFormModal } from "#/components/features/analytics/analytics-consent-form-modal";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { useMigrateUserConsent } from "#/hooks/use-migrate-user-consent";
|
||||
@ -26,6 +27,7 @@ import { useAutoLogin } from "#/hooks/use-auto-login";
|
||||
import { useAuthCallback } from "#/hooks/use-auth-callback";
|
||||
import { useReoTracking } from "#/hooks/use-reo-tracking";
|
||||
import { useSyncPostHogConsent } from "#/hooks/use-sync-posthog-consent";
|
||||
import { useEmailVerification } from "#/hooks/use-email-verification";
|
||||
import { LOCAL_STORAGE_KEYS } from "#/utils/local-storage";
|
||||
import { EmailVerificationGuard } from "#/components/features/guards/email-verification-guard";
|
||||
import { MaintenanceBanner } from "#/components/features/maintenance/maintenance-banner";
|
||||
@ -91,6 +93,12 @@ export default function MainApp() {
|
||||
const effectiveGitHubAuthUrl = isOnTosPage ? null : gitHubAuthUrl;
|
||||
|
||||
const [consentFormIsOpen, setConsentFormIsOpen] = React.useState(false);
|
||||
const {
|
||||
emailVerificationModalOpen,
|
||||
setEmailVerificationModalOpen,
|
||||
emailVerified,
|
||||
hasDuplicatedEmail,
|
||||
} = useEmailVerification();
|
||||
|
||||
// Auto-login if login method is stored in local storage
|
||||
useAutoLogin();
|
||||
@ -236,9 +244,18 @@ export default function MainApp() {
|
||||
appMode={config.data?.APP_MODE}
|
||||
providersConfigured={config.data?.PROVIDERS_CONFIGURED}
|
||||
authUrl={config.data?.AUTH_URL}
|
||||
emailVerified={emailVerified}
|
||||
hasDuplicatedEmail={hasDuplicatedEmail}
|
||||
/>
|
||||
)}
|
||||
{renderReAuthModal && <ReauthModal />}
|
||||
{emailVerificationModalOpen && (
|
||||
<EmailVerificationModal
|
||||
onClose={() => {
|
||||
setEmailVerificationModalOpen(false);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{config.data?.APP_MODE === "oss" && consentFormIsOpen && (
|
||||
<AnalyticsConsentFormModal
|
||||
onClose={() => {
|
||||
|
||||
@ -187,7 +187,7 @@ class RemoteSandboxService(SandboxService):
|
||||
return SandboxStatus.MISSING
|
||||
|
||||
status = None
|
||||
pod_status = runtime['pod_status'].lower()
|
||||
pod_status = (runtime.get('pod_status') or '').lower()
|
||||
if pod_status:
|
||||
status = POD_STATUS_MAPPING.get(pod_status, None)
|
||||
|
||||
@ -356,7 +356,7 @@ class RemoteSandboxService(SandboxService):
|
||||
StoredRemoteSandbox.id == runtime.get('session_id')
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
sandbox = result.first()
|
||||
sandbox = result.scalar_one_or_none()
|
||||
if sandbox is None:
|
||||
raise ValueError('sandbox_not_found')
|
||||
return self._to_sandbox_info(sandbox, runtime)
|
||||
|
||||
@ -21,6 +21,7 @@ from litellm import completion as litellm_completion
|
||||
from litellm import completion_cost as litellm_completion_cost
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
BadGatewayError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
@ -45,6 +46,7 @@ LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
BadGatewayError,
|
||||
litellm.Timeout,
|
||||
litellm.InternalServerError,
|
||||
LLMNoResponseError,
|
||||
|
||||
@ -510,6 +510,10 @@ async def delete_conversation(
|
||||
if v1_result is not None:
|
||||
return v1_result
|
||||
|
||||
# Close connections
|
||||
await db_session.close()
|
||||
await httpx_client.aclose()
|
||||
|
||||
# V0 conversation logic
|
||||
return await _delete_v0_conversation(conversation_id, user_id)
|
||||
|
||||
@ -551,11 +555,8 @@ async def _try_delete_v1_conversation(
|
||||
httpx_client,
|
||||
)
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
# Not a valid UUID, continue with V0 logic
|
||||
pass
|
||||
except Exception:
|
||||
# Some other error, continue with V0 logic
|
||||
# Continue with V0 logic
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from litellm.exceptions import APIConnectionError
|
||||
from litellm.exceptions import APIConnectionError, BadGatewayError
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.llm.llm import LLM
|
||||
@ -86,3 +86,25 @@ def test_completion_max_retries_api_connection_error(
|
||||
# The exception doesn't contain retry information in the current implementation
|
||||
# Just verify that we got an APIConnectionError
|
||||
assert 'API connection error' in str(excinfo.value)
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_retries_bad_gateway_error(mock_litellm_completion, default_config):
|
||||
"""Test that BadGatewayError is properly retried."""
|
||||
mock_litellm_completion.side_effect = [
|
||||
BadGatewayError(
|
||||
message='Bad gateway',
|
||||
llm_provider='test_provider',
|
||||
model='test_model',
|
||||
),
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response['choices'][0]['message']['content'] == 'Retry successful'
|
||||
assert mock_litellm_completion.call_count == 2
|
||||
|
||||
@ -946,6 +946,10 @@ async def test_delete_conversation():
|
||||
# Create a mock sandbox service
|
||||
mock_sandbox_service = MagicMock()
|
||||
|
||||
# Create mock db_session and httpx_client
|
||||
mock_db_session = AsyncMock()
|
||||
mock_httpx_client = AsyncMock()
|
||||
|
||||
# Mock the conversation manager
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.conversation_manager'
|
||||
@ -969,6 +973,8 @@ async def test_delete_conversation():
|
||||
app_conversation_service=mock_app_conversation_service,
|
||||
app_conversation_info_service=mock_app_conversation_info_service,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
db_session=mock_db_session,
|
||||
httpx_client=mock_httpx_client,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
@ -1090,6 +1096,10 @@ async def test_delete_v1_conversation_not_found():
|
||||
)
|
||||
mock_service.delete_app_conversation = AsyncMock(return_value=False)
|
||||
|
||||
# Create mock db_session and httpx_client
|
||||
mock_db_session = AsyncMock()
|
||||
mock_httpx_client = AsyncMock()
|
||||
|
||||
# Call delete_conversation with V1 conversation ID
|
||||
result = await delete_conversation(
|
||||
request=MagicMock(),
|
||||
@ -1098,6 +1108,8 @@ async def test_delete_v1_conversation_not_found():
|
||||
app_conversation_service=mock_service,
|
||||
app_conversation_info_service=mock_info_service,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
db_session=mock_db_session,
|
||||
httpx_client=mock_httpx_client,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
@ -1171,6 +1183,10 @@ async def test_delete_v1_conversation_invalid_uuid():
|
||||
mock_sandbox_service = MagicMock()
|
||||
mock_sandbox_service_dep.return_value = mock_sandbox_service
|
||||
|
||||
# Create mock db_session and httpx_client
|
||||
mock_db_session = AsyncMock()
|
||||
mock_httpx_client = AsyncMock()
|
||||
|
||||
# Call delete_conversation
|
||||
result = await delete_conversation(
|
||||
request=MagicMock(),
|
||||
@ -1179,6 +1195,8 @@ async def test_delete_v1_conversation_invalid_uuid():
|
||||
app_conversation_service=mock_service,
|
||||
app_conversation_info_service=mock_info_service,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
db_session=mock_db_session,
|
||||
httpx_client=mock_httpx_client,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
@ -1264,6 +1282,10 @@ async def test_delete_v1_conversation_service_error():
|
||||
mock_runtime_cls.delete = AsyncMock()
|
||||
mock_get_runtime_cls.return_value = mock_runtime_cls
|
||||
|
||||
# Create mock db_session and httpx_client
|
||||
mock_db_session = AsyncMock()
|
||||
mock_httpx_client = AsyncMock()
|
||||
|
||||
# Call delete_conversation
|
||||
result = await delete_conversation(
|
||||
request=MagicMock(),
|
||||
@ -1272,6 +1294,8 @@ async def test_delete_v1_conversation_service_error():
|
||||
app_conversation_service=mock_service,
|
||||
app_conversation_info_service=mock_info_service,
|
||||
sandbox_service=mock_sandbox_service,
|
||||
db_session=mock_db_session,
|
||||
httpx_client=mock_httpx_client,
|
||||
)
|
||||
|
||||
# Verify the result (should fallback to V0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user