Add Session API Key Authentication for Runtime Communication (#8550)

This commit is contained in:
tofarr 2025-05-19 09:59:22 -06:00 committed by GitHub
parent 872b97a3c8
commit 38b4d93237
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 67 additions and 17 deletions

View File

@ -200,6 +200,11 @@ class LocalRuntime(ActionExecutionClient):
headless_mode,
)
#If there is an API key in the environment we use this in requests to the runtime
session_api_key = os.getenv("SESSION_API_KEY")
if session_api_key:
self.session.headers['X-Session-API-Key'] = session_api_key
@property
def action_execution_server_url(self) -> str:
return self.api_url

View File

@ -480,7 +480,7 @@ class StandaloneConversationManager(ConversationManager):
return AgentLoopInfo(
conversation_id=session.sid,
url=self._get_conversation_url(session.sid),
api_key=None,
session_api_key=None,
event_store=session.agent_session.event_stream,
)

View File

@ -10,5 +10,5 @@ class AgentLoopInfo:
"""
conversation_id: str
url: str | None
api_key: str | None
session_api_key: str | None
event_store: EventStoreABC

View File

@ -20,5 +20,5 @@ class ConversationInfo:
trigger: ConversationTrigger | None = None
num_connections: int = 0
url: str | None = None
api_key: str | None = None
session_api_key: str | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))

View File

@ -10,6 +10,7 @@ from openhands.server.middleware import (
InMemoryRateLimiter,
LocalhostCORSMiddleware,
RateLimitMiddleware,
SessionApiKeyMiddleware,
)
from openhands.server.static import SPAStaticFiles
@ -32,4 +33,8 @@ base_app.add_middleware(
)
base_app.middleware('http')(AttachConversationMiddleware(base_app))
session_api_key = os.getenv('SESSION_API_KEY')
if session_api_key:
base_app.middleware('http')(SessionApiKeyMiddleware(session_api_key))
app = socketio.ASGIApp(sio, other_asgi_app=base_app)

View File

@ -1,4 +1,5 @@
import asyncio
import os
from types import MappingProxyType
from typing import Any
from urllib.parse import parse_qs
@ -72,6 +73,9 @@ async def connect(connection_id: str, environ: dict) -> None:
logger.error('No conversation_id in query params')
raise ConnectionRefusedError('No conversation_id in query params')
if _invalid_session_api_key(query_params):
raise ConnectionRefusedError('invalid_session_api_key')
cookies_str = environ.get('HTTP_COOKIE', '')
# Get Authorization header from the environment
# Headers in WSGI/ASGI are prefixed with 'HTTP_' and have dashes replaced with underscores
@ -160,3 +164,13 @@ async def oh_action(connection_id: str, data: dict[str, Any]) -> None:
async def disconnect(connection_id: str) -> None:
logger.info(f'sio:disconnect:{connection_id}')
await conversation_manager.disconnect_from_session(connection_id)
def _invalid_session_api_key(query_params: dict[str, list[Any]]):
session_api_key = os.getenv('SESSION_API_KEY')
if not session_api_key:
return False
query_api_keys = query_params['session_api_key']
if not query_api_keys:
return True
return query_api_keys[0] != session_api_key

View File

@ -1,5 +1,6 @@
import asyncio
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
from urllib.parse import urlparse
@ -192,3 +193,26 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
await self._detach_session(request)
return response
@dataclass
class SessionApiKeyMiddleware:
"""Middleware which ensures that all requests contain a header with the token given"""
session_api_key: str
async def __call__(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if (
request.method != 'OPTIONS'
and request.url.path != '/alive'
and request.url.path != '/server_info'
):
if self.session_api_key != request.headers.get('X-Session-API-Key'):
return JSONResponse(
{'code': 'invalid_session_api_key'},
status_code=status.HTTP_401_UNAUTHORIZED,
)
response = await call_next(request)
return response

View File

@ -69,7 +69,7 @@ class InitSessionResponse(BaseModel):
status: str
conversation_id: str
conversation_url: str
api_key: str | None
session_api_key: str | None
message: str | None = None
@ -228,7 +228,7 @@ async def new_conversation(
status='ok',
conversation_id=agent_loop_info.conversation_id,
conversation_url=agent_loop_info.url,
api_key=agent_loop_info.api_key,
session_api_key=agent_loop_info.session_api_key,
)
except MissingSettingsError as e:
return JSONResponse(
@ -289,7 +289,7 @@ async def search_conversations(
)
connection_ids_to_conversation_ids = await conversation_manager.get_connections(filter_to_sids=conversation_ids)
agent_loop_info = await conversation_manager.get_agent_loop_info(filter_to_sids=conversation_ids)
urls_by_conversation_id = {info.conversation_id: info.url for info in agent_loop_info}
agent_loop_info_by_conversation_id = {info.conversation_id: info for info in agent_loop_info}
result = ConversationInfoResultSet(
results=await wait_all(
_get_conversation_info(
@ -299,7 +299,8 @@ async def search_conversations(
1 for conversation_id in connection_ids_to_conversation_ids.values()
if conversation_id == conversation.conversation_id
),
url=urls_by_conversation_id.get(conversation.conversation_id),
agent_loop_info=agent_loop_info_by_conversation_id.get(conversation.conversation_id),
)
for conversation in filtered_results
),
@ -317,9 +318,9 @@ async def get_conversation(
metadata = await conversation_store.get_metadata(conversation_id)
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
num_connections = len(await conversation_manager.get_connections(filter_to_sids={conversation_id}))
agent_loop_info = await conversation_manager.get_agent_loop_info(filter_to_sids={conversation_id})
url = agent_loop_info[0].url if agent_loop_info else None
conversation_info = await _get_conversation_info(metadata, is_running, num_connections, url)
agent_loop_infos = await conversation_manager.get_agent_loop_info(filter_to_sids={conversation_id})
agent_loop_info = agent_loop_infos[0] if agent_loop_infos else None
conversation_info = await _get_conversation_info(metadata, is_running, num_connections, agent_loop_info)
return conversation_info
except FileNotFoundError:
return None
@ -348,7 +349,7 @@ async def _get_conversation_info(
conversation: ConversationMetadata,
is_running: bool,
num_connections: int,
url: str | None,
agent_loop_info: AgentLoopInfo | None,
) -> ConversationInfo | None:
try:
title = conversation.title
@ -365,7 +366,8 @@ async def _get_conversation_info(
ConversationStatus.RUNNING if is_running else ConversationStatus.STOPPED
),
num_connections=num_connections,
url=url,
url=agent_loop_info.url if agent_loop_info else None,
session_api_key=agent_loop_info.session_api_key if agent_loop_info else None,
)
except Exception as e:
logger.error(

View File

@ -250,7 +250,7 @@ async def test_new_conversation_success(provider_handler_mock):
mock_create_conversation.return_value = MagicMock(
conversation_id='test_conversation_id',
url='https://my-conversation.com',
api_key=None,
session_api_key=None,
)
test_request = InitSessionRequest(
@ -292,7 +292,7 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
mock_create_conversation.return_value = MagicMock(
conversation_id='test_conversation_id',
url='https://my-conversation.com',
api_key=None,
session_api_key=None,
)
# Mock SuggestedTask.get_prompt_for_task
@ -375,7 +375,7 @@ async def test_new_conversation_missing_settings(provider_handler_mock):
@pytest.mark.asyncio
async def test_new_conversation_invalid_api_key(provider_handler_mock):
async def test_new_conversation_invalid_session_api_key(provider_handler_mock):
"""Test creating a new conversation with an invalid API key."""
with _patch_store():
# Mock the _create_new_conversation function to raise LLMAuthenticationError
@ -477,7 +477,7 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock):
mock_create_conversation.return_value = MagicMock(
conversation_id='test_conversation_id',
url='https://my-conversation.com',
api_key=None,
session_api_key=None,
)
# Create the request object
@ -514,7 +514,7 @@ async def test_new_conversation_with_null_repository():
mock_create_conversation.return_value = MagicMock(
conversation_id='test_conversation_id',
url='https://my-conversation.com',
api_key=None,
session_api_key=None,
)
# Create the request object with null repository