mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add Session API Key Authentication for Runtime Communication (#8550)
This commit is contained in:
parent
872b97a3c8
commit
38b4d93237
@ -200,6 +200,11 @@ class LocalRuntime(ActionExecutionClient):
|
||||
headless_mode,
|
||||
)
|
||||
|
||||
#If there is an API key in the environment we use this in requests to the runtime
|
||||
session_api_key = os.getenv("SESSION_API_KEY")
|
||||
if session_api_key:
|
||||
self.session.headers['X-Session-API-Key'] = session_api_key
|
||||
|
||||
@property
|
||||
def action_execution_server_url(self) -> str:
|
||||
return self.api_url
|
||||
|
||||
@ -480,7 +480,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
return AgentLoopInfo(
|
||||
conversation_id=session.sid,
|
||||
url=self._get_conversation_url(session.sid),
|
||||
api_key=None,
|
||||
session_api_key=None,
|
||||
event_store=session.agent_session.event_stream,
|
||||
)
|
||||
|
||||
|
||||
@ -10,5 +10,5 @@ class AgentLoopInfo:
|
||||
"""
|
||||
conversation_id: str
|
||||
url: str | None
|
||||
api_key: str | None
|
||||
session_api_key: str | None
|
||||
event_store: EventStoreABC
|
||||
|
||||
@ -20,5 +20,5 @@ class ConversationInfo:
|
||||
trigger: ConversationTrigger | None = None
|
||||
num_connections: int = 0
|
||||
url: str | None = None
|
||||
api_key: str | None = None
|
||||
session_api_key: str | None = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@ -10,6 +10,7 @@ from openhands.server.middleware import (
|
||||
InMemoryRateLimiter,
|
||||
LocalhostCORSMiddleware,
|
||||
RateLimitMiddleware,
|
||||
SessionApiKeyMiddleware,
|
||||
)
|
||||
from openhands.server.static import SPAStaticFiles
|
||||
|
||||
@ -32,4 +33,8 @@ base_app.add_middleware(
|
||||
)
|
||||
base_app.middleware('http')(AttachConversationMiddleware(base_app))
|
||||
|
||||
session_api_key = os.getenv('SESSION_API_KEY')
|
||||
if session_api_key:
|
||||
base_app.middleware('http')(SessionApiKeyMiddleware(session_api_key))
|
||||
|
||||
app = socketio.ASGIApp(sio, other_asgi_app=base_app)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
@ -72,6 +73,9 @@ async def connect(connection_id: str, environ: dict) -> None:
|
||||
logger.error('No conversation_id in query params')
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
|
||||
if _invalid_session_api_key(query_params):
|
||||
raise ConnectionRefusedError('invalid_session_api_key')
|
||||
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
# Get Authorization header from the environment
|
||||
# Headers in WSGI/ASGI are prefixed with 'HTTP_' and have dashes replaced with underscores
|
||||
@ -160,3 +164,13 @@ async def oh_action(connection_id: str, data: dict[str, Any]) -> None:
|
||||
async def disconnect(connection_id: str) -> None:
|
||||
logger.info(f'sio:disconnect:{connection_id}')
|
||||
await conversation_manager.disconnect_from_session(connection_id)
|
||||
|
||||
|
||||
def _invalid_session_api_key(query_params: dict[str, list[Any]]):
|
||||
session_api_key = os.getenv('SESSION_API_KEY')
|
||||
if not session_api_key:
|
||||
return False
|
||||
query_api_keys = query_params['session_api_key']
|
||||
if not query_api_keys:
|
||||
return True
|
||||
return query_api_keys[0] != session_api_key
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
@ -192,3 +193,26 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
await self._detach_session(request)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionApiKeyMiddleware:
|
||||
"""Middleware which ensures that all requests contain a header with the token given"""
|
||||
|
||||
session_api_key: str
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
if (
|
||||
request.method != 'OPTIONS'
|
||||
and request.url.path != '/alive'
|
||||
and request.url.path != '/server_info'
|
||||
):
|
||||
if self.session_api_key != request.headers.get('X-Session-API-Key'):
|
||||
return JSONResponse(
|
||||
{'code': 'invalid_session_api_key'},
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
@ -69,7 +69,7 @@ class InitSessionResponse(BaseModel):
|
||||
status: str
|
||||
conversation_id: str
|
||||
conversation_url: str
|
||||
api_key: str | None
|
||||
session_api_key: str | None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@ -228,7 +228,7 @@ async def new_conversation(
|
||||
status='ok',
|
||||
conversation_id=agent_loop_info.conversation_id,
|
||||
conversation_url=agent_loop_info.url,
|
||||
api_key=agent_loop_info.api_key,
|
||||
session_api_key=agent_loop_info.session_api_key,
|
||||
)
|
||||
except MissingSettingsError as e:
|
||||
return JSONResponse(
|
||||
@ -289,7 +289,7 @@ async def search_conversations(
|
||||
)
|
||||
connection_ids_to_conversation_ids = await conversation_manager.get_connections(filter_to_sids=conversation_ids)
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(filter_to_sids=conversation_ids)
|
||||
urls_by_conversation_id = {info.conversation_id: info.url for info in agent_loop_info}
|
||||
agent_loop_info_by_conversation_id = {info.conversation_id: info for info in agent_loop_info}
|
||||
result = ConversationInfoResultSet(
|
||||
results=await wait_all(
|
||||
_get_conversation_info(
|
||||
@ -299,7 +299,8 @@ async def search_conversations(
|
||||
1 for conversation_id in connection_ids_to_conversation_ids.values()
|
||||
if conversation_id == conversation.conversation_id
|
||||
),
|
||||
url=urls_by_conversation_id.get(conversation.conversation_id),
|
||||
agent_loop_info=agent_loop_info_by_conversation_id.get(conversation.conversation_id),
|
||||
|
||||
)
|
||||
for conversation in filtered_results
|
||||
),
|
||||
@ -317,9 +318,9 @@ async def get_conversation(
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
num_connections = len(await conversation_manager.get_connections(filter_to_sids={conversation_id}))
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(filter_to_sids={conversation_id})
|
||||
url = agent_loop_info[0].url if agent_loop_info else None
|
||||
conversation_info = await _get_conversation_info(metadata, is_running, num_connections, url)
|
||||
agent_loop_infos = await conversation_manager.get_agent_loop_info(filter_to_sids={conversation_id})
|
||||
agent_loop_info = agent_loop_infos[0] if agent_loop_infos else None
|
||||
conversation_info = await _get_conversation_info(metadata, is_running, num_connections, agent_loop_info)
|
||||
return conversation_info
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
@ -348,7 +349,7 @@ async def _get_conversation_info(
|
||||
conversation: ConversationMetadata,
|
||||
is_running: bool,
|
||||
num_connections: int,
|
||||
url: str | None,
|
||||
agent_loop_info: AgentLoopInfo | None,
|
||||
) -> ConversationInfo | None:
|
||||
try:
|
||||
title = conversation.title
|
||||
@ -365,7 +366,8 @@ async def _get_conversation_info(
|
||||
ConversationStatus.RUNNING if is_running else ConversationStatus.STOPPED
|
||||
),
|
||||
num_connections=num_connections,
|
||||
url=url,
|
||||
url=agent_loop_info.url if agent_loop_info else None,
|
||||
session_api_key=agent_loop_info.session_api_key if agent_loop_info else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@ -250,7 +250,7 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
api_key=None,
|
||||
session_api_key=None,
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
@ -292,7 +292,7 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
api_key=None,
|
||||
session_api_key=None,
|
||||
)
|
||||
|
||||
# Mock SuggestedTask.get_prompt_for_task
|
||||
@ -375,7 +375,7 @@ async def test_new_conversation_missing_settings(provider_handler_mock):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_invalid_api_key(provider_handler_mock):
|
||||
async def test_new_conversation_invalid_session_api_key(provider_handler_mock):
|
||||
"""Test creating a new conversation with an invalid API key."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function to raise LLMAuthenticationError
|
||||
@ -477,7 +477,7 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock):
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
api_key=None,
|
||||
session_api_key=None,
|
||||
)
|
||||
|
||||
# Create the request object
|
||||
@ -514,7 +514,7 @@ async def test_new_conversation_with_null_repository():
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
api_key=None,
|
||||
session_api_key=None,
|
||||
)
|
||||
|
||||
# Create the request object with null repository
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user