From d4b9fb1d034ae3aa37fc73224cb20f1c08617687 Mon Sep 17 00:00:00 2001 From: "sp.wack" <83104063+amanape@users.noreply.github.com> Date: Thu, 26 Feb 2026 17:29:30 +0400 Subject: [PATCH] fix(backend): user email capture (#12902) Co-authored-by: OpenHands Bot Co-authored-by: openhands --- enterprise/server/routes/auth.py | 1 + enterprise/server/routes/email.py | 6 + enterprise/server/routes/user.py | 13 +- enterprise/storage/user_store.py | 82 +++++ .../unit/server/routes/test_email_routes.py | 64 ++++ enterprise/tests/unit/test_auth_routes.py | 78 +++++ .../tests/unit/test_user_route_fallback.py | 110 ++++++- enterprise/tests/unit/test_user_store.py | 295 ++++++++++++++++++ 8 files changed, 641 insertions(+), 8 deletions(-) diff --git a/enterprise/server/routes/auth.py b/enterprise/server/routes/auth.py index a4dc968ebe..5460a37ea3 100644 --- a/enterprise/server/routes/auth.py +++ b/enterprise/server/routes/auth.py @@ -208,6 +208,7 @@ async def keycloak_callback( else: # Existing user — gradually backfill contact_name if it still has a username-style value await UserStore.backfill_contact_name(user_id, user_info) + await UserStore.backfill_user_email(user_id, user_info) if not user: logger.error(f'Failed to authenticate user {user_info["preferred_username"]}') diff --git a/enterprise/server/routes/email.py b/enterprise/server/routes/email.py index 8b31d3f171..a2fe613a5f 100644 --- a/enterprise/server/routes/email.py +++ b/enterprise/server/routes/email.py @@ -8,6 +8,7 @@ from server.auth.keycloak_manager import get_keycloak_admin from server.auth.saas_user_auth import SaasUserAuth from server.routes.auth import set_response_cookie from server.utils.rate_limit_utils import check_rate_limit_by_user_id +from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger from openhands.server.user_auth import get_user_id @@ -62,6 +63,10 @@ async def update_email( }, ) + await UserStore.update_user_email( + user_id=user_id, email=email, email_verified=False + ) + user_auth: SaasUserAuth = await get_user_auth(request) await user_auth.refresh() # refresh so access token has updated email user_auth.email = email @@ -144,6 +149,7 @@ async def verified_email(request: Request): user_auth: SaasUserAuth = await get_user_auth(request) await user_auth.refresh() # refresh so access token has updated email user_auth.email_verified = True + await UserStore.update_user_email(user_id=user_auth.user_id, email_verified=True) scheme = 'http' if request.url.hostname == 'localhost' else 'https' redirect_uri = f'{scheme}://{request.url.netloc}/settings/user' response = RedirectResponse(redirect_uri, status_code=302) diff --git a/enterprise/server/routes/user.py b/enterprise/server/routes/user.py index cadd787fb4..84de6c60cf 100644 --- a/enterprise/server/routes/user.py +++ b/enterprise/server/routes/user.py @@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, Query, status from fastapi.responses import JSONResponse from pydantic import SecretStr from server.auth.token_manager import TokenManager +from storage.user_store import UserStore from utils.identity import resolve_display_name from openhands.integrations.provider import ( @@ -115,13 +116,21 @@ async def saas_get_user( content='Failed to retrieve user_info.', status_code=status.HTTP_401_UNAUTHORIZED, ) + # Prefer email from DB; fall back to Keycloak if not yet persisted + email = user_info.get('email') if user_info else None + sub = user_info.get('sub') if user_info else '' + if sub: + db_user = await UserStore.get_user_by_id_async(sub) + if db_user and db_user.email is not None: + email = db_user.email + retval = await _check_idp( access_token=access_token, default_value=User( - id=(user_info.get('sub') if user_info else '') or '', + id=sub, login=(user_info.get('preferred_username') if user_info else '') or '', avatar_url='', - email=user_info.get('email') if user_info else None, + email=email, name=resolve_display_name(user_info) if user_info else None, company=user_info.get('company') if user_info else None, ), diff --git a/enterprise/storage/user_store.py b/enterprise/storage/user_store.py index 3a97321d49..224fd45ab1 100644 --- a/enterprise/storage/user_store.py +++ b/enterprise/storage/user_store.py @@ -869,6 +869,88 @@ class UserStore: org.contact_name = real_name await session.commit() + @staticmethod + async def update_user_email( + user_id: str, + email: str | None = None, + email_verified: bool | None = None, + ) -> None: + """Unconditionally update User.email and/or email_verified. + + Unlike backfill_user_email(), this overwrites existing values. + No-op when both arguments are None. + Missing user is logged as a warning and ignored. + """ + if email is None and email_verified is None: + return + + async with a_session_maker() as session: + result = await session.execute( + select(User).filter(User.id == uuid.UUID(user_id)) + ) + user = result.scalars().first() + if not user: + logger.warning( + 'update_user_email:user_not_found', + extra={'user_id': user_id}, + ) + return + + if email is not None: + user.email = email + if email_verified is not None: + user.email_verified = email_verified + + logger.info( + 'update_user_email:updated', + extra={ + 'user_id': user_id, + 'email_set': email is not None, + 'email_verified_set': email_verified is not None, + }, + ) + await session.commit() + + @staticmethod + async def backfill_user_email(user_id: str, user_info: dict) -> None: + """Set User.email and email_verified from IDP if they are still NULL. + + Called during login to gradually fix existing users whose email + was never persisted on the User record. Preserves non-NULL values + (e.g. if a user manually changed their email). + """ + async with a_session_maker() as session: + result = await session.execute( + select(User).filter(User.id == uuid.UUID(user_id)) + ) + user = result.scalars().first() + if not user: + logger.debug( + 'backfill_user_email:user_not_found', + extra={'user_id': user_id}, + ) + return + + updated = False + if user.email is None: + user.email = user_info.get('email') + updated = True + + if user.email_verified is None: + user.email_verified = user_info.get('email_verified', False) + updated = True + + if updated: + logger.info( + 'backfill_user_email:updated', + extra={ + 'user_id': user_id, + 'email_set': user.email is not None, + 'email_verified_set': user.email_verified is not None, + }, + ) + await session.commit() + # Prevent circular imports from typing import TYPE_CHECKING diff --git a/enterprise/tests/unit/server/routes/test_email_routes.py b/enterprise/tests/unit/server/routes/test_email_routes.py index cc8d5ac892..24b15385e2 100644 --- a/enterprise/tests/unit/server/routes/test_email_routes.py +++ b/enterprise/tests/unit/server/routes/test_email_routes.py @@ -6,8 +6,10 @@ from fastapi.responses import JSONResponse, RedirectResponse from pydantic import SecretStr from server.auth.saas_user_auth import SaasUserAuth from server.routes.email import ( + EmailUpdate, ResendEmailVerificationRequest, resend_email_verification, + update_email, verified_email, verify_email, ) @@ -116,12 +118,15 @@ async def test_verified_email_default_redirect(mock_request, mock_user_auth): """Test verified_email redirects to /settings/user by default.""" # Arrange mock_request.query_params.get.return_value = None + mock_user_auth.user_id = 'test-user-id' # Act with ( patch('server.routes.email.get_user_auth', return_value=mock_user_auth), patch('server.routes.email.set_response_cookie') as mock_set_cookie, + patch('server.routes.email.UserStore') as mock_user_store, ): + mock_user_store.update_user_email = AsyncMock() result = await verified_email(mock_request) # Assert @@ -140,12 +145,15 @@ async def test_verified_email_https_scheme(mock_request, mock_user_auth): mock_request.url.hostname = 'example.com' mock_request.url.netloc = 'example.com' mock_request.query_params.get.return_value = None + mock_user_auth.user_id = 'test-user-id' # Act with ( patch('server.routes.email.get_user_auth', return_value=mock_user_auth), patch('server.routes.email.set_response_cookie') as mock_set_cookie, + patch('server.routes.email.UserStore') as mock_user_store, ): + mock_user_store.update_user_email = AsyncMock() result = await verified_email(mock_request) # Assert @@ -327,6 +335,62 @@ async def test_resend_email_verification_with_is_auth_flow_false(mock_request): assert '/api/email/verified' in call_args.kwargs['redirect_uri'] +@pytest.mark.asyncio +async def test_update_email_calls_update_user_email(mock_request, mock_user_auth): + """POST /api/email should call UserStore.update_user_email with new email and email_verified=False.""" + user_id = 'test-user-id' + new_email = 'new@example.com' + email_data = EmailUpdate(email=new_email) + + mock_keycloak_admin = MagicMock() + mock_keycloak_admin.get_user.return_value = { + 'enabled': True, + 'username': 'testuser', + } + mock_keycloak_admin.a_update_user = AsyncMock() + mock_user_store = MagicMock() + mock_user_store.update_user_email = AsyncMock() + + with ( + patch( + 'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin + ), + patch('server.routes.email.get_user_auth', return_value=mock_user_auth), + patch('server.routes.email.set_response_cookie'), + patch('server.routes.email.verify_email', new_callable=AsyncMock), + patch('server.routes.email.UserStore', mock_user_store), + ): + result = await update_email( + email_data=email_data, request=mock_request, user_id=user_id + ) + + assert result.status_code == status.HTTP_200_OK + mock_user_store.update_user_email.assert_awaited_once_with( + user_id=user_id, email=new_email, email_verified=False + ) + + +@pytest.mark.asyncio +async def test_verified_email_calls_update_user_email(mock_request, mock_user_auth): + """GET /api/email/verified should call UserStore.update_user_email with email_verified=True.""" + mock_user_auth.user_id = 'test-user-id' + + mock_user_store = MagicMock() + mock_user_store.update_user_email = AsyncMock() + + with ( + patch('server.routes.email.get_user_auth', return_value=mock_user_auth), + patch('server.routes.email.set_response_cookie'), + patch('server.routes.email.UserStore', mock_user_store), + ): + result = await verified_email(mock_request) + + assert result.status_code == 302 + mock_user_store.update_user_email.assert_awaited_once_with( + user_id='test-user-id', email_verified=True + ) + + @pytest.mark.asyncio async def test_resend_email_verification_body_none_uses_auth(mock_request): """Test resend_email_verification uses auth when body is None.""" diff --git a/enterprise/tests/unit/test_auth_routes.py b/enterprise/tests/unit/test_auth_routes.py index 60b81a0e51..02e0583af7 100644 --- a/enterprise/tests/unit/test_auth_routes.py +++ b/enterprise/tests/unit/test_auth_routes.py @@ -154,6 +154,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request): mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.migrate_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = False @@ -190,6 +191,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request): mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.migrate_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_token_manager.get_keycloak_tokens = AsyncMock( return_value=('test_access_token', 'test_refresh_token') @@ -262,6 +264,7 @@ async def test_keycloak_callback_email_not_verified(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() # Act result = await keycloak_callback( @@ -310,6 +313,7 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() # Act result = await keycloak_callback( @@ -352,6 +356,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request): mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.migrate_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_token_manager.get_keycloak_tokens = AsyncMock( return_value=('test_access_token', 'test_refresh_token') @@ -587,6 +592,7 @@ async def test_keycloak_callback_blocked_email_domain(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_domain_blocker.is_active.return_value = True mock_domain_blocker.is_domain_blocked.return_value = True @@ -651,6 +657,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_domain_blocker.is_active.return_value = True mock_domain_blocker.is_domain_blocked.return_value = False @@ -715,6 +722,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_domain_blocker.is_active.return_value = False mock_domain_blocker.is_domain_blocked.return_value = False @@ -777,6 +785,7 @@ async def test_keycloak_callback_missing_email(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_domain_blocker.is_active.return_value = True @@ -823,6 +832,7 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() # Act result = await keycloak_callback( @@ -868,6 +878,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() # Act result = await keycloak_callback( @@ -926,6 +937,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -984,6 +996,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1045,6 +1058,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request): mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1202,6 +1216,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1267,6 +1282,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_domain_blocker.is_domain_blocked.return_value = False @@ -1350,6 +1366,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1438,6 +1455,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1523,6 +1541,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1607,6 +1626,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1688,6 +1708,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1755,6 +1776,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1828,6 +1850,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_verifier.is_active.return_value = True mock_verifier.is_user_allowed.return_value = True @@ -1899,6 +1922,7 @@ class TestKeycloakCallbackRecaptcha: mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) mock_user_store.create_user = AsyncMock(return_value=mock_user) mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() mock_domain_blocker.is_domain_blocked.return_value = False @@ -1918,3 +1942,57 @@ class TestKeycloakCallbackRecaptcha: assert call_kwargs[0][0] == 'recaptcha_blocked_at_callback' assert call_kwargs[1]['extra']['score'] == 0.2 assert call_kwargs[1]['extra']['user_id'] == 'test_user_id' + + +@pytest.mark.asyncio +async def test_keycloak_callback_calls_backfill_user_email_for_existing_user( + mock_request, +): + """When an existing user logs in, backfill_user_email should be called.""" + user_info = { + 'sub': 'test_user_id', + 'preferred_username': 'test_user', + 'identity_provider': 'github', + 'email': 'test@example.com', + 'email_verified': True, + } + + with ( + patch('server.routes.auth.token_manager') as mock_token_manager, + patch('server.routes.auth.user_verifier') as mock_verifier, + patch('server.routes.auth.set_response_cookie'), + patch('server.routes.auth.UserStore') as mock_user_store, + patch('server.routes.auth.posthog'), + ): + mock_user = MagicMock() + mock_user.id = 'test_user_id' + mock_user.current_org_id = 'test_org_id' + mock_user.accepted_tos = '2025-01-01' + + mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user) + mock_user_store.create_user = AsyncMock(return_value=mock_user) + mock_user_store.backfill_contact_name = AsyncMock() + mock_user_store.backfill_user_email = AsyncMock() + + mock_token_manager.get_keycloak_tokens = AsyncMock( + return_value=('test_access_token', 'test_refresh_token') + ) + mock_token_manager.get_user_info = AsyncMock(return_value=user_info) + mock_token_manager.store_idp_tokens = AsyncMock() + mock_token_manager.validate_offline_token = AsyncMock(return_value=True) + mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False) + + mock_verifier.is_active.return_value = True + mock_verifier.is_user_allowed.return_value = True + + result = await keycloak_callback( + code='test_code', state='test_state', request=mock_request + ) + + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + + # backfill_user_email should have been called with the user_id and user_info + mock_user_store.backfill_user_email.assert_called_once_with( + 'test_user_id', user_info + ) diff --git a/enterprise/tests/unit/test_user_route_fallback.py b/enterprise/tests/unit/test_user_route_fallback.py index 8efffbadd4..f51b73fbcb 100644 --- a/enterprise/tests/unit/test_user_route_fallback.py +++ b/enterprise/tests/unit/test_user_route_fallback.py @@ -5,7 +5,7 @@ the endpoint constructs a User from OIDC claims. These tests verify that name an fields are correctly populated from Keycloak claims in this fallback path. """ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import SecretStr @@ -33,9 +33,20 @@ def mock_check_idp(): yield mock_fn +@pytest.fixture +def mock_user_store(): + """Mock UserStore.get_user_by_id_async to return None by default.""" + with patch( + 'server.routes.user.UserStore.get_user_by_id_async', + new_callable=AsyncMock, + return_value=None, + ) as mock_fn: + yield mock_fn + + @pytest.mark.asyncio async def test_fallback_user_includes_name_from_name_claim( - mock_token_manager, mock_check_idp + mock_token_manager, mock_check_idp, mock_user_store ): """When Keycloak provides a 'name' claim, the fallback User should include it.""" from server.routes.user import saas_get_user @@ -62,7 +73,7 @@ async def test_fallback_user_includes_name_from_name_claim( @pytest.mark.asyncio async def test_fallback_user_combines_given_and_family_name( - mock_token_manager, mock_check_idp + mock_token_manager, mock_check_idp, mock_user_store ): """When 'name' is absent, combine given_name + family_name.""" from server.routes.user import saas_get_user @@ -89,7 +100,7 @@ async def test_fallback_user_combines_given_and_family_name( @pytest.mark.asyncio async def test_fallback_user_name_is_none_when_no_name_claims( - mock_token_manager, mock_check_idp + mock_token_manager, mock_check_idp, mock_user_store ): """When no name claims exist, name should be None.""" from server.routes.user import saas_get_user @@ -113,7 +124,9 @@ async def test_fallback_user_name_is_none_when_no_name_claims( @pytest.mark.asyncio -async def test_fallback_user_includes_company_claim(mock_token_manager, mock_check_idp): +async def test_fallback_user_includes_company_claim( + mock_token_manager, mock_check_idp, mock_user_store +): """When Keycloak provides a 'company' claim, include it in the User.""" from server.routes.user import saas_get_user @@ -139,7 +152,7 @@ async def test_fallback_user_includes_company_claim(mock_token_manager, mock_che @pytest.mark.asyncio async def test_fallback_user_company_is_none_when_absent( - mock_token_manager, mock_check_idp + mock_token_manager, mock_check_idp, mock_user_store ): """When 'company' is not in Keycloak claims, company should be None.""" from server.routes.user import saas_get_user @@ -161,3 +174,88 @@ async def test_fallback_user_company_is_none_when_absent( assert isinstance(result, User) assert result.company is None + + +@pytest.mark.asyncio +async def test_fallback_user_email_from_db_when_available( + mock_token_manager, mock_check_idp, mock_user_store +): + """When User.email is stored in DB, use it instead of Keycloak's live email.""" + from server.routes.user import saas_get_user + + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': '248289761001', + 'preferred_username': 'j.doe', + 'email': 'keycloak@example.com', + } + ) + + mock_db_user = MagicMock() + mock_db_user.email = 'db@example.com' + mock_user_store.return_value = mock_db_user + + result = await saas_get_user( + provider_tokens=None, + access_token=SecretStr('test-token'), + user_id='248289761001', + ) + + assert isinstance(result, User) + assert result.email == 'db@example.com' + + +@pytest.mark.asyncio +async def test_fallback_user_email_falls_back_to_keycloak_when_db_null( + mock_token_manager, mock_check_idp, mock_user_store +): + """When User.email is NULL in DB, fall back to Keycloak's email.""" + from server.routes.user import saas_get_user + + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': '248289761001', + 'preferred_username': 'j.doe', + 'email': 'keycloak@example.com', + } + ) + + mock_db_user = MagicMock() + mock_db_user.email = None + mock_user_store.return_value = mock_db_user + + result = await saas_get_user( + provider_tokens=None, + access_token=SecretStr('test-token'), + user_id='248289761001', + ) + + assert isinstance(result, User) + assert result.email == 'keycloak@example.com' + + +@pytest.mark.asyncio +async def test_fallback_user_email_falls_back_to_keycloak_when_no_db_user( + mock_token_manager, mock_check_idp, mock_user_store +): + """When DB user doesn't exist, fall back to Keycloak's email.""" + from server.routes.user import saas_get_user + + mock_token_manager.get_user_info = AsyncMock( + return_value={ + 'sub': '248289761001', + 'preferred_username': 'j.doe', + 'email': 'keycloak@example.com', + } + ) + + # mock_user_store already returns None by default + + result = await saas_get_user( + provider_tokens=None, + access_token=SecretStr('test-token'), + user_id='248289761001', + ) + + assert isinstance(result, User) + assert result.email == 'keycloak@example.com' diff --git a/enterprise/tests/unit/test_user_store.py b/enterprise/tests/unit/test_user_store.py index db6f1bb6ed..504d00066c 100644 --- a/enterprise/tests/unit/test_user_store.py +++ b/enterprise/tests/unit/test_user_store.py @@ -639,6 +639,204 @@ async def test_backfill_contact_name_preserves_custom_value(session_maker): assert org.contact_name == 'Custom Corp Name' +# --- Tests for backfill_user_email on login --- +# Existing users created before the email capture fix may have NULL +# email in the User table. The backfill sets User.email from the IDP +# when the user next logs in, but preserves manual changes (non-NULL). + + +@pytest.mark.asyncio +async def test_backfill_user_email_sets_email_when_null(session_maker): + """When User.email is NULL, backfill_user_email should set it from user_info.""" + user_id = str(uuid.uuid4()) + with session_maker() as session: + org = Org( + id=uuid.UUID(user_id), + name=f'user_{user_id}_org', + contact_email='jdoe@example.com', + ) + session.add(org) + user = User( + id=uuid.UUID(user_id), + current_org_id=org.id, + email=None, + email_verified=None, + ) + session.add(user) + session.commit() + + user_info = { + 'email': 'jdoe@example.com', + 'email_verified': True, + } + + with patch( + 'storage.user_store.a_session_maker', + _wrap_sync_as_async_session_maker(session_maker), + ): + await UserStore.backfill_user_email(user_id, user_info) + + with session_maker() as session: + user = session.query(User).filter(User.id == uuid.UUID(user_id)).first() + assert user.email == 'jdoe@example.com' + assert user.email_verified is True + + +@pytest.mark.asyncio +async def test_backfill_user_email_does_not_overwrite_existing(session_maker): + """When User.email is already set, backfill_user_email should NOT overwrite it.""" + user_id = str(uuid.uuid4()) + with session_maker() as session: + org = Org( + id=uuid.UUID(user_id), + name=f'user_{user_id}_org', + contact_email='original@example.com', + ) + session.add(org) + user = User( + id=uuid.UUID(user_id), + current_org_id=org.id, + email='custom@example.com', + email_verified=True, + ) + session.add(user) + session.commit() + + user_info = { + 'email': 'different@example.com', + 'email_verified': False, + } + + with patch( + 'storage.user_store.a_session_maker', + _wrap_sync_as_async_session_maker(session_maker), + ): + await UserStore.backfill_user_email(user_id, user_info) + + with session_maker() as session: + user = session.query(User).filter(User.id == uuid.UUID(user_id)).first() + assert user.email == 'custom@example.com' + assert user.email_verified is True + + +@pytest.mark.asyncio +async def test_backfill_user_email_sets_verified_when_null(session_maker): + """When User.email is set but email_verified is NULL, backfill should set email_verified.""" + user_id = str(uuid.uuid4()) + with session_maker() as session: + org = Org( + id=uuid.UUID(user_id), + name=f'user_{user_id}_org', + contact_email='jdoe@example.com', + ) + session.add(org) + user = User( + id=uuid.UUID(user_id), + current_org_id=org.id, + email='jdoe@example.com', + email_verified=None, + ) + session.add(user) + session.commit() + + user_info = { + 'email': 'different@example.com', + 'email_verified': True, + } + + with patch( + 'storage.user_store.a_session_maker', + _wrap_sync_as_async_session_maker(session_maker), + ): + await UserStore.backfill_user_email(user_id, user_info) + + with session_maker() as session: + user = session.query(User).filter(User.id == uuid.UUID(user_id)).first() + # email should NOT be overwritten since it's non-NULL + assert user.email == 'jdoe@example.com' + # email_verified should be set since it was NULL + assert user.email_verified is True + + +@pytest.mark.asyncio +async def test_create_user_sets_email_verified_false_from_user_info(): + """When user_info has email_verified=False, create_user() should set User.email_verified=False.""" + user_id = str(uuid.uuid4()) + user_info = { + 'preferred_username': 'jsmith', + 'email': 'jsmith@example.com', + 'email_verified': False, + } + + mock_session = MagicMock() + mock_sm = MagicMock() + mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_sm.return_value.__exit__ = MagicMock(return_value=False) + + mock_settings = Settings( + language='en', + llm_api_key=SecretStr('test-key'), + llm_base_url='http://test.url', + ) + + mock_role = MagicMock() + mock_role.id = 1 + + with ( + patch('storage.user_store.session_maker', mock_sm), + patch.object( + UserStore, + 'create_default_settings', + new_callable=AsyncMock, + return_value=mock_settings, + ), + patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role), + patch( + 'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings', + return_value={'llm_model': None, 'llm_base_url': None}, + ), + ): + mock_session.commit.side_effect = _StopAfterUserCreation + with pytest.raises(_StopAfterUserCreation): + await UserStore.create_user(user_id, user_info) + + user = mock_session.add.call_args_list[1][0][0] + assert isinstance(user, User) + assert user.email == 'jsmith@example.com' + assert user.email_verified is False + + +@pytest.mark.asyncio +async def test_create_user_preserves_org_contact_email(): + """create_user() must still set Org.contact_email (no regression).""" + user_id = str(uuid.uuid4()) + user_info = { + 'preferred_username': 'jdoe', + 'email': 'jdoe@example.com', + 'email_verified': True, + } + + mock_session = MagicMock() + mock_sm = MagicMock() + mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_sm.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch('storage.user_store.session_maker', mock_sm), + patch.object( + UserStore, + 'create_default_settings', + new_callable=AsyncMock, + return_value=None, + ), + ): + await UserStore.create_user(user_id, user_info) + + org = mock_session.add.call_args_list[0][0][0] + assert isinstance(org, Org) + assert org.contact_email == 'jdoe@example.com' + + def test_update_current_org_success(session_maker): """ GIVEN: User exists in database @@ -680,3 +878,100 @@ def test_update_current_org_user_not_found(session_maker): # Assert assert result is None + + +# --- Tests for update_user_email --- +# update_user_email() should unconditionally overwrite User.email and/or email_verified. +# Unlike backfill_user_email(), it does not check for NULL before writing. + + +@pytest.mark.asyncio +async def test_update_user_email_overwrites_existing(session_maker): + """update_user_email() should overwrite existing email and email_verified values.""" + user_id = str(uuid.uuid4()) + with session_maker() as session: + org = Org( + id=uuid.UUID(user_id), + name=f'user_{user_id}_org', + contact_email='old@example.com', + ) + session.add(org) + user = User( + id=uuid.UUID(user_id), + current_org_id=org.id, + email='old@example.com', + email_verified=True, + ) + session.add(user) + session.commit() + + with patch( + 'storage.user_store.a_session_maker', + _wrap_sync_as_async_session_maker(session_maker), + ): + await UserStore.update_user_email( + user_id, email='new@example.com', email_verified=False + ) + + with session_maker() as session: + user = session.query(User).filter(User.id == uuid.UUID(user_id)).first() + assert user.email == 'new@example.com' + assert user.email_verified is False + + +@pytest.mark.asyncio +async def test_update_user_email_updates_only_email_verified(session_maker): + """update_user_email() with email=None should only update email_verified.""" + user_id = str(uuid.uuid4()) + with session_maker() as session: + org = Org( + id=uuid.UUID(user_id), + name=f'user_{user_id}_org', + contact_email='keep@example.com', + ) + session.add(org) + user = User( + id=uuid.UUID(user_id), + current_org_id=org.id, + email='keep@example.com', + email_verified=False, + ) + session.add(user) + session.commit() + + with patch( + 'storage.user_store.a_session_maker', + _wrap_sync_as_async_session_maker(session_maker), + ): + await UserStore.update_user_email(user_id, email_verified=True) + + with session_maker() as session: + user = session.query(User).filter(User.id == uuid.UUID(user_id)).first() + assert user.email == 'keep@example.com' + assert user.email_verified is True + + +@pytest.mark.asyncio +async def test_update_user_email_noop_when_both_none(): + """update_user_email() with both args None should not open a session.""" + user_id = str(uuid.uuid4()) + mock_session_maker = MagicMock() + + with patch('storage.user_store.a_session_maker', mock_session_maker): + await UserStore.update_user_email(user_id, email=None, email_verified=None) + + mock_session_maker.assert_not_called() + + +@pytest.mark.asyncio +async def test_update_user_email_missing_user_returns_without_error(session_maker): + """update_user_email() with a non-existent user_id should return without error.""" + user_id = str(uuid.uuid4()) + + with patch( + 'storage.user_store.a_session_maker', + _wrap_sync_as_async_session_maker(session_maker), + ): + await UserStore.update_user_email( + user_id, email='new@example.com', email_verified=False + )