fix(backend): user email capture (#12902)

Co-authored-by: OpenHands Bot <contact@all-hands.dev>
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
sp.wack
2026-02-26 17:29:30 +04:00
committed by GitHub
parent 409df1287d
commit d4b9fb1d03
8 changed files with 641 additions and 8 deletions

View File

@@ -208,6 +208,7 @@ async def keycloak_callback(
else:
# Existing user — gradually backfill contact_name if it still has a username-style value
await UserStore.backfill_contact_name(user_id, user_info)
await UserStore.backfill_user_email(user_id, user_info)
if not user:
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')

View File

@@ -8,6 +8,7 @@ from server.auth.keycloak_manager import get_keycloak_admin
from server.auth.saas_user_auth import SaasUserAuth
from server.routes.auth import set_response_cookie
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
@@ -62,6 +63,10 @@ async def update_email(
},
)
await UserStore.update_user_email(
user_id=user_id, email=email, email_verified=False
)
user_auth: SaasUserAuth = await get_user_auth(request)
await user_auth.refresh() # refresh so access token has updated email
user_auth.email = email
@@ -144,6 +149,7 @@ async def verified_email(request: Request):
user_auth: SaasUserAuth = await get_user_auth(request)
await user_auth.refresh() # refresh so access token has updated email
user_auth.email_verified = True
await UserStore.update_user_email(user_id=user_auth.user_id, email_verified=True)
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
redirect_uri = f'{scheme}://{request.url.netloc}/settings/user'
response = RedirectResponse(redirect_uri, status_code=302)

View File

@@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, Query, status
from fastapi.responses import JSONResponse
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from storage.user_store import UserStore
from utils.identity import resolve_display_name
from openhands.integrations.provider import (
@@ -115,13 +116,21 @@ async def saas_get_user(
content='Failed to retrieve user_info.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
# Prefer email from DB; fall back to Keycloak if not yet persisted
email = user_info.get('email') if user_info else None
sub = user_info.get('sub') if user_info else ''
if sub:
db_user = await UserStore.get_user_by_id_async(sub)
if db_user and db_user.email is not None:
email = db_user.email
retval = await _check_idp(
access_token=access_token,
default_value=User(
id=(user_info.get('sub') if user_info else '') or '',
id=sub,
login=(user_info.get('preferred_username') if user_info else '') or '',
avatar_url='',
email=user_info.get('email') if user_info else None,
email=email,
name=resolve_display_name(user_info) if user_info else None,
company=user_info.get('company') if user_info else None,
),

View File

@@ -869,6 +869,88 @@ class UserStore:
org.contact_name = real_name
await session.commit()
@staticmethod
async def update_user_email(
user_id: str,
email: str | None = None,
email_verified: bool | None = None,
) -> None:
"""Unconditionally update User.email and/or email_verified.
Unlike backfill_user_email(), this overwrites existing values.
No-op when both arguments are None.
Missing user is logged as a warning and ignored.
"""
if email is None and email_verified is None:
return
async with a_session_maker() as session:
result = await session.execute(
select(User).filter(User.id == uuid.UUID(user_id))
)
user = result.scalars().first()
if not user:
logger.warning(
'update_user_email:user_not_found',
extra={'user_id': user_id},
)
return
if email is not None:
user.email = email
if email_verified is not None:
user.email_verified = email_verified
logger.info(
'update_user_email:updated',
extra={
'user_id': user_id,
'email_set': email is not None,
'email_verified_set': email_verified is not None,
},
)
await session.commit()
@staticmethod
async def backfill_user_email(user_id: str, user_info: dict) -> None:
"""Set User.email and email_verified from IDP if they are still NULL.
Called during login to gradually fix existing users whose email
was never persisted on the User record. Preserves non-NULL values
(e.g. if a user manually changed their email).
"""
async with a_session_maker() as session:
result = await session.execute(
select(User).filter(User.id == uuid.UUID(user_id))
)
user = result.scalars().first()
if not user:
logger.debug(
'backfill_user_email:user_not_found',
extra={'user_id': user_id},
)
return
updated = False
if user.email is None:
user.email = user_info.get('email')
updated = True
if user.email_verified is None:
user.email_verified = user_info.get('email_verified', False)
updated = True
if updated:
logger.info(
'backfill_user_email:updated',
extra={
'user_id': user_id,
'email_set': user.email is not None,
'email_verified_set': user.email_verified is not None,
},
)
await session.commit()
# Prevent circular imports
from typing import TYPE_CHECKING

View File

@@ -6,8 +6,10 @@ from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import SecretStr
from server.auth.saas_user_auth import SaasUserAuth
from server.routes.email import (
EmailUpdate,
ResendEmailVerificationRequest,
resend_email_verification,
update_email,
verified_email,
verify_email,
)
@@ -116,12 +118,15 @@ async def test_verified_email_default_redirect(mock_request, mock_user_auth):
"""Test verified_email redirects to /settings/user by default."""
# Arrange
mock_request.query_params.get.return_value = None
mock_user_auth.user_id = 'test-user-id'
# Act
with (
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
patch('server.routes.email.UserStore') as mock_user_store,
):
mock_user_store.update_user_email = AsyncMock()
result = await verified_email(mock_request)
# Assert
@@ -140,12 +145,15 @@ async def test_verified_email_https_scheme(mock_request, mock_user_auth):
mock_request.url.hostname = 'example.com'
mock_request.url.netloc = 'example.com'
mock_request.query_params.get.return_value = None
mock_user_auth.user_id = 'test-user-id'
# Act
with (
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
patch('server.routes.email.UserStore') as mock_user_store,
):
mock_user_store.update_user_email = AsyncMock()
result = await verified_email(mock_request)
# Assert
@@ -327,6 +335,62 @@ async def test_resend_email_verification_with_is_auth_flow_false(mock_request):
assert '/api/email/verified' in call_args.kwargs['redirect_uri']
@pytest.mark.asyncio
async def test_update_email_calls_update_user_email(mock_request, mock_user_auth):
"""POST /api/email should call UserStore.update_user_email with new email and email_verified=False."""
user_id = 'test-user-id'
new_email = 'new@example.com'
email_data = EmailUpdate(email=new_email)
mock_keycloak_admin = MagicMock()
mock_keycloak_admin.get_user.return_value = {
'enabled': True,
'username': 'testuser',
}
mock_keycloak_admin.a_update_user = AsyncMock()
mock_user_store = MagicMock()
mock_user_store.update_user_email = AsyncMock()
with (
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
patch('server.routes.email.set_response_cookie'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
patch('server.routes.email.UserStore', mock_user_store),
):
result = await update_email(
email_data=email_data, request=mock_request, user_id=user_id
)
assert result.status_code == status.HTTP_200_OK
mock_user_store.update_user_email.assert_awaited_once_with(
user_id=user_id, email=new_email, email_verified=False
)
@pytest.mark.asyncio
async def test_verified_email_calls_update_user_email(mock_request, mock_user_auth):
"""GET /api/email/verified should call UserStore.update_user_email with email_verified=True."""
mock_user_auth.user_id = 'test-user-id'
mock_user_store = MagicMock()
mock_user_store.update_user_email = AsyncMock()
with (
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
patch('server.routes.email.set_response_cookie'),
patch('server.routes.email.UserStore', mock_user_store),
):
result = await verified_email(mock_request)
assert result.status_code == 302
mock_user_store.update_user_email.assert_awaited_once_with(
user_id='test-user-id', email_verified=True
)
@pytest.mark.asyncio
async def test_resend_email_verification_body_none_uses_auth(mock_request):
"""Test resend_email_verification uses auth when body is None."""

View File

@@ -154,6 +154,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = False
@@ -190,6 +191,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -262,6 +264,7 @@ async def test_keycloak_callback_email_not_verified(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
# Act
result = await keycloak_callback(
@@ -310,6 +313,7 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
# Act
result = await keycloak_callback(
@@ -352,6 +356,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -587,6 +592,7 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = True
@@ -651,6 +657,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = False
@@ -715,6 +722,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_active.return_value = False
mock_domain_blocker.is_domain_blocked.return_value = False
@@ -777,6 +785,7 @@ async def test_keycloak_callback_missing_email(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_active.return_value = True
@@ -823,6 +832,7 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
# Act
result = await keycloak_callback(
@@ -868,6 +878,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
# Act
result = await keycloak_callback(
@@ -926,6 +937,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -984,6 +996,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1045,6 +1058,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1202,6 +1216,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1267,6 +1282,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_domain_blocked.return_value = False
@@ -1350,6 +1366,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1438,6 +1455,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1523,6 +1541,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1607,6 +1626,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1688,6 +1708,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1755,6 +1776,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1828,6 +1850,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1899,6 +1922,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_domain_blocked.return_value = False
@@ -1918,3 +1942,57 @@ class TestKeycloakCallbackRecaptcha:
assert call_kwargs[0][0] == 'recaptcha_blocked_at_callback'
assert call_kwargs[1]['extra']['score'] == 0.2
assert call_kwargs[1]['extra']['user_id'] == 'test_user_id'
@pytest.mark.asyncio
async def test_keycloak_callback_calls_backfill_user_email_for_existing_user(
mock_request,
):
"""When an existing user logs in, backfill_user_email should be called."""
user_info = {
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email': 'test@example.com',
'email_verified': True,
}
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.posthog'),
):
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(return_value=user_info)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
# backfill_user_email should have been called with the user_id and user_info
mock_user_store.backfill_user_email.assert_called_once_with(
'test_user_id', user_info
)

View File

@@ -5,7 +5,7 @@ the endpoint constructs a User from OIDC claims. These tests verify that name an
fields are correctly populated from Keycloak claims in this fallback path.
"""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
@@ -33,9 +33,20 @@ def mock_check_idp():
yield mock_fn
@pytest.fixture
def mock_user_store():
"""Mock UserStore.get_user_by_id_async to return None by default."""
with patch(
'server.routes.user.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=None,
) as mock_fn:
yield mock_fn
@pytest.mark.asyncio
async def test_fallback_user_includes_name_from_name_claim(
mock_token_manager, mock_check_idp
mock_token_manager, mock_check_idp, mock_user_store
):
"""When Keycloak provides a 'name' claim, the fallback User should include it."""
from server.routes.user import saas_get_user
@@ -62,7 +73,7 @@ async def test_fallback_user_includes_name_from_name_claim(
@pytest.mark.asyncio
async def test_fallback_user_combines_given_and_family_name(
mock_token_manager, mock_check_idp
mock_token_manager, mock_check_idp, mock_user_store
):
"""When 'name' is absent, combine given_name + family_name."""
from server.routes.user import saas_get_user
@@ -89,7 +100,7 @@ async def test_fallback_user_combines_given_and_family_name(
@pytest.mark.asyncio
async def test_fallback_user_name_is_none_when_no_name_claims(
mock_token_manager, mock_check_idp
mock_token_manager, mock_check_idp, mock_user_store
):
"""When no name claims exist, name should be None."""
from server.routes.user import saas_get_user
@@ -113,7 +124,9 @@ async def test_fallback_user_name_is_none_when_no_name_claims(
@pytest.mark.asyncio
async def test_fallback_user_includes_company_claim(mock_token_manager, mock_check_idp):
async def test_fallback_user_includes_company_claim(
mock_token_manager, mock_check_idp, mock_user_store
):
"""When Keycloak provides a 'company' claim, include it in the User."""
from server.routes.user import saas_get_user
@@ -139,7 +152,7 @@ async def test_fallback_user_includes_company_claim(mock_token_manager, mock_che
@pytest.mark.asyncio
async def test_fallback_user_company_is_none_when_absent(
mock_token_manager, mock_check_idp
mock_token_manager, mock_check_idp, mock_user_store
):
"""When 'company' is not in Keycloak claims, company should be None."""
from server.routes.user import saas_get_user
@@ -161,3 +174,88 @@ async def test_fallback_user_company_is_none_when_absent(
assert isinstance(result, User)
assert result.company is None
@pytest.mark.asyncio
async def test_fallback_user_email_from_db_when_available(
mock_token_manager, mock_check_idp, mock_user_store
):
"""When User.email is stored in DB, use it instead of Keycloak's live email."""
from server.routes.user import saas_get_user
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': '248289761001',
'preferred_username': 'j.doe',
'email': 'keycloak@example.com',
}
)
mock_db_user = MagicMock()
mock_db_user.email = 'db@example.com'
mock_user_store.return_value = mock_db_user
result = await saas_get_user(
provider_tokens=None,
access_token=SecretStr('test-token'),
user_id='248289761001',
)
assert isinstance(result, User)
assert result.email == 'db@example.com'
@pytest.mark.asyncio
async def test_fallback_user_email_falls_back_to_keycloak_when_db_null(
mock_token_manager, mock_check_idp, mock_user_store
):
"""When User.email is NULL in DB, fall back to Keycloak's email."""
from server.routes.user import saas_get_user
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': '248289761001',
'preferred_username': 'j.doe',
'email': 'keycloak@example.com',
}
)
mock_db_user = MagicMock()
mock_db_user.email = None
mock_user_store.return_value = mock_db_user
result = await saas_get_user(
provider_tokens=None,
access_token=SecretStr('test-token'),
user_id='248289761001',
)
assert isinstance(result, User)
assert result.email == 'keycloak@example.com'
@pytest.mark.asyncio
async def test_fallback_user_email_falls_back_to_keycloak_when_no_db_user(
mock_token_manager, mock_check_idp, mock_user_store
):
"""When DB user doesn't exist, fall back to Keycloak's email."""
from server.routes.user import saas_get_user
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': '248289761001',
'preferred_username': 'j.doe',
'email': 'keycloak@example.com',
}
)
# mock_user_store already returns None by default
result = await saas_get_user(
provider_tokens=None,
access_token=SecretStr('test-token'),
user_id='248289761001',
)
assert isinstance(result, User)
assert result.email == 'keycloak@example.com'

View File

@@ -639,6 +639,204 @@ async def test_backfill_contact_name_preserves_custom_value(session_maker):
assert org.contact_name == 'Custom Corp Name'
# --- Tests for backfill_user_email on login ---
# Existing users created before the email capture fix may have NULL
# email in the User table. The backfill sets User.email from the IDP
# when the user next logs in, but preserves manual changes (non-NULL).
@pytest.mark.asyncio
async def test_backfill_user_email_sets_email_when_null(session_maker):
"""When User.email is NULL, backfill_user_email should set it from user_info."""
user_id = str(uuid.uuid4())
with session_maker() as session:
org = Org(
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_email='jdoe@example.com',
)
session.add(org)
user = User(
id=uuid.UUID(user_id),
current_org_id=org.id,
email=None,
email_verified=None,
)
session.add(user)
session.commit()
user_info = {
'email': 'jdoe@example.com',
'email_verified': True,
}
with patch(
'storage.user_store.a_session_maker',
_wrap_sync_as_async_session_maker(session_maker),
):
await UserStore.backfill_user_email(user_id, user_info)
with session_maker() as session:
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
assert user.email == 'jdoe@example.com'
assert user.email_verified is True
@pytest.mark.asyncio
async def test_backfill_user_email_does_not_overwrite_existing(session_maker):
"""When User.email is already set, backfill_user_email should NOT overwrite it."""
user_id = str(uuid.uuid4())
with session_maker() as session:
org = Org(
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_email='original@example.com',
)
session.add(org)
user = User(
id=uuid.UUID(user_id),
current_org_id=org.id,
email='custom@example.com',
email_verified=True,
)
session.add(user)
session.commit()
user_info = {
'email': 'different@example.com',
'email_verified': False,
}
with patch(
'storage.user_store.a_session_maker',
_wrap_sync_as_async_session_maker(session_maker),
):
await UserStore.backfill_user_email(user_id, user_info)
with session_maker() as session:
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
assert user.email == 'custom@example.com'
assert user.email_verified is True
@pytest.mark.asyncio
async def test_backfill_user_email_sets_verified_when_null(session_maker):
"""When User.email is set but email_verified is NULL, backfill should set email_verified."""
user_id = str(uuid.uuid4())
with session_maker() as session:
org = Org(
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_email='jdoe@example.com',
)
session.add(org)
user = User(
id=uuid.UUID(user_id),
current_org_id=org.id,
email='jdoe@example.com',
email_verified=None,
)
session.add(user)
session.commit()
user_info = {
'email': 'different@example.com',
'email_verified': True,
}
with patch(
'storage.user_store.a_session_maker',
_wrap_sync_as_async_session_maker(session_maker),
):
await UserStore.backfill_user_email(user_id, user_info)
with session_maker() as session:
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
# email should NOT be overwritten since it's non-NULL
assert user.email == 'jdoe@example.com'
# email_verified should be set since it was NULL
assert user.email_verified is True
@pytest.mark.asyncio
async def test_create_user_sets_email_verified_false_from_user_info():
"""When user_info has email_verified=False, create_user() should set User.email_verified=False."""
user_id = str(uuid.uuid4())
user_info = {
'preferred_username': 'jsmith',
'email': 'jsmith@example.com',
'email_verified': False,
}
mock_session = MagicMock()
mock_sm = MagicMock()
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
mock_settings = Settings(
language='en',
llm_api_key=SecretStr('test-key'),
llm_base_url='http://test.url',
)
mock_role = MagicMock()
mock_role.id = 1
with (
patch('storage.user_store.session_maker', mock_sm),
patch.object(
UserStore,
'create_default_settings',
new_callable=AsyncMock,
return_value=mock_settings,
),
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
patch(
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
return_value={'llm_model': None, 'llm_base_url': None},
),
):
mock_session.commit.side_effect = _StopAfterUserCreation
with pytest.raises(_StopAfterUserCreation):
await UserStore.create_user(user_id, user_info)
user = mock_session.add.call_args_list[1][0][0]
assert isinstance(user, User)
assert user.email == 'jsmith@example.com'
assert user.email_verified is False
@pytest.mark.asyncio
async def test_create_user_preserves_org_contact_email():
"""create_user() must still set Org.contact_email (no regression)."""
user_id = str(uuid.uuid4())
user_info = {
'preferred_username': 'jdoe',
'email': 'jdoe@example.com',
'email_verified': True,
}
mock_session = MagicMock()
mock_sm = MagicMock()
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
with (
patch('storage.user_store.session_maker', mock_sm),
patch.object(
UserStore,
'create_default_settings',
new_callable=AsyncMock,
return_value=None,
),
):
await UserStore.create_user(user_id, user_info)
org = mock_session.add.call_args_list[0][0][0]
assert isinstance(org, Org)
assert org.contact_email == 'jdoe@example.com'
def test_update_current_org_success(session_maker):
"""
GIVEN: User exists in database
@@ -680,3 +878,100 @@ def test_update_current_org_user_not_found(session_maker):
# Assert
assert result is None
# --- Tests for update_user_email ---
# update_user_email() should unconditionally overwrite User.email and/or email_verified.
# Unlike backfill_user_email(), it does not check for NULL before writing.
@pytest.mark.asyncio
async def test_update_user_email_overwrites_existing(session_maker):
"""update_user_email() should overwrite existing email and email_verified values."""
user_id = str(uuid.uuid4())
with session_maker() as session:
org = Org(
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_email='old@example.com',
)
session.add(org)
user = User(
id=uuid.UUID(user_id),
current_org_id=org.id,
email='old@example.com',
email_verified=True,
)
session.add(user)
session.commit()
with patch(
'storage.user_store.a_session_maker',
_wrap_sync_as_async_session_maker(session_maker),
):
await UserStore.update_user_email(
user_id, email='new@example.com', email_verified=False
)
with session_maker() as session:
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
assert user.email == 'new@example.com'
assert user.email_verified is False
@pytest.mark.asyncio
async def test_update_user_email_updates_only_email_verified(session_maker):
"""update_user_email() with email=None should only update email_verified."""
user_id = str(uuid.uuid4())
with session_maker() as session:
org = Org(
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_email='keep@example.com',
)
session.add(org)
user = User(
id=uuid.UUID(user_id),
current_org_id=org.id,
email='keep@example.com',
email_verified=False,
)
session.add(user)
session.commit()
with patch(
'storage.user_store.a_session_maker',
_wrap_sync_as_async_session_maker(session_maker),
):
await UserStore.update_user_email(user_id, email_verified=True)
with session_maker() as session:
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
assert user.email == 'keep@example.com'
assert user.email_verified is True
@pytest.mark.asyncio
async def test_update_user_email_noop_when_both_none():
"""update_user_email() with both args None should not open a session."""
user_id = str(uuid.uuid4())
mock_session_maker = MagicMock()
with patch('storage.user_store.a_session_maker', mock_session_maker):
await UserStore.update_user_email(user_id, email=None, email_verified=None)
mock_session_maker.assert_not_called()
@pytest.mark.asyncio
async def test_update_user_email_missing_user_returns_without_error(session_maker):
"""update_user_email() with a non-existent user_id should return without error."""
user_id = str(uuid.uuid4())
with patch(
'storage.user_store.a_session_maker',
_wrap_sync_as_async_session_maker(session_maker),
):
await UserStore.update_user_email(
user_id, email='new@example.com', email_verified=False
)