OpenHands/enterprise/tests/unit/test_auth_middleware.py
2025-09-04 15:44:54 -04:00

237 lines
8.9 KiB
Python

from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from pydantic import SecretStr
from server.auth.auth_error import (
AuthError,
CookieError,
ExpiredError,
NoCredentialsError,
)
from server.auth.saas_user_auth import SaasUserAuth
from server.middleware import SetAuthCookieMiddleware
from openhands.server.user_auth.user_auth import AuthType
@pytest.fixture
def middleware():
return SetAuthCookieMiddleware()
@pytest.fixture
def mock_request():
request = MagicMock(spec=Request)
request.cookies = {}
return request
@pytest.fixture
def mock_response():
return MagicMock(spec=Response)
@pytest.mark.asyncio
async def test_middleware_no_cookie(middleware, mock_request, mock_response):
"""Test middleware when no auth cookie is present."""
mock_request.cookies = {}
mock_call_next = AsyncMock(return_value=mock_response)
# Mock the request URL to have hostname 'localhost' and path that doesn't start with /api
mock_request.url = MagicMock()
mock_request.url.hostname = 'localhost'
mock_request.url.path = '/some/non-api/path'
result = await middleware(mock_request, mock_call_next)
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
@pytest.mark.asyncio
async def test_middleware_with_cookie_no_refresh(
middleware, mock_request, mock_response
):
"""Test middleware when auth cookie is present but no refresh occurred."""
# Create a valid JWT token for testing
with (
patch('server.middleware.jwt.decode') as mock_decode,
patch('server.middleware.config') as mock_config,
):
mock_decode.return_value = {'accepted_tos': True}
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
mock_call_next = AsyncMock(return_value=mock_response)
mock_user_auth = MagicMock(spec=SaasUserAuth)
mock_user_auth.refreshed = False
mock_user_auth.auth_type = AuthType.COOKIE
with patch(
'server.middleware.SetAuthCookieMiddleware._get_user_auth',
return_value=mock_user_auth,
):
result = await middleware(mock_request, mock_call_next)
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
mock_response.set_cookie.assert_not_called()
@pytest.mark.asyncio
async def test_middleware_with_cookie_and_refresh(
middleware, mock_request, mock_response
):
"""Test middleware when auth cookie is present and refresh occurred."""
# Create a valid JWT token for testing
with (
patch('server.middleware.jwt.decode') as mock_decode,
patch('server.middleware.config') as mock_config,
):
mock_decode.return_value = {'accepted_tos': True}
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
mock_call_next = AsyncMock(return_value=mock_response)
mock_user_auth = MagicMock(spec=SaasUserAuth)
mock_user_auth.refreshed = True
mock_user_auth.access_token = SecretStr('new_access_token')
mock_user_auth.refresh_token = SecretStr('new_refresh_token')
mock_user_auth.accepted_tos = True # Set the accepted_tos property on the mock
mock_user_auth.auth_type = AuthType.COOKIE
with (
patch(
'server.middleware.SetAuthCookieMiddleware._get_user_auth',
return_value=mock_user_auth,
),
patch('server.middleware.set_response_cookie') as mock_set_cookie,
):
result = await middleware(mock_request, mock_call_next)
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
mock_set_cookie.assert_called_once_with(
request=mock_request,
response=mock_response,
keycloak_access_token='new_access_token',
keycloak_refresh_token='new_refresh_token',
secure=True,
accepted_tos=True,
)
def decode_body(body: bytes | memoryview):
if isinstance(body, memoryview):
return body.tobytes().decode()
else:
return body.decode()
@pytest.mark.asyncio
async def test_middleware_with_no_auth_provided_error(middleware, mock_request):
"""Test middleware when NoCredentialsError is raised."""
# Create a valid JWT token for testing
with (
patch('server.middleware.jwt.decode') as mock_decode,
patch('server.middleware.config') as mock_config,
):
mock_decode.return_value = {'accepted_tos': True}
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
mock_call_next = AsyncMock(side_effect=NoCredentialsError())
result = await middleware(mock_request, mock_call_next)
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_401_UNAUTHORIZED
assert 'error' in decode_body(result.body)
assert decode_body(result.body).find('NoCredentialsError') > 0
# Cookie should not be deleted for NoCredentialsError
assert 'set-cookie' not in result.headers
@pytest.mark.asyncio
async def test_middleware_with_expired_auth_cookie(middleware, mock_request):
"""Test middleware when ExpiredError is raised due to an expired authentication cookie."""
# Create a valid JWT token for testing
with (
patch('server.middleware.jwt.decode') as mock_decode,
patch('server.middleware.config') as mock_config,
):
mock_decode.return_value = {'accepted_tos': True}
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
mock_call_next = AsyncMock(
side_effect=ExpiredError('Authentication token has expired')
)
with patch('server.middleware.logger') as mock_logger:
result = await middleware(mock_request, mock_call_next)
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_401_UNAUTHORIZED
assert 'error' in decode_body(result.body)
assert decode_body(result.body).find('Authentication token has expired') > 0
# Cookie should be deleted for ExpiredError as it's now handled as a general AuthError
assert 'set-cookie' in result.headers
# Logger should be called for ExpiredError
mock_logger.warning.assert_called_once()
@pytest.mark.asyncio
async def test_middleware_with_cookie_error(middleware, mock_request):
"""Test middleware when CookieError is raised."""
# Create a valid JWT token for testing
with (
patch('server.middleware.jwt.decode') as mock_decode,
patch('server.middleware.config') as mock_config,
):
mock_decode.return_value = {'accepted_tos': True}
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
mock_call_next = AsyncMock(side_effect=CookieError('Invalid cookie'))
result = await middleware(mock_request, mock_call_next)
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_401_UNAUTHORIZED
assert 'error' in decode_body(result.body)
assert decode_body(result.body).find('Invalid cookie') > 0
# Cookie should be deleted for CookieError
assert 'set-cookie' in result.headers
@pytest.mark.asyncio
async def test_middleware_with_other_auth_error(middleware, mock_request):
"""Test middleware when another AuthError is raised."""
# Create a valid JWT token for testing
with (
patch('server.middleware.jwt.decode') as mock_decode,
patch('server.middleware.config') as mock_config,
):
mock_decode.return_value = {'accepted_tos': True}
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
mock_call_next = AsyncMock(side_effect=AuthError('General auth error'))
with patch('server.middleware.logger') as mock_logger:
result = await middleware(mock_request, mock_call_next)
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_401_UNAUTHORIZED
assert 'error' in decode_body(result.body)
assert decode_body(result.body).find('General auth error') > 0
# Cookie should be deleted for any AuthError
assert 'set-cookie' in result.headers
# Logger should be called for non-NoCredentialsError
mock_logger.warning.assert_called_once()