From 0677c035ff467c352baa8427a81ea48f1ea0f716 Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Tue, 24 Feb 2026 13:55:21 +0000 Subject: [PATCH] Optimize get_sandbox_by_session_api_key with hash lookup (#13019) Co-authored-by: openhands --- ...ssion_api_key_hash_to_v1_remote_sandbox.py | 41 ++++ .../app_lifespan/alembic/versions/006.py | 38 +++ .../sandbox/remote_sandbox_service.py | 65 +++++- .../app_server/test_remote_sandbox_service.py | 220 ++++++++++++++++++ 4 files changed, 360 insertions(+), 4 deletions(-) create mode 100644 enterprise/migrations/versions/097_add_session_api_key_hash_to_v1_remote_sandbox.py create mode 100644 openhands/app_server/app_lifespan/alembic/versions/006.py diff --git a/enterprise/migrations/versions/097_add_session_api_key_hash_to_v1_remote_sandbox.py b/enterprise/migrations/versions/097_add_session_api_key_hash_to_v1_remote_sandbox.py new file mode 100644 index 0000000000..8e2a54442b --- /dev/null +++ b/enterprise/migrations/versions/097_add_session_api_key_hash_to_v1_remote_sandbox.py @@ -0,0 +1,41 @@ +"""Add session_api_key_hash to v1_remote_sandbox table + +Revision ID: 097 +Revises: 096 +Create Date: 2025-02-24 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '097' +down_revision: Union[str, None] = '096' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add session_api_key_hash column to v1_remote_sandbox table.""" + op.add_column( + 'v1_remote_sandbox', + sa.Column('session_api_key_hash', sa.String(), nullable=True), + ) + op.create_index( + op.f('ix_v1_remote_sandbox_session_api_key_hash'), + 'v1_remote_sandbox', + ['session_api_key_hash'], + unique=False, + ) + + +def downgrade() -> None: + """Remove session_api_key_hash column from v1_remote_sandbox table.""" + op.drop_index( + op.f('ix_v1_remote_sandbox_session_api_key_hash'), + table_name='v1_remote_sandbox', + ) + op.drop_column('v1_remote_sandbox', 'session_api_key_hash') diff --git a/openhands/app_server/app_lifespan/alembic/versions/006.py b/openhands/app_server/app_lifespan/alembic/versions/006.py new file mode 100644 index 0000000000..a0e3f5debd --- /dev/null +++ b/openhands/app_server/app_lifespan/alembic/versions/006.py @@ -0,0 +1,38 @@ +"""Add session_api_key_hash to v1_remote_sandbox table + +Revision ID: 006 +Revises: 005 +Create Date: 2025-02-24 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '006' +down_revision: Union[str, None] = '005' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add session_api_key_hash column to v1_remote_sandbox table.""" + with op.batch_alter_table('v1_remote_sandbox') as batch_op: + batch_op.add_column( + sa.Column('session_api_key_hash', sa.String(), nullable=True) + ) + batch_op.create_index( + 'ix_v1_remote_sandbox_session_api_key_hash', + ['session_api_key_hash'], + unique=False, + ) + + +def downgrade() -> None: + """Remove session_api_key_hash column from v1_remote_sandbox table.""" + with op.batch_alter_table('v1_remote_sandbox') as batch_op: + batch_op.drop_index('ix_v1_remote_sandbox_session_api_key_hash') + batch_op.drop_column('session_api_key_hash') diff --git a/openhands/app_server/sandbox/remote_sandbox_service.py b/openhands/app_server/sandbox/remote_sandbox_service.py index 12284bf115..48fc5560da 100644 --- a/openhands/app_server/sandbox/remote_sandbox_service.py +++ b/openhands/app_server/sandbox/remote_sandbox_service.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import logging import os from dataclasses import dataclass @@ -72,6 +73,11 @@ WORKER_1_PORT = 12000 WORKER_2_PORT = 12001 +def _hash_session_api_key(session_api_key: str) -> str: + """Hash a session API key using SHA-256.""" + return hashlib.sha256(session_api_key.encode()).hexdigest() + + class StoredRemoteSandbox(Base): # type: ignore """Local storage for remote sandbox info. @@ -84,6 +90,7 @@ class StoredRemoteSandbox(Base): # type: ignore id = Column(String, primary_key=True) created_by_user_id = Column(String, nullable=True, index=True) sandbox_spec_id = Column(String, index=True) # shadows runtime['image'] + session_api_key_hash = Column(String, nullable=True, index=True) created_at = Column(UtcDateTime, server_default=func.now(), index=True) @@ -343,12 +350,14 @@ class RemoteSandboxService(SandboxService): return self._to_sandbox_info(stored_sandbox, runtime) - async def get_sandbox_by_session_api_key( + async def _get_sandbox_by_session_api_key_legacy( self, session_api_key: str ) -> Union[SandboxInfo, None]: - """Get a single sandbox by session API key.""" - # TODO: We should definitely refactor this and store the session_api_key in - # the v1_remote_sandbox table + """Legacy method to get sandbox by session API key via runtime API. + + This is the fallback for sandboxes created before the session_api_key_hash + column was added. It calls the remote runtime API which is less efficient. + """ try: response = await self._send_runtime_api_request( 'GET', @@ -366,6 +375,10 @@ class RemoteSandboxService(SandboxService): sandbox = result.scalar_one_or_none() if sandbox is None: raise ValueError('sandbox_not_found') + # Backfill the hash for future lookups (Auto committed at end of request) + sandbox.session_api_key_hash = _hash_session_api_key( + session_api_key + ) return self._to_sandbox_info(sandbox, runtime) except Exception: _logger.exception( @@ -382,6 +395,10 @@ class RemoteSandboxService(SandboxService): try: runtime = await self._get_runtime(stored_sandbox.id) if runtime and runtime.get('session_api_key') == session_api_key: + # Backfill the hash for future lookups (Auto committed at end of request) + stored_sandbox.session_api_key_hash = _hash_session_api_key( + session_api_key + ) return self._to_sandbox_info(stored_sandbox, runtime) except Exception: # Continue checking other sandboxes if one fails @@ -389,6 +406,39 @@ class RemoteSandboxService(SandboxService): return None + async def get_sandbox_by_session_api_key( + self, session_api_key: str + ) -> Union[SandboxInfo, None]: + """Get a single sandbox by session API key. + + Uses the stored session_api_key_hash for efficient database lookup instead + of calling the remote runtime API. Falls back to legacy API-based lookup + for sandboxes created before the hash column was added. + """ + session_api_key_hash = _hash_session_api_key(session_api_key) + + # First try to find sandbox by hash in the database + stmt = await self._secure_select() + stmt = stmt.where( + StoredRemoteSandbox.session_api_key_hash == session_api_key_hash + ) + result = await self.db_session.execute(stmt) + stored_sandbox = result.scalar_one_or_none() + + if stored_sandbox: + try: + runtime = await self._get_runtime(stored_sandbox.id) + return self._to_sandbox_info(stored_sandbox, runtime) + except Exception: + _logger.exception( + f'Error getting runtime for sandbox {stored_sandbox.id}', + stack_info=True, + ) + return self._to_sandbox_info(stored_sandbox, None) + + # Fallback for sandboxes created before the hash column was added + return await self._get_sandbox_by_session_api_key_legacy(session_api_key) + async def start_sandbox( self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None ) -> SandboxInfo: @@ -455,6 +505,13 @@ class RemoteSandboxService(SandboxService): response.raise_for_status() runtime_data = response.json() + # Store the session_api_key hash for efficient lookups + session_api_key = runtime_data.get('session_api_key') + if session_api_key: + stored_sandbox.session_api_key_hash = _hash_session_api_key( + session_api_key + ) + # Hack - result doesn't contain this runtime_data['pod_status'] = 'pending' diff --git a/tests/unit/app_server/test_remote_sandbox_service.py b/tests/unit/app_server/test_remote_sandbox_service.py index dcec8e390f..113c1ab6e5 100644 --- a/tests/unit/app_server/test_remote_sandbox_service.py +++ b/tests/unit/app_server/test_remote_sandbox_service.py @@ -119,6 +119,7 @@ def create_stored_sandbox( user_id: str = 'test-user-123', spec_id: str = 'test-image:latest', created_at: datetime | None = None, + session_api_key_hash: str | None = None, ) -> StoredRemoteSandbox: """Helper function to create StoredRemoteSandbox for testing.""" if created_at is None: @@ -128,6 +129,7 @@ def create_stored_sandbox( id=sandbox_id, created_by_user_id=user_id, sandbox_spec_id=spec_id, + session_api_key_hash=session_api_key_hash, created_at=created_at, ) @@ -994,6 +996,203 @@ class TestErrorHandling: assert result is False +class TestGetSandboxBySessionApiKey: + """Test cases for get_sandbox_by_session_api_key functionality.""" + + @pytest.mark.asyncio + async def test_get_sandbox_by_session_api_key_with_hash( + self, remote_sandbox_service + ): + """Test finding sandbox by session API key using stored hash.""" + from openhands.app_server.sandbox.remote_sandbox_service import ( + _hash_session_api_key, + ) + + # Setup + session_api_key = 'test-session-key' + expected_hash = _hash_session_api_key(session_api_key) + stored_sandbox = create_stored_sandbox(session_api_key_hash=expected_hash) + runtime_data = create_runtime_data(session_api_key=session_api_key) + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = stored_sandbox + remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result) + remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data) + remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123' + + # Execute + result = await remote_sandbox_service.get_sandbox_by_session_api_key( + session_api_key + ) + + # Verify + assert result is not None + assert result.id == 'test-sandbox-123' + assert result.session_api_key == session_api_key + + @pytest.mark.asyncio + async def test_get_sandbox_by_session_api_key_not_found( + self, remote_sandbox_service + ): + """Test finding sandbox when no matching hash exists and legacy fallback fails.""" + # Setup - no hash match + mock_result_no_hash = MagicMock() + mock_result_no_hash.scalar_one_or_none.return_value = None + + # Setup - legacy fallback: /list API fails, then no stored sandboxes + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = Exception('API error') + remote_sandbox_service.httpx_client.request = AsyncMock( + return_value=mock_response + ) + + mock_result_legacy = MagicMock() + mock_result_legacy.scalars.return_value.all.return_value = [] + + remote_sandbox_service.db_session.execute = AsyncMock( + side_effect=[mock_result_no_hash, mock_result_legacy] + ) + remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123' + + # Execute + result = await remote_sandbox_service.get_sandbox_by_session_api_key( + 'unknown-key' + ) + + # Verify + assert result is None + + @pytest.mark.asyncio + async def test_get_sandbox_by_session_api_key_legacy_via_list_api( + self, remote_sandbox_service + ): + """Test legacy fallback finding sandbox via /list API and backfilling hash.""" + from openhands.app_server.sandbox.remote_sandbox_service import ( + _hash_session_api_key, + ) + + # Setup + session_api_key = 'test-session-key' + stored_sandbox = create_stored_sandbox( + session_api_key_hash=None + ) # Legacy sandbox + runtime_data = create_runtime_data(session_api_key=session_api_key) + + # First call returns None (no hash match) + mock_result_no_match = MagicMock() + mock_result_no_match.scalar_one_or_none.return_value = None + + # Legacy fallback: /list API returns the runtime + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = {'runtimes': [runtime_data]} + remote_sandbox_service.httpx_client.request = AsyncMock( + return_value=mock_response + ) + + # Query for sandbox by session_id returns the stored sandbox + mock_result_sandbox = MagicMock() + mock_result_sandbox.scalar_one_or_none.return_value = stored_sandbox + + remote_sandbox_service.db_session.execute = AsyncMock( + side_effect=[mock_result_no_match, mock_result_sandbox] + ) + remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123' + + # Execute + result = await remote_sandbox_service.get_sandbox_by_session_api_key( + session_api_key + ) + + # Verify + assert result is not None + assert result.id == 'test-sandbox-123' + # Verify the hash was backfilled + expected_hash = _hash_session_api_key(session_api_key) + assert stored_sandbox.session_api_key_hash == expected_hash + + @pytest.mark.asyncio + async def test_get_sandbox_by_session_api_key_legacy_via_runtime_check( + self, remote_sandbox_service + ): + """Test legacy fallback checking each sandbox's runtime when /list API fails.""" + from openhands.app_server.sandbox.remote_sandbox_service import ( + _hash_session_api_key, + ) + + # Setup + session_api_key = 'test-session-key' + stored_sandbox = create_stored_sandbox( + session_api_key_hash=None + ) # Legacy sandbox + runtime_data = create_runtime_data(session_api_key=session_api_key) + + # First call returns None (no hash match) + mock_result_no_match = MagicMock() + mock_result_no_match.scalar_one_or_none.return_value = None + + # Legacy fallback: /list API fails + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = Exception('API error') + remote_sandbox_service.httpx_client.request = AsyncMock( + return_value=mock_response + ) + + # Get all stored sandboxes returns the legacy sandbox + mock_result_all = MagicMock() + mock_result_all.scalars.return_value.all.return_value = [stored_sandbox] + + remote_sandbox_service.db_session.execute = AsyncMock( + side_effect=[mock_result_no_match, mock_result_all] + ) + remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data) + remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123' + + # Execute + result = await remote_sandbox_service.get_sandbox_by_session_api_key( + session_api_key + ) + + # Verify + assert result is not None + assert result.id == 'test-sandbox-123' + # Verify the hash was backfilled + expected_hash = _hash_session_api_key(session_api_key) + assert stored_sandbox.session_api_key_hash == expected_hash + + @pytest.mark.asyncio + async def test_get_sandbox_by_session_api_key_runtime_error( + self, remote_sandbox_service + ): + """Test handling runtime error when getting sandbox.""" + from openhands.app_server.sandbox.remote_sandbox_service import ( + _hash_session_api_key, + ) + + # Setup + session_api_key = 'test-session-key' + expected_hash = _hash_session_api_key(session_api_key) + stored_sandbox = create_stored_sandbox(session_api_key_hash=expected_hash) + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = stored_sandbox + remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result) + remote_sandbox_service._get_runtime = AsyncMock( + side_effect=Exception('Runtime error') + ) + remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123' + + # Execute + result = await remote_sandbox_service.get_sandbox_by_session_api_key( + session_api_key + ) + + # Verify - should still return sandbox info, just with None runtime + assert result is not None + assert result.id == 'test-sandbox-123' + assert result.status == SandboxStatus.MISSING # No runtime means MISSING + + class TestUtilityFunctions: """Test cases for utility functions.""" @@ -1011,6 +1210,27 @@ class TestUtilityFunctions: result = _build_service_url('http://localhost:8000', 'work-1') assert result == 'http://work-1-localhost:8000' + def test_hash_session_api_key(self): + """Test _hash_session_api_key function.""" + from openhands.app_server.sandbox.remote_sandbox_service import ( + _hash_session_api_key, + ) + + # Test that same input always produces same hash + key = 'test-session-api-key' + hash1 = _hash_session_api_key(key) + hash2 = _hash_session_api_key(key) + assert hash1 == hash2 + + # Test that different inputs produce different hashes + key2 = 'another-session-api-key' + hash3 = _hash_session_api_key(key2) + assert hash1 != hash3 + + # Test that hash is a 64-character hex string (SHA-256) + assert len(hash1) == 64 + assert all(c in '0123456789abcdef' for c in hash1) + class TestConstants: """Test cases for constants and mappings."""