mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 13:52:43 +08:00
237 lines
8.9 KiB
Python
237 lines
8.9 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import Request, Response, status
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import SecretStr
|
|
from server.auth.auth_error import (
|
|
AuthError,
|
|
CookieError,
|
|
ExpiredError,
|
|
NoCredentialsError,
|
|
)
|
|
from server.auth.saas_user_auth import SaasUserAuth
|
|
from server.middleware import SetAuthCookieMiddleware
|
|
|
|
from openhands.server.user_auth.user_auth import AuthType
|
|
|
|
|
|
@pytest.fixture
|
|
def middleware():
|
|
return SetAuthCookieMiddleware()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_request():
|
|
request = MagicMock(spec=Request)
|
|
request.cookies = {}
|
|
return request
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_response():
|
|
return MagicMock(spec=Response)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_no_cookie(middleware, mock_request, mock_response):
|
|
"""Test middleware when no auth cookie is present."""
|
|
mock_request.cookies = {}
|
|
mock_call_next = AsyncMock(return_value=mock_response)
|
|
|
|
# Mock the request URL to have hostname 'localhost' and path that doesn't start with /api
|
|
mock_request.url = MagicMock()
|
|
mock_request.url.hostname = 'localhost'
|
|
mock_request.url.path = '/some/non-api/path'
|
|
|
|
result = await middleware(mock_request, mock_call_next)
|
|
|
|
assert result == mock_response
|
|
mock_call_next.assert_called_once_with(mock_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_with_cookie_no_refresh(
|
|
middleware, mock_request, mock_response
|
|
):
|
|
"""Test middleware when auth cookie is present but no refresh occurred."""
|
|
# Create a valid JWT token for testing
|
|
with (
|
|
patch('server.middleware.jwt.decode') as mock_decode,
|
|
patch('server.middleware.config') as mock_config,
|
|
):
|
|
mock_decode.return_value = {'accepted_tos': True}
|
|
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
|
|
|
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
|
|
mock_call_next = AsyncMock(return_value=mock_response)
|
|
|
|
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
|
mock_user_auth.refreshed = False
|
|
mock_user_auth.auth_type = AuthType.COOKIE
|
|
|
|
with patch(
|
|
'server.middleware.SetAuthCookieMiddleware._get_user_auth',
|
|
return_value=mock_user_auth,
|
|
):
|
|
result = await middleware(mock_request, mock_call_next)
|
|
|
|
assert result == mock_response
|
|
mock_call_next.assert_called_once_with(mock_request)
|
|
mock_response.set_cookie.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_with_cookie_and_refresh(
|
|
middleware, mock_request, mock_response
|
|
):
|
|
"""Test middleware when auth cookie is present and refresh occurred."""
|
|
# Create a valid JWT token for testing
|
|
with (
|
|
patch('server.middleware.jwt.decode') as mock_decode,
|
|
patch('server.middleware.config') as mock_config,
|
|
):
|
|
mock_decode.return_value = {'accepted_tos': True}
|
|
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
|
|
|
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
|
|
mock_call_next = AsyncMock(return_value=mock_response)
|
|
|
|
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
|
mock_user_auth.refreshed = True
|
|
mock_user_auth.access_token = SecretStr('new_access_token')
|
|
mock_user_auth.refresh_token = SecretStr('new_refresh_token')
|
|
mock_user_auth.accepted_tos = True # Set the accepted_tos property on the mock
|
|
mock_user_auth.auth_type = AuthType.COOKIE
|
|
|
|
with (
|
|
patch(
|
|
'server.middleware.SetAuthCookieMiddleware._get_user_auth',
|
|
return_value=mock_user_auth,
|
|
),
|
|
patch('server.middleware.set_response_cookie') as mock_set_cookie,
|
|
):
|
|
result = await middleware(mock_request, mock_call_next)
|
|
|
|
assert result == mock_response
|
|
mock_call_next.assert_called_once_with(mock_request)
|
|
mock_set_cookie.assert_called_once_with(
|
|
request=mock_request,
|
|
response=mock_response,
|
|
keycloak_access_token='new_access_token',
|
|
keycloak_refresh_token='new_refresh_token',
|
|
secure=True,
|
|
accepted_tos=True,
|
|
)
|
|
|
|
|
|
def decode_body(body: bytes | memoryview):
|
|
if isinstance(body, memoryview):
|
|
return body.tobytes().decode()
|
|
else:
|
|
return body.decode()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_with_no_auth_provided_error(middleware, mock_request):
|
|
"""Test middleware when NoCredentialsError is raised."""
|
|
# Create a valid JWT token for testing
|
|
with (
|
|
patch('server.middleware.jwt.decode') as mock_decode,
|
|
patch('server.middleware.config') as mock_config,
|
|
):
|
|
mock_decode.return_value = {'accepted_tos': True}
|
|
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
|
|
|
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
|
|
mock_call_next = AsyncMock(side_effect=NoCredentialsError())
|
|
|
|
result = await middleware(mock_request, mock_call_next)
|
|
|
|
assert isinstance(result, JSONResponse)
|
|
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
|
assert 'error' in decode_body(result.body)
|
|
assert decode_body(result.body).find('NoCredentialsError') > 0
|
|
# Cookie should not be deleted for NoCredentialsError
|
|
assert 'set-cookie' not in result.headers
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_with_expired_auth_cookie(middleware, mock_request):
|
|
"""Test middleware when ExpiredError is raised due to an expired authentication cookie."""
|
|
# Create a valid JWT token for testing
|
|
with (
|
|
patch('server.middleware.jwt.decode') as mock_decode,
|
|
patch('server.middleware.config') as mock_config,
|
|
):
|
|
mock_decode.return_value = {'accepted_tos': True}
|
|
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
|
|
|
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
|
|
mock_call_next = AsyncMock(
|
|
side_effect=ExpiredError('Authentication token has expired')
|
|
)
|
|
|
|
with patch('server.middleware.logger') as mock_logger:
|
|
result = await middleware(mock_request, mock_call_next)
|
|
|
|
assert isinstance(result, JSONResponse)
|
|
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
|
assert 'error' in decode_body(result.body)
|
|
assert decode_body(result.body).find('Authentication token has expired') > 0
|
|
# Cookie should be deleted for ExpiredError as it's now handled as a general AuthError
|
|
assert 'set-cookie' in result.headers
|
|
# Logger should be called for ExpiredError
|
|
mock_logger.warning.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_with_cookie_error(middleware, mock_request):
|
|
"""Test middleware when CookieError is raised."""
|
|
# Create a valid JWT token for testing
|
|
with (
|
|
patch('server.middleware.jwt.decode') as mock_decode,
|
|
patch('server.middleware.config') as mock_config,
|
|
):
|
|
mock_decode.return_value = {'accepted_tos': True}
|
|
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
|
|
|
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
|
|
mock_call_next = AsyncMock(side_effect=CookieError('Invalid cookie'))
|
|
|
|
result = await middleware(mock_request, mock_call_next)
|
|
|
|
assert isinstance(result, JSONResponse)
|
|
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
|
assert 'error' in decode_body(result.body)
|
|
assert decode_body(result.body).find('Invalid cookie') > 0
|
|
# Cookie should be deleted for CookieError
|
|
assert 'set-cookie' in result.headers
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_with_other_auth_error(middleware, mock_request):
|
|
"""Test middleware when another AuthError is raised."""
|
|
# Create a valid JWT token for testing
|
|
with (
|
|
patch('server.middleware.jwt.decode') as mock_decode,
|
|
patch('server.middleware.config') as mock_config,
|
|
):
|
|
mock_decode.return_value = {'accepted_tos': True}
|
|
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
|
|
|
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
|
|
mock_call_next = AsyncMock(side_effect=AuthError('General auth error'))
|
|
|
|
with patch('server.middleware.logger') as mock_logger:
|
|
result = await middleware(mock_request, mock_call_next)
|
|
|
|
assert isinstance(result, JSONResponse)
|
|
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
|
assert 'error' in decode_body(result.body)
|
|
assert decode_body(result.body).find('General auth error') > 0
|
|
# Cookie should be deleted for any AuthError
|
|
assert 'set-cookie' in result.headers
|
|
# Logger should be called for non-NoCredentialsError
|
|
mock_logger.warning.assert_called_once()
|