diff --git a/openhands/runtime/impl/local/local_runtime.py b/openhands/runtime/impl/local/local_runtime.py index 1741d01c12..82ca70324a 100644 --- a/openhands/runtime/impl/local/local_runtime.py +++ b/openhands/runtime/impl/local/local_runtime.py @@ -200,6 +200,11 @@ class LocalRuntime(ActionExecutionClient): headless_mode, ) + #If there is an API key in the environment we use this in requests to the runtime + session_api_key = os.getenv("SESSION_API_KEY") + if session_api_key: + self.session.headers['X-Session-API-Key'] = session_api_key + @property def action_execution_server_url(self) -> str: return self.api_url diff --git a/openhands/server/conversation_manager/standalone_conversation_manager.py b/openhands/server/conversation_manager/standalone_conversation_manager.py index b3c726058e..3a98528c32 100644 --- a/openhands/server/conversation_manager/standalone_conversation_manager.py +++ b/openhands/server/conversation_manager/standalone_conversation_manager.py @@ -480,7 +480,7 @@ class StandaloneConversationManager(ConversationManager): return AgentLoopInfo( conversation_id=session.sid, url=self._get_conversation_url(session.sid), - api_key=None, + session_api_key=None, event_store=session.agent_session.event_stream, ) diff --git a/openhands/server/data_models/agent_loop_info.py b/openhands/server/data_models/agent_loop_info.py index 1582c310e0..34eedfa1e0 100644 --- a/openhands/server/data_models/agent_loop_info.py +++ b/openhands/server/data_models/agent_loop_info.py @@ -10,5 +10,5 @@ class AgentLoopInfo: """ conversation_id: str url: str | None - api_key: str | None + session_api_key: str | None event_store: EventStoreABC diff --git a/openhands/server/data_models/conversation_info.py b/openhands/server/data_models/conversation_info.py index 67b17e75a6..c7f495fe88 100644 --- a/openhands/server/data_models/conversation_info.py +++ b/openhands/server/data_models/conversation_info.py @@ -20,5 +20,5 @@ class ConversationInfo: trigger: ConversationTrigger | None = None num_connections: int = 0 url: str | None = None - api_key: str | None = None + session_api_key: str | None = None created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) diff --git a/openhands/server/listen.py b/openhands/server/listen.py index 4caf6279ec..fcdce359c3 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -10,6 +10,7 @@ from openhands.server.middleware import ( InMemoryRateLimiter, LocalhostCORSMiddleware, RateLimitMiddleware, + SessionApiKeyMiddleware, ) from openhands.server.static import SPAStaticFiles @@ -32,4 +33,8 @@ base_app.add_middleware( ) base_app.middleware('http')(AttachConversationMiddleware(base_app)) +session_api_key = os.getenv('SESSION_API_KEY') +if session_api_key: + base_app.middleware('http')(SessionApiKeyMiddleware(session_api_key)) + app = socketio.ASGIApp(sio, other_asgi_app=base_app) diff --git a/openhands/server/listen_socket.py b/openhands/server/listen_socket.py index 8614dfa47a..4f6e0bbcc1 100644 --- a/openhands/server/listen_socket.py +++ b/openhands/server/listen_socket.py @@ -1,4 +1,5 @@ import asyncio +import os from types import MappingProxyType from typing import Any from urllib.parse import parse_qs @@ -72,6 +73,9 @@ async def connect(connection_id: str, environ: dict) -> None: logger.error('No conversation_id in query params') raise ConnectionRefusedError('No conversation_id in query params') + if _invalid_session_api_key(query_params): + raise ConnectionRefusedError('invalid_session_api_key') + cookies_str = environ.get('HTTP_COOKIE', '') # Get Authorization header from the environment # Headers in WSGI/ASGI are prefixed with 'HTTP_' and have dashes replaced with underscores @@ -160,3 +164,13 @@ async def oh_action(connection_id: str, data: dict[str, Any]) -> None: async def disconnect(connection_id: str) -> None: logger.info(f'sio:disconnect:{connection_id}') await conversation_manager.disconnect_from_session(connection_id) + + +def _invalid_session_api_key(query_params: dict[str, list[Any]]): + session_api_key = os.getenv('SESSION_API_KEY') + if not session_api_key: + return False + query_api_keys = query_params['session_api_key'] + if not query_api_keys: + return True + return query_api_keys[0] != session_api_key diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 4ce91593ac..1e5d391751 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -1,5 +1,6 @@ import asyncio from collections import defaultdict +from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any from urllib.parse import urlparse @@ -192,3 +193,26 @@ class AttachConversationMiddleware(SessionMiddlewareInterface): await self._detach_session(request) return response + + +@dataclass +class SessionApiKeyMiddleware: + """Middleware which ensures that all requests contain a header with the token given""" + + session_api_key: str + + async def __call__( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + if ( + request.method != 'OPTIONS' + and request.url.path != '/alive' + and request.url.path != '/server_info' + ): + if self.session_api_key != request.headers.get('X-Session-API-Key'): + return JSONResponse( + {'code': 'invalid_session_api_key'}, + status_code=status.HTTP_401_UNAUTHORIZED, + ) + response = await call_next(request) + return response diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 2b9ee602bc..73376f21b6 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -69,7 +69,7 @@ class InitSessionResponse(BaseModel): status: str conversation_id: str conversation_url: str - api_key: str | None + session_api_key: str | None message: str | None = None @@ -228,7 +228,7 @@ async def new_conversation( status='ok', conversation_id=agent_loop_info.conversation_id, conversation_url=agent_loop_info.url, - api_key=agent_loop_info.api_key, + session_api_key=agent_loop_info.session_api_key, ) except MissingSettingsError as e: return JSONResponse( @@ -289,7 +289,7 @@ async def search_conversations( ) connection_ids_to_conversation_ids = await conversation_manager.get_connections(filter_to_sids=conversation_ids) agent_loop_info = await conversation_manager.get_agent_loop_info(filter_to_sids=conversation_ids) - urls_by_conversation_id = {info.conversation_id: info.url for info in agent_loop_info} + agent_loop_info_by_conversation_id = {info.conversation_id: info for info in agent_loop_info} result = ConversationInfoResultSet( results=await wait_all( _get_conversation_info( @@ -299,7 +299,8 @@ async def search_conversations( 1 for conversation_id in connection_ids_to_conversation_ids.values() if conversation_id == conversation.conversation_id ), - url=urls_by_conversation_id.get(conversation.conversation_id), + agent_loop_info=agent_loop_info_by_conversation_id.get(conversation.conversation_id), + ) for conversation in filtered_results ), @@ -317,9 +318,9 @@ async def get_conversation( metadata = await conversation_store.get_metadata(conversation_id) is_running = await conversation_manager.is_agent_loop_running(conversation_id) num_connections = len(await conversation_manager.get_connections(filter_to_sids={conversation_id})) - agent_loop_info = await conversation_manager.get_agent_loop_info(filter_to_sids={conversation_id}) - url = agent_loop_info[0].url if agent_loop_info else None - conversation_info = await _get_conversation_info(metadata, is_running, num_connections, url) + agent_loop_infos = await conversation_manager.get_agent_loop_info(filter_to_sids={conversation_id}) + agent_loop_info = agent_loop_infos[0] if agent_loop_infos else None + conversation_info = await _get_conversation_info(metadata, is_running, num_connections, agent_loop_info) return conversation_info except FileNotFoundError: return None @@ -348,7 +349,7 @@ async def _get_conversation_info( conversation: ConversationMetadata, is_running: bool, num_connections: int, - url: str | None, + agent_loop_info: AgentLoopInfo | None, ) -> ConversationInfo | None: try: title = conversation.title @@ -365,7 +366,8 @@ async def _get_conversation_info( ConversationStatus.RUNNING if is_running else ConversationStatus.STOPPED ), num_connections=num_connections, - url=url, + url=agent_loop_info.url if agent_loop_info else None, + session_api_key=agent_loop_info.session_api_key if agent_loop_info else None, ) except Exception as e: logger.error( diff --git a/tests/unit/test_conversation.py b/tests/unit/test_conversation.py index 85a9a5958a..f79f0008aa 100644 --- a/tests/unit/test_conversation.py +++ b/tests/unit/test_conversation.py @@ -250,7 +250,7 @@ async def test_new_conversation_success(provider_handler_mock): mock_create_conversation.return_value = MagicMock( conversation_id='test_conversation_id', url='https://my-conversation.com', - api_key=None, + session_api_key=None, ) test_request = InitSessionRequest( @@ -292,7 +292,7 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock): mock_create_conversation.return_value = MagicMock( conversation_id='test_conversation_id', url='https://my-conversation.com', - api_key=None, + session_api_key=None, ) # Mock SuggestedTask.get_prompt_for_task @@ -375,7 +375,7 @@ async def test_new_conversation_missing_settings(provider_handler_mock): @pytest.mark.asyncio -async def test_new_conversation_invalid_api_key(provider_handler_mock): +async def test_new_conversation_invalid_session_api_key(provider_handler_mock): """Test creating a new conversation with an invalid API key.""" with _patch_store(): # Mock the _create_new_conversation function to raise LLMAuthenticationError @@ -477,7 +477,7 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock): mock_create_conversation.return_value = MagicMock( conversation_id='test_conversation_id', url='https://my-conversation.com', - api_key=None, + session_api_key=None, ) # Create the request object @@ -514,7 +514,7 @@ async def test_new_conversation_with_null_repository(): mock_create_conversation.return_value = MagicMock( conversation_id='test_conversation_id', url='https://my-conversation.com', - api_key=None, + session_api_key=None, ) # Create the request object with null repository