mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
fix(backend): user email capture (#12902)
Co-authored-by: OpenHands Bot <contact@all-hands.dev> Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -208,6 +208,7 @@ async def keycloak_callback(
|
||||
else:
|
||||
# Existing user — gradually backfill contact_name if it still has a username-style value
|
||||
await UserStore.backfill_contact_name(user_id, user_info)
|
||||
await UserStore.backfill_user_email(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
|
||||
@@ -8,6 +8,7 @@ from server.auth.keycloak_manager import get_keycloak_admin
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.auth import set_response_cookie
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
@@ -62,6 +63,10 @@ async def update_email(
|
||||
},
|
||||
)
|
||||
|
||||
await UserStore.update_user_email(
|
||||
user_id=user_id, email=email, email_verified=False
|
||||
)
|
||||
|
||||
user_auth: SaasUserAuth = await get_user_auth(request)
|
||||
await user_auth.refresh() # refresh so access token has updated email
|
||||
user_auth.email = email
|
||||
@@ -144,6 +149,7 @@ async def verified_email(request: Request):
|
||||
user_auth: SaasUserAuth = await get_user_auth(request)
|
||||
await user_auth.refresh() # refresh so access token has updated email
|
||||
user_auth.email_verified = True
|
||||
await UserStore.update_user_email(user_id=user_auth.user_id, email_verified=True)
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/settings/user'
|
||||
response = RedirectResponse(redirect_uri, status_code=302)
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.user_store import UserStore
|
||||
from utils.identity import resolve_display_name
|
||||
|
||||
from openhands.integrations.provider import (
|
||||
@@ -115,13 +116,21 @@ async def saas_get_user(
|
||||
content='Failed to retrieve user_info.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
# Prefer email from DB; fall back to Keycloak if not yet persisted
|
||||
email = user_info.get('email') if user_info else None
|
||||
sub = user_info.get('sub') if user_info else ''
|
||||
if sub:
|
||||
db_user = await UserStore.get_user_by_id_async(sub)
|
||||
if db_user and db_user.email is not None:
|
||||
email = db_user.email
|
||||
|
||||
retval = await _check_idp(
|
||||
access_token=access_token,
|
||||
default_value=User(
|
||||
id=(user_info.get('sub') if user_info else '') or '',
|
||||
id=sub,
|
||||
login=(user_info.get('preferred_username') if user_info else '') or '',
|
||||
avatar_url='',
|
||||
email=user_info.get('email') if user_info else None,
|
||||
email=email,
|
||||
name=resolve_display_name(user_info) if user_info else None,
|
||||
company=user_info.get('company') if user_info else None,
|
||||
),
|
||||
|
||||
@@ -869,6 +869,88 @@ class UserStore:
|
||||
org.contact_name = real_name
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def update_user_email(
|
||||
user_id: str,
|
||||
email: str | None = None,
|
||||
email_verified: bool | None = None,
|
||||
) -> None:
|
||||
"""Unconditionally update User.email and/or email_verified.
|
||||
|
||||
Unlike backfill_user_email(), this overwrites existing values.
|
||||
No-op when both arguments are None.
|
||||
Missing user is logged as a warning and ignored.
|
||||
"""
|
||||
if email is None and email_verified is None:
|
||||
return
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
logger.warning(
|
||||
'update_user_email:user_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
if email is not None:
|
||||
user.email = email
|
||||
if email_verified is not None:
|
||||
user.email_verified = email_verified
|
||||
|
||||
logger.info(
|
||||
'update_user_email:updated',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'email_set': email is not None,
|
||||
'email_verified_set': email_verified is not None,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def backfill_user_email(user_id: str, user_info: dict) -> None:
|
||||
"""Set User.email and email_verified from IDP if they are still NULL.
|
||||
|
||||
Called during login to gradually fix existing users whose email
|
||||
was never persisted on the User record. Preserves non-NULL values
|
||||
(e.g. if a user manually changed their email).
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
logger.debug(
|
||||
'backfill_user_email:user_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
updated = False
|
||||
if user.email is None:
|
||||
user.email = user_info.get('email')
|
||||
updated = True
|
||||
|
||||
if user.email_verified is None:
|
||||
user.email_verified = user_info.get('email_verified', False)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
logger.info(
|
||||
'backfill_user_email:updated',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'email_set': user.email is not None,
|
||||
'email_verified_set': user.email_verified is not None,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Prevent circular imports
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -6,8 +6,10 @@ from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.email import (
|
||||
EmailUpdate,
|
||||
ResendEmailVerificationRequest,
|
||||
resend_email_verification,
|
||||
update_email,
|
||||
verified_email,
|
||||
verify_email,
|
||||
)
|
||||
@@ -116,12 +118,15 @@ async def test_verified_email_default_redirect(mock_request, mock_user_auth):
|
||||
"""Test verified_email redirects to /settings/user by default."""
|
||||
# Arrange
|
||||
mock_request.query_params.get.return_value = None
|
||||
mock_user_auth.user_id = 'test-user-id'
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.email.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_user_store.update_user_email = AsyncMock()
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
@@ -140,12 +145,15 @@ async def test_verified_email_https_scheme(mock_request, mock_user_auth):
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_request.query_params.get.return_value = None
|
||||
mock_user_auth.user_id = 'test-user-id'
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.email.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_user_store.update_user_email = AsyncMock()
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
@@ -327,6 +335,62 @@ async def test_resend_email_verification_with_is_auth_flow_false(mock_request):
|
||||
assert '/api/email/verified' in call_args.kwargs['redirect_uri']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_email_calls_update_user_email(mock_request, mock_user_auth):
|
||||
"""POST /api/email should call UserStore.update_user_email with new email and email_verified=False."""
|
||||
user_id = 'test-user-id'
|
||||
new_email = 'new@example.com'
|
||||
email_data = EmailUpdate(email=new_email)
|
||||
|
||||
mock_keycloak_admin = MagicMock()
|
||||
mock_keycloak_admin.get_user.return_value = {
|
||||
'enabled': True,
|
||||
'username': 'testuser',
|
||||
}
|
||||
mock_keycloak_admin.a_update_user = AsyncMock()
|
||||
mock_user_store = MagicMock()
|
||||
mock_user_store.update_user_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.email.UserStore', mock_user_store),
|
||||
):
|
||||
result = await update_email(
|
||||
email_data=email_data, request=mock_request, user_id=user_id
|
||||
)
|
||||
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
mock_user_store.update_user_email.assert_awaited_once_with(
|
||||
user_id=user_id, email=new_email, email_verified=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_calls_update_user_email(mock_request, mock_user_auth):
|
||||
"""GET /api/email/verified should call UserStore.update_user_email with email_verified=True."""
|
||||
mock_user_auth.user_id = 'test-user-id'
|
||||
|
||||
mock_user_store = MagicMock()
|
||||
mock_user_store.update_user_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie'),
|
||||
patch('server.routes.email.UserStore', mock_user_store),
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
assert result.status_code == 302
|
||||
mock_user_store.update_user_email.assert_awaited_once_with(
|
||||
user_id='test-user-id', email_verified=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_body_none_uses_auth(mock_request):
|
||||
"""Test resend_email_verification uses auth when body is None."""
|
||||
|
||||
@@ -154,6 +154,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = False
|
||||
@@ -190,6 +191,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -262,6 +264,7 @@ async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
@@ -310,6 +313,7 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
@@ -352,6 +356,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -587,6 +592,7 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
@@ -651,6 +657,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
@@ -715,6 +722,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
@@ -777,6 +785,7 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
@@ -823,6 +832,7 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
@@ -868,6 +878,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
@@ -926,6 +937,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -984,6 +996,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1045,6 +1058,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1202,6 +1216,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1267,6 +1282,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
@@ -1350,6 +1366,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1438,6 +1455,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1523,6 +1541,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1607,6 +1626,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1688,6 +1708,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1755,6 +1776,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1828,6 +1850,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1899,6 +1922,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
@@ -1918,3 +1942,57 @@ class TestKeycloakCallbackRecaptcha:
|
||||
assert call_kwargs[0][0] == 'recaptcha_blocked_at_callback'
|
||||
assert call_kwargs[1]['extra']['score'] == 0.2
|
||||
assert call_kwargs[1]['extra']['user_id'] == 'test_user_id'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_calls_backfill_user_email_for_existing_user(
|
||||
mock_request,
|
||||
):
|
||||
"""When an existing user logs in, backfill_user_email should be called."""
|
||||
user_info = {
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email': 'test@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog'),
|
||||
):
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(return_value=user_info)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
|
||||
# backfill_user_email should have been called with the user_id and user_info
|
||||
mock_user_store.backfill_user_email.assert_called_once_with(
|
||||
'test_user_id', user_info
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ the endpoint constructs a User from OIDC claims. These tests verify that name an
|
||||
fields are correctly populated from Keycloak claims in this fallback path.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
@@ -33,9 +33,20 @@ def mock_check_idp():
|
||||
yield mock_fn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_store():
|
||||
"""Mock UserStore.get_user_by_id_async to return None by default."""
|
||||
with patch(
|
||||
'server.routes.user.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
) as mock_fn:
|
||||
yield mock_fn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_user_includes_name_from_name_claim(
|
||||
mock_token_manager, mock_check_idp
|
||||
mock_token_manager, mock_check_idp, mock_user_store
|
||||
):
|
||||
"""When Keycloak provides a 'name' claim, the fallback User should include it."""
|
||||
from server.routes.user import saas_get_user
|
||||
@@ -62,7 +73,7 @@ async def test_fallback_user_includes_name_from_name_claim(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_user_combines_given_and_family_name(
|
||||
mock_token_manager, mock_check_idp
|
||||
mock_token_manager, mock_check_idp, mock_user_store
|
||||
):
|
||||
"""When 'name' is absent, combine given_name + family_name."""
|
||||
from server.routes.user import saas_get_user
|
||||
@@ -89,7 +100,7 @@ async def test_fallback_user_combines_given_and_family_name(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_user_name_is_none_when_no_name_claims(
|
||||
mock_token_manager, mock_check_idp
|
||||
mock_token_manager, mock_check_idp, mock_user_store
|
||||
):
|
||||
"""When no name claims exist, name should be None."""
|
||||
from server.routes.user import saas_get_user
|
||||
@@ -113,7 +124,9 @@ async def test_fallback_user_name_is_none_when_no_name_claims(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_user_includes_company_claim(mock_token_manager, mock_check_idp):
|
||||
async def test_fallback_user_includes_company_claim(
|
||||
mock_token_manager, mock_check_idp, mock_user_store
|
||||
):
|
||||
"""When Keycloak provides a 'company' claim, include it in the User."""
|
||||
from server.routes.user import saas_get_user
|
||||
|
||||
@@ -139,7 +152,7 @@ async def test_fallback_user_includes_company_claim(mock_token_manager, mock_che
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_user_company_is_none_when_absent(
|
||||
mock_token_manager, mock_check_idp
|
||||
mock_token_manager, mock_check_idp, mock_user_store
|
||||
):
|
||||
"""When 'company' is not in Keycloak claims, company should be None."""
|
||||
from server.routes.user import saas_get_user
|
||||
@@ -161,3 +174,88 @@ async def test_fallback_user_company_is_none_when_absent(
|
||||
|
||||
assert isinstance(result, User)
|
||||
assert result.company is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_user_email_from_db_when_available(
|
||||
mock_token_manager, mock_check_idp, mock_user_store
|
||||
):
|
||||
"""When User.email is stored in DB, use it instead of Keycloak's live email."""
|
||||
from server.routes.user import saas_get_user
|
||||
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': '248289761001',
|
||||
'preferred_username': 'j.doe',
|
||||
'email': 'keycloak@example.com',
|
||||
}
|
||||
)
|
||||
|
||||
mock_db_user = MagicMock()
|
||||
mock_db_user.email = 'db@example.com'
|
||||
mock_user_store.return_value = mock_db_user
|
||||
|
||||
result = await saas_get_user(
|
||||
provider_tokens=None,
|
||||
access_token=SecretStr('test-token'),
|
||||
user_id='248289761001',
|
||||
)
|
||||
|
||||
assert isinstance(result, User)
|
||||
assert result.email == 'db@example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_user_email_falls_back_to_keycloak_when_db_null(
|
||||
mock_token_manager, mock_check_idp, mock_user_store
|
||||
):
|
||||
"""When User.email is NULL in DB, fall back to Keycloak's email."""
|
||||
from server.routes.user import saas_get_user
|
||||
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': '248289761001',
|
||||
'preferred_username': 'j.doe',
|
||||
'email': 'keycloak@example.com',
|
||||
}
|
||||
)
|
||||
|
||||
mock_db_user = MagicMock()
|
||||
mock_db_user.email = None
|
||||
mock_user_store.return_value = mock_db_user
|
||||
|
||||
result = await saas_get_user(
|
||||
provider_tokens=None,
|
||||
access_token=SecretStr('test-token'),
|
||||
user_id='248289761001',
|
||||
)
|
||||
|
||||
assert isinstance(result, User)
|
||||
assert result.email == 'keycloak@example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_user_email_falls_back_to_keycloak_when_no_db_user(
|
||||
mock_token_manager, mock_check_idp, mock_user_store
|
||||
):
|
||||
"""When DB user doesn't exist, fall back to Keycloak's email."""
|
||||
from server.routes.user import saas_get_user
|
||||
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': '248289761001',
|
||||
'preferred_username': 'j.doe',
|
||||
'email': 'keycloak@example.com',
|
||||
}
|
||||
)
|
||||
|
||||
# mock_user_store already returns None by default
|
||||
|
||||
result = await saas_get_user(
|
||||
provider_tokens=None,
|
||||
access_token=SecretStr('test-token'),
|
||||
user_id='248289761001',
|
||||
)
|
||||
|
||||
assert isinstance(result, User)
|
||||
assert result.email == 'keycloak@example.com'
|
||||
|
||||
@@ -639,6 +639,204 @@ async def test_backfill_contact_name_preserves_custom_value(session_maker):
|
||||
assert org.contact_name == 'Custom Corp Name'
|
||||
|
||||
|
||||
# --- Tests for backfill_user_email on login ---
|
||||
# Existing users created before the email capture fix may have NULL
|
||||
# email in the User table. The backfill sets User.email from the IDP
|
||||
# when the user next logs in, but preserves manual changes (non-NULL).
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_user_email_sets_email_when_null(session_maker):
|
||||
"""When User.email is NULL, backfill_user_email should set it from user_info."""
|
||||
user_id = str(uuid.uuid4())
|
||||
with session_maker() as session:
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_email='jdoe@example.com',
|
||||
)
|
||||
session.add(org)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
email=None,
|
||||
email_verified=None,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
user_info = {
|
||||
'email': 'jdoe@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
|
||||
with patch(
|
||||
'storage.user_store.a_session_maker',
|
||||
_wrap_sync_as_async_session_maker(session_maker),
|
||||
):
|
||||
await UserStore.backfill_user_email(user_id, user_info)
|
||||
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
assert user.email == 'jdoe@example.com'
|
||||
assert user.email_verified is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_user_email_does_not_overwrite_existing(session_maker):
|
||||
"""When User.email is already set, backfill_user_email should NOT overwrite it."""
|
||||
user_id = str(uuid.uuid4())
|
||||
with session_maker() as session:
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_email='original@example.com',
|
||||
)
|
||||
session.add(org)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
email='custom@example.com',
|
||||
email_verified=True,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
user_info = {
|
||||
'email': 'different@example.com',
|
||||
'email_verified': False,
|
||||
}
|
||||
|
||||
with patch(
|
||||
'storage.user_store.a_session_maker',
|
||||
_wrap_sync_as_async_session_maker(session_maker),
|
||||
):
|
||||
await UserStore.backfill_user_email(user_id, user_info)
|
||||
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
assert user.email == 'custom@example.com'
|
||||
assert user.email_verified is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_user_email_sets_verified_when_null(session_maker):
|
||||
"""When User.email is set but email_verified is NULL, backfill should set email_verified."""
|
||||
user_id = str(uuid.uuid4())
|
||||
with session_maker() as session:
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_email='jdoe@example.com',
|
||||
)
|
||||
session.add(org)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
email='jdoe@example.com',
|
||||
email_verified=None,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
user_info = {
|
||||
'email': 'different@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
|
||||
with patch(
|
||||
'storage.user_store.a_session_maker',
|
||||
_wrap_sync_as_async_session_maker(session_maker),
|
||||
):
|
||||
await UserStore.backfill_user_email(user_id, user_info)
|
||||
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
# email should NOT be overwritten since it's non-NULL
|
||||
assert user.email == 'jdoe@example.com'
|
||||
# email_verified should be set since it was NULL
|
||||
assert user.email_verified is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_sets_email_verified_false_from_user_info():
|
||||
"""When user_info has email_verified=False, create_user() should set User.email_verified=False."""
|
||||
user_id = str(uuid.uuid4())
|
||||
user_info = {
|
||||
'preferred_username': 'jsmith',
|
||||
'email': 'jsmith@example.com',
|
||||
'email_verified': False,
|
||||
}
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_settings = Settings(
|
||||
language='en',
|
||||
llm_api_key=SecretStr('test-key'),
|
||||
llm_base_url='http://test.url',
|
||||
)
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.id = 1
|
||||
|
||||
with (
|
||||
patch('storage.user_store.session_maker', mock_sm),
|
||||
patch.object(
|
||||
UserStore,
|
||||
'create_default_settings',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_settings,
|
||||
),
|
||||
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
|
||||
patch(
|
||||
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
|
||||
return_value={'llm_model': None, 'llm_base_url': None},
|
||||
),
|
||||
):
|
||||
mock_session.commit.side_effect = _StopAfterUserCreation
|
||||
with pytest.raises(_StopAfterUserCreation):
|
||||
await UserStore.create_user(user_id, user_info)
|
||||
|
||||
user = mock_session.add.call_args_list[1][0][0]
|
||||
assert isinstance(user, User)
|
||||
assert user.email == 'jsmith@example.com'
|
||||
assert user.email_verified is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_preserves_org_contact_email():
|
||||
"""create_user() must still set Org.contact_email (no regression)."""
|
||||
user_id = str(uuid.uuid4())
|
||||
user_info = {
|
||||
'preferred_username': 'jdoe',
|
||||
'email': 'jdoe@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch('storage.user_store.session_maker', mock_sm),
|
||||
patch.object(
|
||||
UserStore,
|
||||
'create_default_settings',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
await UserStore.create_user(user_id, user_info)
|
||||
|
||||
org = mock_session.add.call_args_list[0][0][0]
|
||||
assert isinstance(org, Org)
|
||||
assert org.contact_email == 'jdoe@example.com'
|
||||
|
||||
|
||||
def test_update_current_org_success(session_maker):
|
||||
"""
|
||||
GIVEN: User exists in database
|
||||
@@ -680,3 +878,100 @@ def test_update_current_org_user_not_found(session_maker):
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- Tests for update_user_email ---
|
||||
# update_user_email() should unconditionally overwrite User.email and/or email_verified.
|
||||
# Unlike backfill_user_email(), it does not check for NULL before writing.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_email_overwrites_existing(session_maker):
|
||||
"""update_user_email() should overwrite existing email and email_verified values."""
|
||||
user_id = str(uuid.uuid4())
|
||||
with session_maker() as session:
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_email='old@example.com',
|
||||
)
|
||||
session.add(org)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
email='old@example.com',
|
||||
email_verified=True,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
with patch(
|
||||
'storage.user_store.a_session_maker',
|
||||
_wrap_sync_as_async_session_maker(session_maker),
|
||||
):
|
||||
await UserStore.update_user_email(
|
||||
user_id, email='new@example.com', email_verified=False
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
assert user.email == 'new@example.com'
|
||||
assert user.email_verified is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_email_updates_only_email_verified(session_maker):
|
||||
"""update_user_email() with email=None should only update email_verified."""
|
||||
user_id = str(uuid.uuid4())
|
||||
with session_maker() as session:
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_email='keep@example.com',
|
||||
)
|
||||
session.add(org)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
email='keep@example.com',
|
||||
email_verified=False,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
with patch(
|
||||
'storage.user_store.a_session_maker',
|
||||
_wrap_sync_as_async_session_maker(session_maker),
|
||||
):
|
||||
await UserStore.update_user_email(user_id, email_verified=True)
|
||||
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
assert user.email == 'keep@example.com'
|
||||
assert user.email_verified is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_email_noop_when_both_none():
|
||||
"""update_user_email() with both args None should not open a session."""
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_session_maker = MagicMock()
|
||||
|
||||
with patch('storage.user_store.a_session_maker', mock_session_maker):
|
||||
await UserStore.update_user_email(user_id, email=None, email_verified=None)
|
||||
|
||||
mock_session_maker.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_email_missing_user_returns_without_error(session_maker):
|
||||
"""update_user_email() with a non-existent user_id should return without error."""
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
with patch(
|
||||
'storage.user_store.a_session_maker',
|
||||
_wrap_sync_as_async_session_maker(session_maker),
|
||||
):
|
||||
await UserStore.update_user_email(
|
||||
user_id, email='new@example.com', email_verified=False
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user