mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
749 lines
25 KiB
Python
749 lines
25 KiB
Python
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import jwt
|
|
import pytest
|
|
from fastapi import Request
|
|
from pydantic import SecretStr
|
|
from server.auth.auth_error import (
|
|
AuthError,
|
|
BearerTokenError,
|
|
CookieError,
|
|
NoCredentialsError,
|
|
)
|
|
from server.auth.saas_user_auth import (
|
|
SaasUserAuth,
|
|
get_api_key_from_header,
|
|
saas_user_auth_from_bearer,
|
|
saas_user_auth_from_cookie,
|
|
saas_user_auth_from_signed_token,
|
|
)
|
|
|
|
from openhands.integrations.provider import ProviderToken, ProviderType
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_request():
|
|
request = MagicMock(spec=Request)
|
|
request.headers = {}
|
|
request.cookies = {}
|
|
return request
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_token_manager():
|
|
with patch('server.auth.saas_user_auth.token_manager') as mock_tm:
|
|
mock_tm.refresh = AsyncMock(
|
|
return_value={
|
|
'access_token': 'new_access_token',
|
|
'refresh_token': 'new_refresh_token',
|
|
}
|
|
)
|
|
mock_tm.get_user_info_from_user_id = AsyncMock(
|
|
return_value={
|
|
'federatedIdentities': [
|
|
{
|
|
'identityProvider': 'github',
|
|
'userId': 'github_user_id',
|
|
}
|
|
]
|
|
}
|
|
)
|
|
mock_tm.get_idp_token = AsyncMock(return_value='github_token')
|
|
yield mock_tm
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_config():
|
|
with patch('server.auth.saas_user_auth.get_config') as mock_get_config:
|
|
mock_cfg = mock_get_config.return_value
|
|
mock_cfg.jwt_secret.get_secret_value.return_value = 'test_secret'
|
|
yield mock_cfg
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_id():
|
|
"""Test that get_user_id returns the user_id."""
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr('refresh_token'),
|
|
)
|
|
|
|
user_id = await user_auth.get_user_id()
|
|
|
|
assert user_id == 'test_user_id'
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_email():
|
|
"""Test that get_user_email returns the email."""
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr('refresh_token'),
|
|
email='test@example.com',
|
|
)
|
|
|
|
email = await user_auth.get_user_email()
|
|
|
|
assert email == 'test@example.com'
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh(mock_token_manager):
|
|
"""Test that refresh updates the tokens."""
|
|
refresh_token = jwt.encode(
|
|
{
|
|
'sub': 'test_user_id',
|
|
'exp': int(time.time()) + 3600,
|
|
},
|
|
'secret',
|
|
algorithm='HS256',
|
|
)
|
|
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr(refresh_token),
|
|
)
|
|
|
|
await user_auth.refresh()
|
|
|
|
mock_token_manager.refresh.assert_called_once_with(refresh_token)
|
|
assert user_auth.access_token.get_secret_value() == 'new_access_token'
|
|
assert user_auth.refresh_token.get_secret_value() == 'new_refresh_token'
|
|
assert user_auth.refreshed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_access_token_with_existing_valid_token(mock_token_manager):
|
|
"""Test that get_access_token returns the existing token if it's valid."""
|
|
# Create a valid JWT token that expires in the future
|
|
payload = {
|
|
'sub': 'test_user_id',
|
|
'exp': int(time.time()) + 3600, # Expires in 1 hour
|
|
}
|
|
access_token = jwt.encode(payload, 'secret', algorithm='HS256')
|
|
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr('refresh_token'),
|
|
access_token=SecretStr(access_token),
|
|
)
|
|
|
|
result = await user_auth.get_access_token()
|
|
|
|
assert result.get_secret_value() == access_token
|
|
mock_token_manager.refresh.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_access_token_with_expired_token(mock_token_manager):
|
|
"""Test that get_access_token refreshes the token if it's expired."""
|
|
# Create expired access token and valid refresh token
|
|
access_token, refresh_token = (
|
|
jwt.encode(
|
|
{
|
|
'sub': 'test_user_id',
|
|
'exp': int(time.time()) + exp,
|
|
},
|
|
'secret',
|
|
algorithm='HS256',
|
|
)
|
|
for exp in [-3600, 3600]
|
|
)
|
|
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr(refresh_token),
|
|
access_token=SecretStr(access_token),
|
|
)
|
|
|
|
result = await user_auth.get_access_token()
|
|
|
|
assert result.get_secret_value() == 'new_access_token'
|
|
mock_token_manager.refresh.assert_called_once_with(refresh_token)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_access_token_with_no_token(mock_token_manager):
|
|
"""Test that get_access_token refreshes when no token exists."""
|
|
refresh_token = jwt.encode(
|
|
{
|
|
'sub': 'test_user_id',
|
|
'exp': int(time.time()) + 3600,
|
|
},
|
|
'secret',
|
|
algorithm='HS256',
|
|
)
|
|
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr(refresh_token),
|
|
)
|
|
|
|
result = await user_auth.get_access_token()
|
|
|
|
assert result.get_secret_value() == 'new_access_token'
|
|
mock_token_manager.refresh.assert_called_once_with(refresh_token)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_provider_tokens(mock_token_manager):
|
|
"""Test that get_provider_tokens fetches provider tokens."""
|
|
"""
|
|
# Create a valid JWT token
|
|
payload = {
|
|
'sub': 'test_user_id',
|
|
'exp': int(time.time()) + 3600, # Expires in 1 hour
|
|
}
|
|
access_token = jwt.encode(payload, 'secret', algorithm='HS256')
|
|
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr('refresh_token'),
|
|
access_token=SecretStr(access_token),
|
|
)
|
|
|
|
result = await user_auth.get_provider_tokens()
|
|
|
|
assert ProviderType.GITHUB in result
|
|
assert result[ProviderType.GITHUB].token.get_secret_value() == 'github_token'
|
|
assert result[ProviderType.GITHUB].user_id == 'github_user_id'
|
|
mock_token_manager.get_user_info_from_user_id.assert_called_once_with(
|
|
'test_user_id'
|
|
)
|
|
mock_token_manager.get_idp_token.assert_called_once_with(
|
|
access_token, idp=ProviderType.GITHUB
|
|
)
|
|
"""
|
|
pass
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_provider_tokens_cached(mock_token_manager):
|
|
"""Test that get_provider_tokens returns cached tokens if available."""
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr('refresh_token'),
|
|
provider_tokens={
|
|
ProviderType.GITHUB: ProviderToken(
|
|
token=SecretStr('cached_github_token'),
|
|
user_id='github_user_id',
|
|
)
|
|
},
|
|
)
|
|
|
|
result = await user_auth.get_provider_tokens()
|
|
|
|
assert ProviderType.GITHUB in result
|
|
assert result[ProviderType.GITHUB].token.get_secret_value() == 'cached_github_token'
|
|
mock_token_manager.get_user_info_from_user_id.assert_not_called()
|
|
mock_token_manager.get_idp_token.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_settings_store():
|
|
"""Test that get_user_settings_store returns a settings store."""
|
|
with patch('server.auth.saas_user_auth.SaasSettingsStore') as mock_store_cls:
|
|
mock_store = MagicMock()
|
|
mock_store_cls.return_value = mock_store
|
|
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr('refresh_token'),
|
|
)
|
|
|
|
result = await user_auth.get_user_settings_store()
|
|
|
|
assert result == mock_store
|
|
mock_store_cls.assert_called_once()
|
|
assert user_auth.settings_store == mock_store
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_settings_store_cached():
|
|
"""Test that get_user_settings_store returns cached store if available."""
|
|
mock_store = MagicMock()
|
|
|
|
user_auth = SaasUserAuth(
|
|
user_id='test_user_id',
|
|
refresh_token=SecretStr('refresh_token'),
|
|
settings_store=mock_store,
|
|
)
|
|
|
|
result = await user_auth.get_user_settings_store()
|
|
|
|
assert result == mock_store
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_instance_from_bearer(mock_request):
|
|
"""Test that get_instance returns auth from bearer token."""
|
|
with patch(
|
|
'server.auth.saas_user_auth.saas_user_auth_from_bearer'
|
|
) as mock_from_bearer:
|
|
mock_auth = MagicMock()
|
|
mock_from_bearer.return_value = mock_auth
|
|
|
|
result = await SaasUserAuth.get_instance(mock_request)
|
|
|
|
assert result == mock_auth
|
|
mock_from_bearer.assert_called_once_with(mock_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_instance_from_cookie(mock_request):
|
|
"""Test that get_instance returns auth from cookie if bearer fails."""
|
|
with (
|
|
patch(
|
|
'server.auth.saas_user_auth.saas_user_auth_from_bearer'
|
|
) as mock_from_bearer,
|
|
patch(
|
|
'server.auth.saas_user_auth.saas_user_auth_from_cookie'
|
|
) as mock_from_cookie,
|
|
):
|
|
mock_from_bearer.return_value = None
|
|
mock_auth = MagicMock()
|
|
mock_from_cookie.return_value = mock_auth
|
|
|
|
result = await SaasUserAuth.get_instance(mock_request)
|
|
|
|
assert result == mock_auth
|
|
mock_from_bearer.assert_called_once_with(mock_request)
|
|
mock_from_cookie.assert_called_once_with(mock_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_instance_no_auth(mock_request):
|
|
"""Test that get_instance raises NoCredentialsError if no auth is found."""
|
|
with (
|
|
patch(
|
|
'server.auth.saas_user_auth.saas_user_auth_from_bearer'
|
|
) as mock_from_bearer,
|
|
patch(
|
|
'server.auth.saas_user_auth.saas_user_auth_from_cookie'
|
|
) as mock_from_cookie,
|
|
):
|
|
mock_from_bearer.return_value = None
|
|
mock_from_cookie.return_value = None
|
|
|
|
with pytest.raises(NoCredentialsError):
|
|
await SaasUserAuth.get_instance(mock_request)
|
|
|
|
mock_from_bearer.assert_called_once_with(mock_request)
|
|
mock_from_cookie.assert_called_once_with(mock_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_bearer_success():
|
|
"""Test successful authentication from bearer token."""
|
|
mock_request = MagicMock()
|
|
mock_request.headers = {'Authorization': 'Bearer test_api_key'}
|
|
|
|
with (
|
|
patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls,
|
|
patch('server.auth.saas_user_auth.token_manager') as mock_token_manager,
|
|
):
|
|
mock_api_key_store = MagicMock()
|
|
mock_api_key_store.validate_api_key.return_value = 'test_user_id'
|
|
mock_api_key_store_cls.get_instance.return_value = mock_api_key_store
|
|
|
|
mock_token_manager.load_offline_token = AsyncMock(return_value='offline_token')
|
|
|
|
result = await saas_user_auth_from_bearer(mock_request)
|
|
|
|
assert isinstance(result, SaasUserAuth)
|
|
assert result.user_id == 'test_user_id'
|
|
assert result.refresh_token.get_secret_value() == 'offline_token'
|
|
mock_api_key_store.validate_api_key.assert_called_once_with('test_api_key')
|
|
mock_token_manager.load_offline_token.assert_called_once_with('test_user_id')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_bearer_no_auth_header():
|
|
"""Test that saas_user_auth_from_bearer returns None if no auth header."""
|
|
mock_request = MagicMock()
|
|
mock_request.headers = {}
|
|
|
|
result = await saas_user_auth_from_bearer(mock_request)
|
|
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_bearer_invalid_api_key():
|
|
"""Test that saas_user_auth_from_bearer returns None if API key is invalid."""
|
|
mock_request = MagicMock()
|
|
mock_request.headers = {'Authorization': 'Bearer test_api_key'}
|
|
|
|
with patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls:
|
|
mock_api_key_store = MagicMock()
|
|
mock_api_key_store.validate_api_key.return_value = None
|
|
mock_api_key_store_cls.get_instance.return_value = mock_api_key_store
|
|
|
|
result = await saas_user_auth_from_bearer(mock_request)
|
|
|
|
assert result is None
|
|
mock_api_key_store.validate_api_key.assert_called_once_with('test_api_key')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_bearer_exception():
|
|
"""Test that saas_user_auth_from_bearer raises BearerTokenError on exception."""
|
|
mock_request = MagicMock()
|
|
mock_request.headers = {'Authorization': 'Bearer test_api_key'}
|
|
|
|
with patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls:
|
|
mock_api_key_store_cls.get_instance.side_effect = Exception('Test error')
|
|
|
|
with pytest.raises(BearerTokenError):
|
|
await saas_user_auth_from_bearer(mock_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_cookie_success(mock_config):
|
|
"""Test successful authentication from cookie."""
|
|
# Create a signed token
|
|
payload = {
|
|
'access_token': 'test_access_token',
|
|
'refresh_token': 'test_refresh_token',
|
|
}
|
|
signed_token = jwt.encode(payload, 'test_secret', algorithm='HS256')
|
|
|
|
mock_request = MagicMock()
|
|
mock_request.cookies = {'keycloak_auth': signed_token}
|
|
|
|
with patch(
|
|
'server.auth.saas_user_auth.saas_user_auth_from_signed_token'
|
|
) as mock_from_signed:
|
|
mock_auth = MagicMock()
|
|
mock_from_signed.return_value = mock_auth
|
|
|
|
result = await saas_user_auth_from_cookie(mock_request)
|
|
|
|
assert result == mock_auth
|
|
mock_from_signed.assert_called_once_with(signed_token)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_cookie_no_cookie():
|
|
"""Test that saas_user_auth_from_cookie returns None if no cookie."""
|
|
mock_request = MagicMock()
|
|
mock_request.cookies = {}
|
|
|
|
result = await saas_user_auth_from_cookie(mock_request)
|
|
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_cookie_exception():
|
|
"""Test that saas_user_auth_from_cookie raises CookieError on exception."""
|
|
mock_request = MagicMock()
|
|
mock_request.cookies = {'keycloak_auth': 'invalid_token'}
|
|
|
|
with pytest.raises(CookieError):
|
|
await saas_user_auth_from_cookie(mock_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_signed_token(mock_config):
|
|
"""Test successful creation of SaasUserAuth from signed token."""
|
|
# Create a JWT access token
|
|
access_payload = {
|
|
'sub': 'test_user_id',
|
|
'exp': int(time.time()) + 3600,
|
|
'email': 'test@example.com',
|
|
'email_verified': True,
|
|
}
|
|
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
|
|
|
# Create a signed token containing the access and refresh tokens
|
|
token_payload = {
|
|
'access_token': access_token,
|
|
'refresh_token': 'test_refresh_token',
|
|
}
|
|
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
|
|
|
result = await saas_user_auth_from_signed_token(signed_token)
|
|
|
|
assert isinstance(result, SaasUserAuth)
|
|
assert result.user_id == 'test_user_id'
|
|
assert result.access_token.get_secret_value() == access_token
|
|
assert result.refresh_token.get_secret_value() == 'test_refresh_token'
|
|
assert result.email == 'test@example.com'
|
|
assert result.email_verified is True
|
|
|
|
|
|
def test_get_api_key_from_header_with_authorization_header():
|
|
"""Test that get_api_key_from_header extracts API key from Authorization header."""
|
|
# Create a mock request with Authorization header
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {'Authorization': 'Bearer test_api_key'}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that the API key was correctly extracted
|
|
assert api_key == 'test_api_key'
|
|
|
|
|
|
def test_get_api_key_from_header_with_x_session_api_key():
|
|
"""Test that get_api_key_from_header extracts API key from X-Session-API-Key header."""
|
|
# Create a mock request with X-Session-API-Key header
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {'X-Session-API-Key': 'session_api_key'}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that the API key was correctly extracted
|
|
assert api_key == 'session_api_key'
|
|
|
|
|
|
def test_get_api_key_from_header_with_both_headers():
|
|
"""Test that get_api_key_from_header prioritizes Authorization header when both are present."""
|
|
# Create a mock request with both headers
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {
|
|
'Authorization': 'Bearer auth_api_key',
|
|
'X-Session-API-Key': 'session_api_key',
|
|
}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that the API key from Authorization header was used
|
|
assert api_key == 'auth_api_key'
|
|
|
|
|
|
def test_get_api_key_from_header_with_no_headers():
|
|
"""Test that get_api_key_from_header returns None when no relevant headers are present."""
|
|
# Create a mock request with no relevant headers
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {'Other-Header': 'some_value'}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that None was returned
|
|
assert api_key is None
|
|
|
|
|
|
def test_get_api_key_from_header_with_invalid_authorization_format():
|
|
"""Test that get_api_key_from_header handles Authorization headers without 'Bearer ' prefix."""
|
|
# Create a mock request with incorrectly formatted Authorization header
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {'Authorization': 'InvalidFormat api_key'}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that None was returned
|
|
assert api_key is None
|
|
|
|
|
|
def test_get_api_key_from_header_with_x_access_token():
|
|
"""Test that get_api_key_from_header extracts API key from X-Access-Token header."""
|
|
# Create a mock request with X-Access-Token header
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {'X-Access-Token': 'access_token_key'}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that the API key was correctly extracted
|
|
assert api_key == 'access_token_key'
|
|
|
|
|
|
def test_get_api_key_from_header_priority_authorization_over_x_access_token():
|
|
"""Test that Authorization header takes priority over X-Access-Token header."""
|
|
# Create a mock request with both headers
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {
|
|
'Authorization': 'Bearer auth_api_key',
|
|
'X-Access-Token': 'access_token_key',
|
|
}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that the API key from Authorization header was used
|
|
assert api_key == 'auth_api_key'
|
|
|
|
|
|
def test_get_api_key_from_header_priority_x_session_over_x_access_token():
|
|
"""Test that X-Session-API-Key header takes priority over X-Access-Token header."""
|
|
# Create a mock request with both headers
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {
|
|
'X-Session-API-Key': 'session_api_key',
|
|
'X-Access-Token': 'access_token_key',
|
|
}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that the API key from X-Session-API-Key header was used
|
|
assert api_key == 'session_api_key'
|
|
|
|
|
|
def test_get_api_key_from_header_all_three_headers():
|
|
"""Test header priority when all three headers are present."""
|
|
# Create a mock request with all three headers
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {
|
|
'Authorization': 'Bearer auth_api_key',
|
|
'X-Session-API-Key': 'session_api_key',
|
|
'X-Access-Token': 'access_token_key',
|
|
}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that the API key from Authorization header was used (highest priority)
|
|
assert api_key == 'auth_api_key'
|
|
|
|
|
|
def test_get_api_key_from_header_invalid_authorization_fallback_to_x_access_token():
|
|
"""Test that invalid Authorization header falls back to X-Access-Token."""
|
|
# Create a mock request with invalid Authorization header and X-Access-Token
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {
|
|
'Authorization': 'InvalidFormat api_key',
|
|
'X-Access-Token': 'access_token_key',
|
|
}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that the API key from X-Access-Token header was used
|
|
assert api_key == 'access_token_key'
|
|
|
|
|
|
def test_get_api_key_from_header_empty_headers():
|
|
"""Test that empty header values are handled correctly."""
|
|
# Create a mock request with empty header values
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {
|
|
'Authorization': '',
|
|
'X-Session-API-Key': '',
|
|
'X-Access-Token': 'access_token_key',
|
|
}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that the API key from X-Access-Token header was used
|
|
assert api_key == 'access_token_key'
|
|
|
|
|
|
def test_get_api_key_from_header_bearer_with_empty_token():
|
|
"""Test that Bearer header with empty token falls back to other headers."""
|
|
# Create a mock request with Bearer header with empty token
|
|
mock_request = MagicMock(spec=Request)
|
|
mock_request.headers = {
|
|
'Authorization': 'Bearer ',
|
|
'X-Access-Token': 'access_token_key',
|
|
}
|
|
|
|
# Call the function
|
|
api_key = get_api_key_from_header(mock_request)
|
|
|
|
# Assert that empty string from Bearer is returned (current behavior)
|
|
# This tests the current implementation behavior
|
|
assert api_key == ''
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
|
|
"""Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked."""
|
|
# Arrange
|
|
access_payload = {
|
|
'sub': 'test_user_id',
|
|
'exp': int(time.time()) + 3600,
|
|
'email': 'user@colsch.us',
|
|
'email_verified': True,
|
|
}
|
|
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
|
|
|
token_payload = {
|
|
'access_token': access_token,
|
|
'refresh_token': 'test_refresh_token',
|
|
}
|
|
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
|
|
|
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
|
mock_domain_blocker.is_active.return_value = True
|
|
mock_domain_blocker.is_domain_blocked.return_value = True
|
|
|
|
# Act & Assert
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await saas_user_auth_from_signed_token(signed_token)
|
|
|
|
assert 'email domain is not allowed' in str(exc_info.value)
|
|
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
|
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
|
|
# Arrange
|
|
access_payload = {
|
|
'sub': 'test_user_id',
|
|
'exp': int(time.time()) + 3600,
|
|
'email': 'user@example.com',
|
|
'email_verified': True,
|
|
}
|
|
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
|
|
|
token_payload = {
|
|
'access_token': access_token,
|
|
'refresh_token': 'test_refresh_token',
|
|
}
|
|
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
|
|
|
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
|
mock_domain_blocker.is_active.return_value = True
|
|
mock_domain_blocker.is_domain_blocked.return_value = False
|
|
|
|
# Act
|
|
result = await saas_user_auth_from_signed_token(signed_token)
|
|
|
|
# Assert
|
|
assert isinstance(result, SaasUserAuth)
|
|
assert result.user_id == 'test_user_id'
|
|
assert result.email == 'user@example.com'
|
|
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
|
'user@example.com'
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
|
|
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
|
|
# Arrange
|
|
access_payload = {
|
|
'sub': 'test_user_id',
|
|
'exp': int(time.time()) + 3600,
|
|
'email': 'user@colsch.us',
|
|
'email_verified': True,
|
|
}
|
|
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
|
|
|
token_payload = {
|
|
'access_token': access_token,
|
|
'refresh_token': 'test_refresh_token',
|
|
}
|
|
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
|
|
|
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
|
mock_domain_blocker.is_active.return_value = False
|
|
|
|
# Act
|
|
result = await saas_user_auth_from_signed_token(signed_token)
|
|
|
|
# Assert
|
|
assert isinstance(result, SaasUserAuth)
|
|
assert result.user_id == 'test_user_id'
|
|
mock_domain_blocker.is_domain_blocked.assert_not_called()
|