import time from unittest.mock import AsyncMock, MagicMock, patch import jwt import pytest from fastapi import Request from pydantic import SecretStr from server.auth.auth_error import ( AuthError, BearerTokenError, CookieError, NoCredentialsError, ) from server.auth.saas_user_auth import ( SaasUserAuth, get_api_key_from_header, saas_user_auth_from_bearer, saas_user_auth_from_cookie, saas_user_auth_from_signed_token, ) from openhands.integrations.provider import ProviderToken, ProviderType @pytest.fixture def mock_request(): request = MagicMock(spec=Request) request.headers = {} request.cookies = {} return request @pytest.fixture def mock_token_manager(): with patch('server.auth.saas_user_auth.token_manager') as mock_tm: mock_tm.refresh = AsyncMock( return_value={ 'access_token': 'new_access_token', 'refresh_token': 'new_refresh_token', } ) mock_tm.get_user_info_from_user_id = AsyncMock( return_value={ 'federatedIdentities': [ { 'identityProvider': 'github', 'userId': 'github_user_id', } ] } ) mock_tm.get_idp_token = AsyncMock(return_value='github_token') yield mock_tm @pytest.fixture def mock_config(): with patch('server.auth.saas_user_auth.get_config') as mock_get_config: mock_cfg = mock_get_config.return_value mock_cfg.jwt_secret.get_secret_value.return_value = 'test_secret' yield mock_cfg @pytest.mark.asyncio async def test_get_user_id(): """Test that get_user_id returns the user_id.""" user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr('refresh_token'), ) user_id = await user_auth.get_user_id() assert user_id == 'test_user_id' @pytest.mark.asyncio async def test_get_user_email(): """Test that get_user_email returns the email.""" user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr('refresh_token'), email='test@example.com', ) email = await user_auth.get_user_email() assert email == 'test@example.com' @pytest.mark.asyncio async def test_refresh(mock_token_manager): """Test that refresh updates the tokens.""" refresh_token = jwt.encode( { 'sub': 'test_user_id', 'exp': int(time.time()) + 3600, }, 'secret', algorithm='HS256', ) user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr(refresh_token), ) await user_auth.refresh() mock_token_manager.refresh.assert_called_once_with(refresh_token) assert user_auth.access_token.get_secret_value() == 'new_access_token' assert user_auth.refresh_token.get_secret_value() == 'new_refresh_token' assert user_auth.refreshed is True @pytest.mark.asyncio async def test_get_access_token_with_existing_valid_token(mock_token_manager): """Test that get_access_token returns the existing token if it's valid.""" # Create a valid JWT token that expires in the future payload = { 'sub': 'test_user_id', 'exp': int(time.time()) + 3600, # Expires in 1 hour } access_token = jwt.encode(payload, 'secret', algorithm='HS256') user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr('refresh_token'), access_token=SecretStr(access_token), ) result = await user_auth.get_access_token() assert result.get_secret_value() == access_token mock_token_manager.refresh.assert_not_called() @pytest.mark.asyncio async def test_get_access_token_with_expired_token(mock_token_manager): """Test that get_access_token refreshes the token if it's expired.""" # Create expired access token and valid refresh token access_token, refresh_token = ( jwt.encode( { 'sub': 'test_user_id', 'exp': int(time.time()) + exp, }, 'secret', algorithm='HS256', ) for exp in [-3600, 3600] ) user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr(refresh_token), access_token=SecretStr(access_token), ) result = await user_auth.get_access_token() assert result.get_secret_value() == 'new_access_token' mock_token_manager.refresh.assert_called_once_with(refresh_token) @pytest.mark.asyncio async def test_get_access_token_with_no_token(mock_token_manager): """Test that get_access_token refreshes when no token exists.""" refresh_token = jwt.encode( { 'sub': 'test_user_id', 'exp': int(time.time()) + 3600, }, 'secret', algorithm='HS256', ) user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr(refresh_token), ) result = await user_auth.get_access_token() assert result.get_secret_value() == 'new_access_token' mock_token_manager.refresh.assert_called_once_with(refresh_token) @pytest.mark.asyncio async def test_get_provider_tokens(mock_token_manager): """Test that get_provider_tokens fetches provider tokens.""" """ # Create a valid JWT token payload = { 'sub': 'test_user_id', 'exp': int(time.time()) + 3600, # Expires in 1 hour } access_token = jwt.encode(payload, 'secret', algorithm='HS256') user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr('refresh_token'), access_token=SecretStr(access_token), ) result = await user_auth.get_provider_tokens() assert ProviderType.GITHUB in result assert result[ProviderType.GITHUB].token.get_secret_value() == 'github_token' assert result[ProviderType.GITHUB].user_id == 'github_user_id' mock_token_manager.get_user_info_from_user_id.assert_called_once_with( 'test_user_id' ) mock_token_manager.get_idp_token.assert_called_once_with( access_token, idp=ProviderType.GITHUB ) """ pass @pytest.mark.asyncio async def test_get_provider_tokens_cached(mock_token_manager): """Test that get_provider_tokens returns cached tokens if available.""" user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr('refresh_token'), provider_tokens={ ProviderType.GITHUB: ProviderToken( token=SecretStr('cached_github_token'), user_id='github_user_id', ) }, ) result = await user_auth.get_provider_tokens() assert ProviderType.GITHUB in result assert result[ProviderType.GITHUB].token.get_secret_value() == 'cached_github_token' mock_token_manager.get_user_info_from_user_id.assert_not_called() mock_token_manager.get_idp_token.assert_not_called() @pytest.mark.asyncio async def test_get_user_settings_store(): """Test that get_user_settings_store returns a settings store.""" with patch('server.auth.saas_user_auth.SaasSettingsStore') as mock_store_cls: mock_store = MagicMock() mock_store_cls.return_value = mock_store user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr('refresh_token'), ) result = await user_auth.get_user_settings_store() assert result == mock_store mock_store_cls.assert_called_once() assert user_auth.settings_store == mock_store @pytest.mark.asyncio async def test_get_user_settings_store_cached(): """Test that get_user_settings_store returns cached store if available.""" mock_store = MagicMock() user_auth = SaasUserAuth( user_id='test_user_id', refresh_token=SecretStr('refresh_token'), settings_store=mock_store, ) result = await user_auth.get_user_settings_store() assert result == mock_store @pytest.mark.asyncio async def test_get_instance_from_bearer(mock_request): """Test that get_instance returns auth from bearer token.""" with patch( 'server.auth.saas_user_auth.saas_user_auth_from_bearer' ) as mock_from_bearer: mock_auth = MagicMock() mock_from_bearer.return_value = mock_auth result = await SaasUserAuth.get_instance(mock_request) assert result == mock_auth mock_from_bearer.assert_called_once_with(mock_request) @pytest.mark.asyncio async def test_get_instance_from_cookie(mock_request): """Test that get_instance returns auth from cookie if bearer fails.""" with ( patch( 'server.auth.saas_user_auth.saas_user_auth_from_bearer' ) as mock_from_bearer, patch( 'server.auth.saas_user_auth.saas_user_auth_from_cookie' ) as mock_from_cookie, ): mock_from_bearer.return_value = None mock_auth = MagicMock() mock_from_cookie.return_value = mock_auth result = await SaasUserAuth.get_instance(mock_request) assert result == mock_auth mock_from_bearer.assert_called_once_with(mock_request) mock_from_cookie.assert_called_once_with(mock_request) @pytest.mark.asyncio async def test_get_instance_no_auth(mock_request): """Test that get_instance raises NoCredentialsError if no auth is found.""" with ( patch( 'server.auth.saas_user_auth.saas_user_auth_from_bearer' ) as mock_from_bearer, patch( 'server.auth.saas_user_auth.saas_user_auth_from_cookie' ) as mock_from_cookie, ): mock_from_bearer.return_value = None mock_from_cookie.return_value = None with pytest.raises(NoCredentialsError): await SaasUserAuth.get_instance(mock_request) mock_from_bearer.assert_called_once_with(mock_request) mock_from_cookie.assert_called_once_with(mock_request) @pytest.mark.asyncio async def test_saas_user_auth_from_bearer_success(): """Test successful authentication from bearer token.""" mock_request = MagicMock() mock_request.headers = {'Authorization': 'Bearer test_api_key'} with ( patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls, patch('server.auth.saas_user_auth.token_manager') as mock_token_manager, ): mock_api_key_store = MagicMock() mock_api_key_store.validate_api_key.return_value = 'test_user_id' mock_api_key_store_cls.get_instance.return_value = mock_api_key_store mock_token_manager.load_offline_token = AsyncMock(return_value='offline_token') result = await saas_user_auth_from_bearer(mock_request) assert isinstance(result, SaasUserAuth) assert result.user_id == 'test_user_id' assert result.refresh_token.get_secret_value() == 'offline_token' mock_api_key_store.validate_api_key.assert_called_once_with('test_api_key') mock_token_manager.load_offline_token.assert_called_once_with('test_user_id') @pytest.mark.asyncio async def test_saas_user_auth_from_bearer_no_auth_header(): """Test that saas_user_auth_from_bearer returns None if no auth header.""" mock_request = MagicMock() mock_request.headers = {} result = await saas_user_auth_from_bearer(mock_request) assert result is None @pytest.mark.asyncio async def test_saas_user_auth_from_bearer_invalid_api_key(): """Test that saas_user_auth_from_bearer returns None if API key is invalid.""" mock_request = MagicMock() mock_request.headers = {'Authorization': 'Bearer test_api_key'} with patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls: mock_api_key_store = MagicMock() mock_api_key_store.validate_api_key.return_value = None mock_api_key_store_cls.get_instance.return_value = mock_api_key_store result = await saas_user_auth_from_bearer(mock_request) assert result is None mock_api_key_store.validate_api_key.assert_called_once_with('test_api_key') @pytest.mark.asyncio async def test_saas_user_auth_from_bearer_exception(): """Test that saas_user_auth_from_bearer raises BearerTokenError on exception.""" mock_request = MagicMock() mock_request.headers = {'Authorization': 'Bearer test_api_key'} with patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls: mock_api_key_store_cls.get_instance.side_effect = Exception('Test error') with pytest.raises(BearerTokenError): await saas_user_auth_from_bearer(mock_request) @pytest.mark.asyncio async def test_saas_user_auth_from_cookie_success(mock_config): """Test successful authentication from cookie.""" # Create a signed token payload = { 'access_token': 'test_access_token', 'refresh_token': 'test_refresh_token', } signed_token = jwt.encode(payload, 'test_secret', algorithm='HS256') mock_request = MagicMock() mock_request.cookies = {'keycloak_auth': signed_token} with patch( 'server.auth.saas_user_auth.saas_user_auth_from_signed_token' ) as mock_from_signed: mock_auth = MagicMock() mock_from_signed.return_value = mock_auth result = await saas_user_auth_from_cookie(mock_request) assert result == mock_auth mock_from_signed.assert_called_once_with(signed_token) @pytest.mark.asyncio async def test_saas_user_auth_from_cookie_no_cookie(): """Test that saas_user_auth_from_cookie returns None if no cookie.""" mock_request = MagicMock() mock_request.cookies = {} result = await saas_user_auth_from_cookie(mock_request) assert result is None @pytest.mark.asyncio async def test_saas_user_auth_from_cookie_exception(): """Test that saas_user_auth_from_cookie raises CookieError on exception.""" mock_request = MagicMock() mock_request.cookies = {'keycloak_auth': 'invalid_token'} with pytest.raises(CookieError): await saas_user_auth_from_cookie(mock_request) @pytest.mark.asyncio async def test_saas_user_auth_from_signed_token(mock_config): """Test successful creation of SaasUserAuth from signed token.""" # Create a JWT access token access_payload = { 'sub': 'test_user_id', 'exp': int(time.time()) + 3600, 'email': 'test@example.com', 'email_verified': True, } access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256') # Create a signed token containing the access and refresh tokens token_payload = { 'access_token': access_token, 'refresh_token': 'test_refresh_token', } signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') result = await saas_user_auth_from_signed_token(signed_token) assert isinstance(result, SaasUserAuth) assert result.user_id == 'test_user_id' assert result.access_token.get_secret_value() == access_token assert result.refresh_token.get_secret_value() == 'test_refresh_token' assert result.email == 'test@example.com' assert result.email_verified is True def test_get_api_key_from_header_with_authorization_header(): """Test that get_api_key_from_header extracts API key from Authorization header.""" # Create a mock request with Authorization header mock_request = MagicMock(spec=Request) mock_request.headers = {'Authorization': 'Bearer test_api_key'} # Call the function api_key = get_api_key_from_header(mock_request) # Assert that the API key was correctly extracted assert api_key == 'test_api_key' def test_get_api_key_from_header_with_x_session_api_key(): """Test that get_api_key_from_header extracts API key from X-Session-API-Key header.""" # Create a mock request with X-Session-API-Key header mock_request = MagicMock(spec=Request) mock_request.headers = {'X-Session-API-Key': 'session_api_key'} # Call the function api_key = get_api_key_from_header(mock_request) # Assert that the API key was correctly extracted assert api_key == 'session_api_key' def test_get_api_key_from_header_with_both_headers(): """Test that get_api_key_from_header prioritizes Authorization header when both are present.""" # Create a mock request with both headers mock_request = MagicMock(spec=Request) mock_request.headers = { 'Authorization': 'Bearer auth_api_key', 'X-Session-API-Key': 'session_api_key', } # Call the function api_key = get_api_key_from_header(mock_request) # Assert that the API key from Authorization header was used assert api_key == 'auth_api_key' def test_get_api_key_from_header_with_no_headers(): """Test that get_api_key_from_header returns None when no relevant headers are present.""" # Create a mock request with no relevant headers mock_request = MagicMock(spec=Request) mock_request.headers = {'Other-Header': 'some_value'} # Call the function api_key = get_api_key_from_header(mock_request) # Assert that None was returned assert api_key is None def test_get_api_key_from_header_with_invalid_authorization_format(): """Test that get_api_key_from_header handles Authorization headers without 'Bearer ' prefix.""" # Create a mock request with incorrectly formatted Authorization header mock_request = MagicMock(spec=Request) mock_request.headers = {'Authorization': 'InvalidFormat api_key'} # Call the function api_key = get_api_key_from_header(mock_request) # Assert that None was returned assert api_key is None def test_get_api_key_from_header_with_x_access_token(): """Test that get_api_key_from_header extracts API key from X-Access-Token header.""" # Create a mock request with X-Access-Token header mock_request = MagicMock(spec=Request) mock_request.headers = {'X-Access-Token': 'access_token_key'} # Call the function api_key = get_api_key_from_header(mock_request) # Assert that the API key was correctly extracted assert api_key == 'access_token_key' def test_get_api_key_from_header_priority_authorization_over_x_access_token(): """Test that Authorization header takes priority over X-Access-Token header.""" # Create a mock request with both headers mock_request = MagicMock(spec=Request) mock_request.headers = { 'Authorization': 'Bearer auth_api_key', 'X-Access-Token': 'access_token_key', } # Call the function api_key = get_api_key_from_header(mock_request) # Assert that the API key from Authorization header was used assert api_key == 'auth_api_key' def test_get_api_key_from_header_priority_x_session_over_x_access_token(): """Test that X-Session-API-Key header takes priority over X-Access-Token header.""" # Create a mock request with both headers mock_request = MagicMock(spec=Request) mock_request.headers = { 'X-Session-API-Key': 'session_api_key', 'X-Access-Token': 'access_token_key', } # Call the function api_key = get_api_key_from_header(mock_request) # Assert that the API key from X-Session-API-Key header was used assert api_key == 'session_api_key' def test_get_api_key_from_header_all_three_headers(): """Test header priority when all three headers are present.""" # Create a mock request with all three headers mock_request = MagicMock(spec=Request) mock_request.headers = { 'Authorization': 'Bearer auth_api_key', 'X-Session-API-Key': 'session_api_key', 'X-Access-Token': 'access_token_key', } # Call the function api_key = get_api_key_from_header(mock_request) # Assert that the API key from Authorization header was used (highest priority) assert api_key == 'auth_api_key' def test_get_api_key_from_header_invalid_authorization_fallback_to_x_access_token(): """Test that invalid Authorization header falls back to X-Access-Token.""" # Create a mock request with invalid Authorization header and X-Access-Token mock_request = MagicMock(spec=Request) mock_request.headers = { 'Authorization': 'InvalidFormat api_key', 'X-Access-Token': 'access_token_key', } # Call the function api_key = get_api_key_from_header(mock_request) # Assert that the API key from X-Access-Token header was used assert api_key == 'access_token_key' def test_get_api_key_from_header_empty_headers(): """Test that empty header values are handled correctly.""" # Create a mock request with empty header values mock_request = MagicMock(spec=Request) mock_request.headers = { 'Authorization': '', 'X-Session-API-Key': '', 'X-Access-Token': 'access_token_key', } # Call the function api_key = get_api_key_from_header(mock_request) # Assert that the API key from X-Access-Token header was used assert api_key == 'access_token_key' def test_get_api_key_from_header_bearer_with_empty_token(): """Test that Bearer header with empty token falls back to other headers.""" # Create a mock request with Bearer header with empty token mock_request = MagicMock(spec=Request) mock_request.headers = { 'Authorization': 'Bearer ', 'X-Access-Token': 'access_token_key', } # Call the function api_key = get_api_key_from_header(mock_request) # Assert that empty string from Bearer is returned (current behavior) # This tests the current implementation behavior assert api_key == '' @pytest.mark.asyncio async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config): """Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked.""" # Arrange access_payload = { 'sub': 'test_user_id', 'exp': int(time.time()) + 3600, 'email': 'user@colsch.us', 'email_verified': True, } access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256') token_payload = { 'access_token': access_token, 'refresh_token': 'test_refresh_token', } signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker: mock_domain_blocker.is_active.return_value = True mock_domain_blocker.is_domain_blocked.return_value = True # Act & Assert with pytest.raises(AuthError) as exc_info: await saas_user_auth_from_signed_token(signed_token) assert 'email domain is not allowed' in str(exc_info.value) mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us') @pytest.mark.asyncio async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config): """Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked.""" # Arrange access_payload = { 'sub': 'test_user_id', 'exp': int(time.time()) + 3600, 'email': 'user@example.com', 'email_verified': True, } access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256') token_payload = { 'access_token': access_token, 'refresh_token': 'test_refresh_token', } signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker: mock_domain_blocker.is_active.return_value = True mock_domain_blocker.is_domain_blocked.return_value = False # Act result = await saas_user_auth_from_signed_token(signed_token) # Assert assert isinstance(result, SaasUserAuth) assert result.user_id == 'test_user_id' assert result.email == 'user@example.com' mock_domain_blocker.is_domain_blocked.assert_called_once_with( 'user@example.com' ) @pytest.mark.asyncio async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config): """Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active.""" # Arrange access_payload = { 'sub': 'test_user_id', 'exp': int(time.time()) + 3600, 'email': 'user@colsch.us', 'email_verified': True, } access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256') token_payload = { 'access_token': access_token, 'refresh_token': 'test_refresh_token', } signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256') with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker: mock_domain_blocker.is_active.return_value = False # Act result = await saas_user_auth_from_signed_token(signed_token) # Assert assert isinstance(result, SaasUserAuth) assert result.user_id == 'test_user_id' mock_domain_blocker.is_domain_blocked.assert_not_called()