From 9a5e5956fc5bf0a9a66ec8307d9863054b8aa838 Mon Sep 17 00:00:00 2001 From: tofarr Date: Fri, 23 May 2025 19:18:20 -0600 Subject: [PATCH] Added ability to read specify permitted origins in env (#8675) Co-authored-by: openhands --- openhands/server/listen.py | 8 +-- openhands/server/middleware.py | 21 ++++-- tests/unit/test_middleware.py | 120 +++++++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 11 deletions(-) create mode 100644 tests/unit/test_middleware.py diff --git a/openhands/server/listen.py b/openhands/server/listen.py index fcdce359c3..d195ef5fbe 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -19,13 +19,7 @@ if os.getenv('SERVE_FRONTEND', 'true').lower() == 'true': '/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist' ) -base_app.add_middleware( - LocalhostCORSMiddleware, - allow_credentials=True, - allow_methods=['*'], - allow_headers=['*'], -) - +base_app.add_middleware(LocalhostCORSMiddleware) base_app.add_middleware(CacheControlMiddleware) base_app.add_middleware( RateLimitMiddleware, diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 32994fc0ec..c1a1300ad4 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -1,8 +1,8 @@ import asyncio +import os from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any from urllib.parse import urlparse from fastapi import Request, status @@ -24,11 +24,24 @@ class LocalhostCORSMiddleware(CORSMiddleware): while using standard CORS rules for other origins. """ - def __init__(self, app: ASGIApp, **kwargs: Any) -> None: - super().__init__(app, **kwargs) + def __init__(self, app: ASGIApp) -> None: + allow_origins_str = os.getenv('PERMITTED_CORS_ORIGINS') + if allow_origins_str: + allow_origins = tuple( + origin.strip() for origin in allow_origins_str.split(',') + ) + else: + allow_origins = () + super().__init__( + app, + allow_origins=allow_origins, + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) def is_allowed_origin(self, origin: str) -> bool: - if origin: + if origin and not self.allow_origins and not self.allow_origin_regex: parsed = urlparse(origin) hostname = parsed.hostname or '' diff --git a/tests/unit/test_middleware.py b/tests/unit/test_middleware.py new file mode 100644 index 0000000000..d83004e54f --- /dev/null +++ b/tests/unit/test_middleware.py @@ -0,0 +1,120 @@ +import os +from unittest.mock import patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from starlette.middleware.cors import CORSMiddleware + +from openhands.server.middleware import LocalhostCORSMiddleware + + +@pytest.fixture +def app(): + """Create a test FastAPI application.""" + app = FastAPI() + + @app.get('/test') + def test_endpoint(): + return {'message': 'Test endpoint'} + + return app + + +def test_localhost_cors_middleware_init_with_env_var(): + """Test that the middleware correctly parses PERMITTED_CORS_ORIGINS environment variable.""" + with patch.dict( + os.environ, {'PERMITTED_CORS_ORIGINS': 'https://example.com,https://test.com'} + ): + app = FastAPI() + middleware = LocalhostCORSMiddleware(app) + + # Check that the origins were correctly parsed from the environment variable + assert 'https://example.com' in middleware.allow_origins + assert 'https://test.com' in middleware.allow_origins + assert len(middleware.allow_origins) == 2 + + +def test_localhost_cors_middleware_init_without_env_var(): + """Test that the middleware works correctly without PERMITTED_CORS_ORIGINS environment variable.""" + with patch.dict(os.environ, {}, clear=True): + app = FastAPI() + middleware = LocalhostCORSMiddleware(app) + + # Check that allow_origins is empty when no environment variable is set + assert middleware.allow_origins == () + + +def test_localhost_cors_middleware_is_allowed_origin_localhost(app): + """Test that localhost origins are allowed regardless of port.""" + app.add_middleware(LocalhostCORSMiddleware) + client = TestClient(app) + + # Test with localhost + response = client.get('/test', headers={'Origin': 'http://localhost:8000'}) + assert response.status_code == 200 + assert response.headers['access-control-allow-origin'] == 'http://localhost:8000' + + # Test with different port + response = client.get('/test', headers={'Origin': 'http://localhost:3000'}) + assert response.status_code == 200 + assert response.headers['access-control-allow-origin'] == 'http://localhost:3000' + + # Test with 127.0.0.1 + response = client.get('/test', headers={'Origin': 'http://127.0.0.1:8000'}) + assert response.status_code == 200 + assert response.headers['access-control-allow-origin'] == 'http://127.0.0.1:8000' + + +def test_localhost_cors_middleware_is_allowed_origin_non_localhost(app): + """Test that non-localhost origins follow the standard CORS rules.""" + # Set up the middleware with specific allowed origins + with patch.dict(os.environ, {'PERMITTED_CORS_ORIGINS': 'https://example.com'}): + app.add_middleware(LocalhostCORSMiddleware) + client = TestClient(app) + + # Test with allowed origin + response = client.get('/test', headers={'Origin': 'https://example.com'}) + assert response.status_code == 200 + assert response.headers['access-control-allow-origin'] == 'https://example.com' + + # Test with disallowed origin + response = client.get('/test', headers={'Origin': 'https://disallowed.com'}) + assert response.status_code == 200 + # The disallowed origin should not be in the response headers + assert 'access-control-allow-origin' not in response.headers + + +def test_localhost_cors_middleware_missing_origin(app): + """Test behavior when Origin header is missing.""" + app.add_middleware(LocalhostCORSMiddleware) + client = TestClient(app) + + # Test without Origin header + response = client.get('/test') + assert response.status_code == 200 + # There should be no access-control-allow-origin header + assert 'access-control-allow-origin' not in response.headers + + +def test_localhost_cors_middleware_inheritance(): + """Test that LocalhostCORSMiddleware correctly inherits from CORSMiddleware.""" + assert issubclass(LocalhostCORSMiddleware, CORSMiddleware) + + +def test_localhost_cors_middleware_cors_parameters(): + """Test that CORS parameters are set correctly in the middleware.""" + # We need to inspect the initialization parameters rather than attributes + # since CORSMiddleware doesn't expose these as attributes + with patch('fastapi.middleware.cors.CORSMiddleware.__init__') as mock_init: + mock_init.return_value = None + app = FastAPI() + LocalhostCORSMiddleware(app) + + # Check that the parent class was initialized with the correct parameters + mock_init.assert_called_once() + _, kwargs = mock_init.call_args + + assert kwargs['allow_credentials'] is True + assert kwargs['allow_methods'] == ['*'] + assert kwargs['allow_headers'] == ['*']