Optimize get_sandbox_by_session_api_key with hash lookup (#13019)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-02-24 13:55:21 +00:00
committed by GitHub
parent 68165b52d9
commit 0677c035ff
4 changed files with 360 additions and 4 deletions

View File

@@ -0,0 +1,41 @@
"""Add session_api_key_hash to v1_remote_sandbox table
Revision ID: 097
Revises: 096
Create Date: 2025-02-24 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '097'
down_revision: Union[str, None] = '096'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Add session_api_key_hash column to v1_remote_sandbox table."""
op.add_column(
'v1_remote_sandbox',
sa.Column('session_api_key_hash', sa.String(), nullable=True),
)
op.create_index(
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
'v1_remote_sandbox',
['session_api_key_hash'],
unique=False,
)
def downgrade() -> None:
"""Remove session_api_key_hash column from v1_remote_sandbox table."""
op.drop_index(
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
table_name='v1_remote_sandbox',
)
op.drop_column('v1_remote_sandbox', 'session_api_key_hash')

View File

@@ -0,0 +1,38 @@
"""Add session_api_key_hash to v1_remote_sandbox table
Revision ID: 006
Revises: 005
Create Date: 2025-02-24 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '006'
down_revision: Union[str, None] = '005'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Add session_api_key_hash column to v1_remote_sandbox table."""
with op.batch_alter_table('v1_remote_sandbox') as batch_op:
batch_op.add_column(
sa.Column('session_api_key_hash', sa.String(), nullable=True)
)
batch_op.create_index(
'ix_v1_remote_sandbox_session_api_key_hash',
['session_api_key_hash'],
unique=False,
)
def downgrade() -> None:
"""Remove session_api_key_hash column from v1_remote_sandbox table."""
with op.batch_alter_table('v1_remote_sandbox') as batch_op:
batch_op.drop_index('ix_v1_remote_sandbox_session_api_key_hash')
batch_op.drop_column('session_api_key_hash')

View File

@@ -1,4 +1,5 @@
import asyncio
import hashlib
import logging
import os
from dataclasses import dataclass
@@ -72,6 +73,11 @@ WORKER_1_PORT = 12000
WORKER_2_PORT = 12001
def _hash_session_api_key(session_api_key: str) -> str:
"""Hash a session API key using SHA-256."""
return hashlib.sha256(session_api_key.encode()).hexdigest()
class StoredRemoteSandbox(Base): # type: ignore
"""Local storage for remote sandbox info.
@@ -84,6 +90,7 @@ class StoredRemoteSandbox(Base): # type: ignore
id = Column(String, primary_key=True)
created_by_user_id = Column(String, nullable=True, index=True)
sandbox_spec_id = Column(String, index=True) # shadows runtime['image']
session_api_key_hash = Column(String, nullable=True, index=True)
created_at = Column(UtcDateTime, server_default=func.now(), index=True)
@@ -343,12 +350,14 @@ class RemoteSandboxService(SandboxService):
return self._to_sandbox_info(stored_sandbox, runtime)
async def get_sandbox_by_session_api_key(
async def _get_sandbox_by_session_api_key_legacy(
self, session_api_key: str
) -> Union[SandboxInfo, None]:
"""Get a single sandbox by session API key."""
# TODO: We should definitely refactor this and store the session_api_key in
# the v1_remote_sandbox table
"""Legacy method to get sandbox by session API key via runtime API.
This is the fallback for sandboxes created before the session_api_key_hash
column was added. It calls the remote runtime API which is less efficient.
"""
try:
response = await self._send_runtime_api_request(
'GET',
@@ -366,6 +375,10 @@ class RemoteSandboxService(SandboxService):
sandbox = result.scalar_one_or_none()
if sandbox is None:
raise ValueError('sandbox_not_found')
# Backfill the hash for future lookups (Auto committed at end of request)
sandbox.session_api_key_hash = _hash_session_api_key(
session_api_key
)
return self._to_sandbox_info(sandbox, runtime)
except Exception:
_logger.exception(
@@ -382,6 +395,10 @@ class RemoteSandboxService(SandboxService):
try:
runtime = await self._get_runtime(stored_sandbox.id)
if runtime and runtime.get('session_api_key') == session_api_key:
# Backfill the hash for future lookups (Auto committed at end of request)
stored_sandbox.session_api_key_hash = _hash_session_api_key(
session_api_key
)
return self._to_sandbox_info(stored_sandbox, runtime)
except Exception:
# Continue checking other sandboxes if one fails
@@ -389,6 +406,39 @@ class RemoteSandboxService(SandboxService):
return None
async def get_sandbox_by_session_api_key(
self, session_api_key: str
) -> Union[SandboxInfo, None]:
"""Get a single sandbox by session API key.
Uses the stored session_api_key_hash for efficient database lookup instead
of calling the remote runtime API. Falls back to legacy API-based lookup
for sandboxes created before the hash column was added.
"""
session_api_key_hash = _hash_session_api_key(session_api_key)
# First try to find sandbox by hash in the database
stmt = await self._secure_select()
stmt = stmt.where(
StoredRemoteSandbox.session_api_key_hash == session_api_key_hash
)
result = await self.db_session.execute(stmt)
stored_sandbox = result.scalar_one_or_none()
if stored_sandbox:
try:
runtime = await self._get_runtime(stored_sandbox.id)
return self._to_sandbox_info(stored_sandbox, runtime)
except Exception:
_logger.exception(
f'Error getting runtime for sandbox {stored_sandbox.id}',
stack_info=True,
)
return self._to_sandbox_info(stored_sandbox, None)
# Fallback for sandboxes created before the hash column was added
return await self._get_sandbox_by_session_api_key_legacy(session_api_key)
async def start_sandbox(
self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None
) -> SandboxInfo:
@@ -455,6 +505,13 @@ class RemoteSandboxService(SandboxService):
response.raise_for_status()
runtime_data = response.json()
# Store the session_api_key hash for efficient lookups
session_api_key = runtime_data.get('session_api_key')
if session_api_key:
stored_sandbox.session_api_key_hash = _hash_session_api_key(
session_api_key
)
# Hack - result doesn't contain this
runtime_data['pod_status'] = 'pending'

View File

@@ -119,6 +119,7 @@ def create_stored_sandbox(
user_id: str = 'test-user-123',
spec_id: str = 'test-image:latest',
created_at: datetime | None = None,
session_api_key_hash: str | None = None,
) -> StoredRemoteSandbox:
"""Helper function to create StoredRemoteSandbox for testing."""
if created_at is None:
@@ -128,6 +129,7 @@ def create_stored_sandbox(
id=sandbox_id,
created_by_user_id=user_id,
sandbox_spec_id=spec_id,
session_api_key_hash=session_api_key_hash,
created_at=created_at,
)
@@ -994,6 +996,203 @@ class TestErrorHandling:
assert result is False
class TestGetSandboxBySessionApiKey:
"""Test cases for get_sandbox_by_session_api_key functionality."""
@pytest.mark.asyncio
async def test_get_sandbox_by_session_api_key_with_hash(
self, remote_sandbox_service
):
"""Test finding sandbox by session API key using stored hash."""
from openhands.app_server.sandbox.remote_sandbox_service import (
_hash_session_api_key,
)
# Setup
session_api_key = 'test-session-key'
expected_hash = _hash_session_api_key(session_api_key)
stored_sandbox = create_stored_sandbox(session_api_key_hash=expected_hash)
runtime_data = create_runtime_data(session_api_key=session_api_key)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = stored_sandbox
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
# Execute
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
session_api_key
)
# Verify
assert result is not None
assert result.id == 'test-sandbox-123'
assert result.session_api_key == session_api_key
@pytest.mark.asyncio
async def test_get_sandbox_by_session_api_key_not_found(
self, remote_sandbox_service
):
"""Test finding sandbox when no matching hash exists and legacy fallback fails."""
# Setup - no hash match
mock_result_no_hash = MagicMock()
mock_result_no_hash.scalar_one_or_none.return_value = None
# Setup - legacy fallback: /list API fails, then no stored sandboxes
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = Exception('API error')
remote_sandbox_service.httpx_client.request = AsyncMock(
return_value=mock_response
)
mock_result_legacy = MagicMock()
mock_result_legacy.scalars.return_value.all.return_value = []
remote_sandbox_service.db_session.execute = AsyncMock(
side_effect=[mock_result_no_hash, mock_result_legacy]
)
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
# Execute
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
'unknown-key'
)
# Verify
assert result is None
@pytest.mark.asyncio
async def test_get_sandbox_by_session_api_key_legacy_via_list_api(
self, remote_sandbox_service
):
"""Test legacy fallback finding sandbox via /list API and backfilling hash."""
from openhands.app_server.sandbox.remote_sandbox_service import (
_hash_session_api_key,
)
# Setup
session_api_key = 'test-session-key'
stored_sandbox = create_stored_sandbox(
session_api_key_hash=None
) # Legacy sandbox
runtime_data = create_runtime_data(session_api_key=session_api_key)
# First call returns None (no hash match)
mock_result_no_match = MagicMock()
mock_result_no_match.scalar_one_or_none.return_value = None
# Legacy fallback: /list API returns the runtime
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {'runtimes': [runtime_data]}
remote_sandbox_service.httpx_client.request = AsyncMock(
return_value=mock_response
)
# Query for sandbox by session_id returns the stored sandbox
mock_result_sandbox = MagicMock()
mock_result_sandbox.scalar_one_or_none.return_value = stored_sandbox
remote_sandbox_service.db_session.execute = AsyncMock(
side_effect=[mock_result_no_match, mock_result_sandbox]
)
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
# Execute
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
session_api_key
)
# Verify
assert result is not None
assert result.id == 'test-sandbox-123'
# Verify the hash was backfilled
expected_hash = _hash_session_api_key(session_api_key)
assert stored_sandbox.session_api_key_hash == expected_hash
@pytest.mark.asyncio
async def test_get_sandbox_by_session_api_key_legacy_via_runtime_check(
self, remote_sandbox_service
):
"""Test legacy fallback checking each sandbox's runtime when /list API fails."""
from openhands.app_server.sandbox.remote_sandbox_service import (
_hash_session_api_key,
)
# Setup
session_api_key = 'test-session-key'
stored_sandbox = create_stored_sandbox(
session_api_key_hash=None
) # Legacy sandbox
runtime_data = create_runtime_data(session_api_key=session_api_key)
# First call returns None (no hash match)
mock_result_no_match = MagicMock()
mock_result_no_match.scalar_one_or_none.return_value = None
# Legacy fallback: /list API fails
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = Exception('API error')
remote_sandbox_service.httpx_client.request = AsyncMock(
return_value=mock_response
)
# Get all stored sandboxes returns the legacy sandbox
mock_result_all = MagicMock()
mock_result_all.scalars.return_value.all.return_value = [stored_sandbox]
remote_sandbox_service.db_session.execute = AsyncMock(
side_effect=[mock_result_no_match, mock_result_all]
)
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
# Execute
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
session_api_key
)
# Verify
assert result is not None
assert result.id == 'test-sandbox-123'
# Verify the hash was backfilled
expected_hash = _hash_session_api_key(session_api_key)
assert stored_sandbox.session_api_key_hash == expected_hash
@pytest.mark.asyncio
async def test_get_sandbox_by_session_api_key_runtime_error(
self, remote_sandbox_service
):
"""Test handling runtime error when getting sandbox."""
from openhands.app_server.sandbox.remote_sandbox_service import (
_hash_session_api_key,
)
# Setup
session_api_key = 'test-session-key'
expected_hash = _hash_session_api_key(session_api_key)
stored_sandbox = create_stored_sandbox(session_api_key_hash=expected_hash)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = stored_sandbox
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
remote_sandbox_service._get_runtime = AsyncMock(
side_effect=Exception('Runtime error')
)
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
# Execute
result = await remote_sandbox_service.get_sandbox_by_session_api_key(
session_api_key
)
# Verify - should still return sandbox info, just with None runtime
assert result is not None
assert result.id == 'test-sandbox-123'
assert result.status == SandboxStatus.MISSING # No runtime means MISSING
class TestUtilityFunctions:
"""Test cases for utility functions."""
@@ -1011,6 +1210,27 @@ class TestUtilityFunctions:
result = _build_service_url('http://localhost:8000', 'work-1')
assert result == 'http://work-1-localhost:8000'
def test_hash_session_api_key(self):
"""Test _hash_session_api_key function."""
from openhands.app_server.sandbox.remote_sandbox_service import (
_hash_session_api_key,
)
# Test that same input always produces same hash
key = 'test-session-api-key'
hash1 = _hash_session_api_key(key)
hash2 = _hash_session_api_key(key)
assert hash1 == hash2
# Test that different inputs produce different hashes
key2 = 'another-session-api-key'
hash3 = _hash_session_api_key(key2)
assert hash1 != hash3
# Test that hash is a 64-character hex string (SHA-256)
assert len(hash1) == 64
assert all(c in '0123456789abcdef' for c in hash1)
class TestConstants:
"""Test cases for constants and mappings."""