mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Optimize get_sandbox_by_session_api_key with hash lookup (#13019)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
"""Add session_api_key_hash to v1_remote_sandbox table
|
||||
|
||||
Revision ID: 097
|
||||
Revises: 096
|
||||
Create Date: 2025-02-24 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '097'
|
||||
down_revision: Union[str, None] = '096'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add session_api_key_hash column to v1_remote_sandbox table."""
|
||||
op.add_column(
|
||||
'v1_remote_sandbox',
|
||||
sa.Column('session_api_key_hash', sa.String(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
'v1_remote_sandbox',
|
||||
['session_api_key_hash'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove session_api_key_hash column from v1_remote_sandbox table."""
|
||||
op.drop_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
table_name='v1_remote_sandbox',
|
||||
)
|
||||
op.drop_column('v1_remote_sandbox', 'session_api_key_hash')
|
||||
38
openhands/app_server/app_lifespan/alembic/versions/006.py
Normal file
38
openhands/app_server/app_lifespan/alembic/versions/006.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Add session_api_key_hash to v1_remote_sandbox table
|
||||
|
||||
Revision ID: 006
|
||||
Revises: 005
|
||||
Create Date: 2025-02-24 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '006'
|
||||
down_revision: Union[str, None] = '005'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add session_api_key_hash column to v1_remote_sandbox table."""
|
||||
with op.batch_alter_table('v1_remote_sandbox') as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column('session_api_key_hash', sa.String(), nullable=True)
|
||||
)
|
||||
batch_op.create_index(
|
||||
'ix_v1_remote_sandbox_session_api_key_hash',
|
||||
['session_api_key_hash'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove session_api_key_hash column from v1_remote_sandbox table."""
|
||||
with op.batch_alter_table('v1_remote_sandbox') as batch_op:
|
||||
batch_op.drop_index('ix_v1_remote_sandbox_session_api_key_hash')
|
||||
batch_op.drop_column('session_api_key_hash')
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
@@ -72,6 +73,11 @@ WORKER_1_PORT = 12000
|
||||
WORKER_2_PORT = 12001
|
||||
|
||||
|
||||
def _hash_session_api_key(session_api_key: str) -> str:
|
||||
"""Hash a session API key using SHA-256."""
|
||||
return hashlib.sha256(session_api_key.encode()).hexdigest()
|
||||
|
||||
|
||||
class StoredRemoteSandbox(Base): # type: ignore
|
||||
"""Local storage for remote sandbox info.
|
||||
|
||||
@@ -84,6 +90,7 @@ class StoredRemoteSandbox(Base): # type: ignore
|
||||
id = Column(String, primary_key=True)
|
||||
created_by_user_id = Column(String, nullable=True, index=True)
|
||||
sandbox_spec_id = Column(String, index=True) # shadows runtime['image']
|
||||
session_api_key_hash = Column(String, nullable=True, index=True)
|
||||
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
|
||||
|
||||
|
||||
@@ -343,12 +350,14 @@ class RemoteSandboxService(SandboxService):
|
||||
|
||||
return self._to_sandbox_info(stored_sandbox, runtime)
|
||||
|
||||
async def get_sandbox_by_session_api_key(
|
||||
async def _get_sandbox_by_session_api_key_legacy(
|
||||
self, session_api_key: str
|
||||
) -> Union[SandboxInfo, None]:
|
||||
"""Get a single sandbox by session API key."""
|
||||
# TODO: We should definitely refactor this and store the session_api_key in
|
||||
# the v1_remote_sandbox table
|
||||
"""Legacy method to get sandbox by session API key via runtime API.
|
||||
|
||||
This is the fallback for sandboxes created before the session_api_key_hash
|
||||
column was added. It calls the remote runtime API which is less efficient.
|
||||
"""
|
||||
try:
|
||||
response = await self._send_runtime_api_request(
|
||||
'GET',
|
||||
@@ -366,6 +375,10 @@ class RemoteSandboxService(SandboxService):
|
||||
sandbox = result.scalar_one_or_none()
|
||||
if sandbox is None:
|
||||
raise ValueError('sandbox_not_found')
|
||||
# Backfill the hash for future lookups (Auto committed at end of request)
|
||||
sandbox.session_api_key_hash = _hash_session_api_key(
|
||||
session_api_key
|
||||
)
|
||||
return self._to_sandbox_info(sandbox, runtime)
|
||||
except Exception:
|
||||
_logger.exception(
|
||||
@@ -382,6 +395,10 @@ class RemoteSandboxService(SandboxService):
|
||||
try:
|
||||
runtime = await self._get_runtime(stored_sandbox.id)
|
||||
if runtime and runtime.get('session_api_key') == session_api_key:
|
||||
# Backfill the hash for future lookups (Auto committed at end of request)
|
||||
stored_sandbox.session_api_key_hash = _hash_session_api_key(
|
||||
session_api_key
|
||||
)
|
||||
return self._to_sandbox_info(stored_sandbox, runtime)
|
||||
except Exception:
|
||||
# Continue checking other sandboxes if one fails
|
||||
@@ -389,6 +406,39 @@ class RemoteSandboxService(SandboxService):
|
||||
|
||||
return None
|
||||
|
||||
async def get_sandbox_by_session_api_key(
|
||||
self, session_api_key: str
|
||||
) -> Union[SandboxInfo, None]:
|
||||
"""Get a single sandbox by session API key.
|
||||
|
||||
Uses the stored session_api_key_hash for efficient database lookup instead
|
||||
of calling the remote runtime API. Falls back to legacy API-based lookup
|
||||
for sandboxes created before the hash column was added.
|
||||
"""
|
||||
session_api_key_hash = _hash_session_api_key(session_api_key)
|
||||
|
||||
# First try to find sandbox by hash in the database
|
||||
stmt = await self._secure_select()
|
||||
stmt = stmt.where(
|
||||
StoredRemoteSandbox.session_api_key_hash == session_api_key_hash
|
||||
)
|
||||
result = await self.db_session.execute(stmt)
|
||||
stored_sandbox = result.scalar_one_or_none()
|
||||
|
||||
if stored_sandbox:
|
||||
try:
|
||||
runtime = await self._get_runtime(stored_sandbox.id)
|
||||
return self._to_sandbox_info(stored_sandbox, runtime)
|
||||
except Exception:
|
||||
_logger.exception(
|
||||
f'Error getting runtime for sandbox {stored_sandbox.id}',
|
||||
stack_info=True,
|
||||
)
|
||||
return self._to_sandbox_info(stored_sandbox, None)
|
||||
|
||||
# Fallback for sandboxes created before the hash column was added
|
||||
return await self._get_sandbox_by_session_api_key_legacy(session_api_key)
|
||||
|
||||
async def start_sandbox(
|
||||
self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None
|
||||
) -> SandboxInfo:
|
||||
@@ -455,6 +505,13 @@ class RemoteSandboxService(SandboxService):
|
||||
response.raise_for_status()
|
||||
runtime_data = response.json()
|
||||
|
||||
# Store the session_api_key hash for efficient lookups
|
||||
session_api_key = runtime_data.get('session_api_key')
|
||||
if session_api_key:
|
||||
stored_sandbox.session_api_key_hash = _hash_session_api_key(
|
||||
session_api_key
|
||||
)
|
||||
|
||||
# Hack - result doesn't contain this
|
||||
runtime_data['pod_status'] = 'pending'
|
||||
|
||||
|
||||
@@ -119,6 +119,7 @@ def create_stored_sandbox(
|
||||
user_id: str = 'test-user-123',
|
||||
spec_id: str = 'test-image:latest',
|
||||
created_at: datetime | None = None,
|
||||
session_api_key_hash: str | None = None,
|
||||
) -> StoredRemoteSandbox:
|
||||
"""Helper function to create StoredRemoteSandbox for testing."""
|
||||
if created_at is None:
|
||||
@@ -128,6 +129,7 @@ def create_stored_sandbox(
|
||||
id=sandbox_id,
|
||||
created_by_user_id=user_id,
|
||||
sandbox_spec_id=spec_id,
|
||||
session_api_key_hash=session_api_key_hash,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
@@ -994,6 +996,203 @@ class TestErrorHandling:
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetSandboxBySessionApiKey:
|
||||
"""Test cases for get_sandbox_by_session_api_key functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_by_session_api_key_with_hash(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test finding sandbox by session API key using stored hash."""
|
||||
from openhands.app_server.sandbox.remote_sandbox_service import (
|
||||
_hash_session_api_key,
|
||||
)
|
||||
|
||||
# Setup
|
||||
session_api_key = 'test-session-key'
|
||||
expected_hash = _hash_session_api_key(session_api_key)
|
||||
stored_sandbox = create_stored_sandbox(session_api_key_hash=expected_hash)
|
||||
runtime_data = create_runtime_data(session_api_key=session_api_key)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = stored_sandbox
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
|
||||
session_api_key
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.id == 'test-sandbox-123'
|
||||
assert result.session_api_key == session_api_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_by_session_api_key_not_found(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test finding sandbox when no matching hash exists and legacy fallback fails."""
|
||||
# Setup - no hash match
|
||||
mock_result_no_hash = MagicMock()
|
||||
mock_result_no_hash.scalar_one_or_none.return_value = None
|
||||
|
||||
# Setup - legacy fallback: /list API fails, then no stored sandboxes
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = Exception('API error')
|
||||
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
mock_result_legacy = MagicMock()
|
||||
mock_result_legacy.scalars.return_value.all.return_value = []
|
||||
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(
|
||||
side_effect=[mock_result_no_hash, mock_result_legacy]
|
||||
)
|
||||
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
|
||||
'unknown-key'
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_by_session_api_key_legacy_via_list_api(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test legacy fallback finding sandbox via /list API and backfilling hash."""
|
||||
from openhands.app_server.sandbox.remote_sandbox_service import (
|
||||
_hash_session_api_key,
|
||||
)
|
||||
|
||||
# Setup
|
||||
session_api_key = 'test-session-key'
|
||||
stored_sandbox = create_stored_sandbox(
|
||||
session_api_key_hash=None
|
||||
) # Legacy sandbox
|
||||
runtime_data = create_runtime_data(session_api_key=session_api_key)
|
||||
|
||||
# First call returns None (no hash match)
|
||||
mock_result_no_match = MagicMock()
|
||||
mock_result_no_match.scalar_one_or_none.return_value = None
|
||||
|
||||
# Legacy fallback: /list API returns the runtime
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {'runtimes': [runtime_data]}
|
||||
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
# Query for sandbox by session_id returns the stored sandbox
|
||||
mock_result_sandbox = MagicMock()
|
||||
mock_result_sandbox.scalar_one_or_none.return_value = stored_sandbox
|
||||
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(
|
||||
side_effect=[mock_result_no_match, mock_result_sandbox]
|
||||
)
|
||||
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
|
||||
session_api_key
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.id == 'test-sandbox-123'
|
||||
# Verify the hash was backfilled
|
||||
expected_hash = _hash_session_api_key(session_api_key)
|
||||
assert stored_sandbox.session_api_key_hash == expected_hash
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_by_session_api_key_legacy_via_runtime_check(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test legacy fallback checking each sandbox's runtime when /list API fails."""
|
||||
from openhands.app_server.sandbox.remote_sandbox_service import (
|
||||
_hash_session_api_key,
|
||||
)
|
||||
|
||||
# Setup
|
||||
session_api_key = 'test-session-key'
|
||||
stored_sandbox = create_stored_sandbox(
|
||||
session_api_key_hash=None
|
||||
) # Legacy sandbox
|
||||
runtime_data = create_runtime_data(session_api_key=session_api_key)
|
||||
|
||||
# First call returns None (no hash match)
|
||||
mock_result_no_match = MagicMock()
|
||||
mock_result_no_match.scalar_one_or_none.return_value = None
|
||||
|
||||
# Legacy fallback: /list API fails
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = Exception('API error')
|
||||
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
# Get all stored sandboxes returns the legacy sandbox
|
||||
mock_result_all = MagicMock()
|
||||
mock_result_all.scalars.return_value.all.return_value = [stored_sandbox]
|
||||
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(
|
||||
side_effect=[mock_result_no_match, mock_result_all]
|
||||
)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
|
||||
session_api_key
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.id == 'test-sandbox-123'
|
||||
# Verify the hash was backfilled
|
||||
expected_hash = _hash_session_api_key(session_api_key)
|
||||
assert stored_sandbox.session_api_key_hash == expected_hash
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_by_session_api_key_runtime_error(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test handling runtime error when getting sandbox."""
|
||||
from openhands.app_server.sandbox.remote_sandbox_service import (
|
||||
_hash_session_api_key,
|
||||
)
|
||||
|
||||
# Setup
|
||||
session_api_key = 'test-session-key'
|
||||
expected_hash = _hash_session_api_key(session_api_key)
|
||||
stored_sandbox = create_stored_sandbox(session_api_key_hash=expected_hash)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = stored_sandbox
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(
|
||||
side_effect=Exception('Runtime error')
|
||||
)
|
||||
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
|
||||
session_api_key
|
||||
)
|
||||
|
||||
# Verify - should still return sandbox info, just with None runtime
|
||||
assert result is not None
|
||||
assert result.id == 'test-sandbox-123'
|
||||
assert result.status == SandboxStatus.MISSING # No runtime means MISSING
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test cases for utility functions."""
|
||||
|
||||
@@ -1011,6 +1210,27 @@ class TestUtilityFunctions:
|
||||
result = _build_service_url('http://localhost:8000', 'work-1')
|
||||
assert result == 'http://work-1-localhost:8000'
|
||||
|
||||
def test_hash_session_api_key(self):
|
||||
"""Test _hash_session_api_key function."""
|
||||
from openhands.app_server.sandbox.remote_sandbox_service import (
|
||||
_hash_session_api_key,
|
||||
)
|
||||
|
||||
# Test that same input always produces same hash
|
||||
key = 'test-session-api-key'
|
||||
hash1 = _hash_session_api_key(key)
|
||||
hash2 = _hash_session_api_key(key)
|
||||
assert hash1 == hash2
|
||||
|
||||
# Test that different inputs produce different hashes
|
||||
key2 = 'another-session-api-key'
|
||||
hash3 = _hash_session_api_key(key2)
|
||||
assert hash1 != hash3
|
||||
|
||||
# Test that hash is a 64-character hex string (SHA-256)
|
||||
assert len(hash1) == 64
|
||||
assert all(c in '0123456789abcdef' for c in hash1)
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Test cases for constants and mappings."""
|
||||
|
||||
Reference in New Issue
Block a user