mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 13:52:43 +08:00
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
531 lines
21 KiB
Python
531 lines
21 KiB
Python
"""Tests for DbSessionInjector.
|
|
|
|
This module tests the database service implementation, focusing on:
|
|
- Session management and reuse within request contexts
|
|
- Configuration processing from environment variables
|
|
- Connection string generation for different database types (GCP, PostgreSQL, SQLite)
|
|
- Engine creation and caching behavior
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from pydantic import SecretStr
|
|
from sqlalchemy import Engine
|
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
# Mock the storage.database module to avoid import-time engine creation
|
|
mock_storage_database = MagicMock()
|
|
mock_storage_database.sessionmaker = sessionmaker
|
|
sys.modules['storage.database'] = mock_storage_database
|
|
|
|
# Mock database drivers to avoid import errors
|
|
sys.modules['pg8000'] = MagicMock()
|
|
sys.modules['asyncpg'] = MagicMock()
|
|
sys.modules['google.cloud.sql.connector'] = MagicMock()
|
|
|
|
# Import after mocking to avoid import-time issues
|
|
from openhands.app_server.services.db_session_injector import ( # noqa: E402
|
|
DbSessionInjector,
|
|
)
|
|
|
|
|
|
class MockRequest:
|
|
"""Mock FastAPI Request object for testing."""
|
|
|
|
def __init__(self):
|
|
self.state = MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_persistence_dir():
|
|
"""Create a temporary directory for testing."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
yield Path(temp_dir)
|
|
|
|
|
|
@pytest.fixture
|
|
def basic_db_session_injector(temp_persistence_dir):
|
|
"""Create a basic DbSessionInjector instance for testing."""
|
|
return DbSessionInjector(persistence_dir=temp_persistence_dir)
|
|
|
|
|
|
@pytest.fixture
|
|
def postgres_db_session_injector(temp_persistence_dir):
|
|
"""Create a DbSessionInjector instance configured for PostgreSQL."""
|
|
return DbSessionInjector(
|
|
persistence_dir=temp_persistence_dir,
|
|
host='localhost',
|
|
port=5432,
|
|
name='test_db',
|
|
user='test_user',
|
|
password=SecretStr('test_password'),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def gcp_db_session_injector(temp_persistence_dir):
|
|
"""Create a DbSessionInjector instance configured for GCP Cloud SQL."""
|
|
return DbSessionInjector(
|
|
persistence_dir=temp_persistence_dir,
|
|
gcp_db_instance='test-instance',
|
|
gcp_project='test-project',
|
|
gcp_region='us-central1',
|
|
name='test_db',
|
|
user='test_user',
|
|
password=SecretStr('test_password'),
|
|
)
|
|
|
|
|
|
class TestDbSessionInjectorConfiguration:
|
|
"""Test configuration processing and environment variable handling."""
|
|
|
|
def test_default_configuration(self, temp_persistence_dir):
|
|
"""Test default configuration values."""
|
|
service = DbSessionInjector(persistence_dir=temp_persistence_dir)
|
|
|
|
assert service.persistence_dir == temp_persistence_dir
|
|
assert service.host is None
|
|
assert service.port == 5432 # Default from env var processing
|
|
assert service.name == 'openhands' # Default from env var processing
|
|
assert service.user == 'postgres' # Default from env var processing
|
|
assert (
|
|
service.password.get_secret_value() == 'postgres'
|
|
) # Default from env var processing
|
|
assert service.echo is False
|
|
assert service.pool_size == 25
|
|
assert service.max_overflow == 10
|
|
assert service.gcp_db_instance is None
|
|
assert service.gcp_project is None
|
|
assert service.gcp_region is None
|
|
|
|
def test_environment_variable_processing(self, temp_persistence_dir):
|
|
"""Test that environment variables are properly processed."""
|
|
env_vars = {
|
|
'DB_HOST': 'env_host',
|
|
'DB_PORT': '3306',
|
|
'DB_NAME': 'env_db',
|
|
'DB_USER': 'env_user',
|
|
'DB_PASS': 'env_password',
|
|
'GCP_DB_INSTANCE': 'env_instance',
|
|
'GCP_PROJECT': 'env_project',
|
|
'GCP_REGION': 'env_region',
|
|
}
|
|
|
|
with patch.dict(os.environ, env_vars):
|
|
service = DbSessionInjector(persistence_dir=temp_persistence_dir)
|
|
|
|
assert service.host == 'env_host'
|
|
assert service.port == 3306
|
|
assert service.name == 'env_db'
|
|
assert service.user == 'env_user'
|
|
assert service.password.get_secret_value() == 'env_password'
|
|
assert service.gcp_db_instance == 'env_instance'
|
|
assert service.gcp_project == 'env_project'
|
|
assert service.gcp_region == 'env_region'
|
|
|
|
def test_explicit_values_override_env_vars(self, temp_persistence_dir):
|
|
"""Test that explicitly provided values override environment variables."""
|
|
env_vars = {
|
|
'DB_HOST': 'env_host',
|
|
'DB_PORT': '3306',
|
|
'DB_NAME': 'env_db',
|
|
'DB_USER': 'env_user',
|
|
'DB_PASS': 'env_password',
|
|
}
|
|
|
|
with patch.dict(os.environ, env_vars):
|
|
service = DbSessionInjector(
|
|
persistence_dir=temp_persistence_dir,
|
|
host='explicit_host',
|
|
port=5432,
|
|
name='explicit_db',
|
|
user='explicit_user',
|
|
password=SecretStr('explicit_password'),
|
|
)
|
|
|
|
assert service.host == 'explicit_host'
|
|
assert service.port == 5432
|
|
assert service.name == 'explicit_db'
|
|
assert service.user == 'explicit_user'
|
|
assert service.password.get_secret_value() == 'explicit_password'
|
|
|
|
|
|
class TestDbSessionInjectorConnections:
|
|
"""Test database connection string generation and engine creation."""
|
|
|
|
def test_sqlite_connection_fallback(self, basic_db_session_injector):
|
|
"""Test SQLite connection when no host is defined."""
|
|
engine = basic_db_session_injector.get_db_engine()
|
|
|
|
assert isinstance(engine, Engine)
|
|
expected_url = (
|
|
f'sqlite:///{basic_db_session_injector.persistence_dir}/openhands.db'
|
|
)
|
|
assert str(engine.url) == expected_url
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sqlite_async_connection_fallback(self, basic_db_session_injector):
|
|
"""Test SQLite async connection when no host is defined."""
|
|
engine = await basic_db_session_injector.get_async_db_engine()
|
|
|
|
assert isinstance(engine, AsyncEngine)
|
|
expected_url = f'sqlite+aiosqlite:///{basic_db_session_injector.persistence_dir}/openhands.db'
|
|
assert str(engine.url) == expected_url
|
|
|
|
def test_postgres_connection_with_host(self, postgres_db_session_injector):
|
|
"""Test PostgreSQL connection when host is defined."""
|
|
with patch(
|
|
'openhands.app_server.services.db_session_injector.create_engine'
|
|
) as mock_create_engine:
|
|
mock_engine = MagicMock()
|
|
mock_create_engine.return_value = mock_engine
|
|
|
|
engine = postgres_db_session_injector.get_db_engine()
|
|
|
|
assert engine == mock_engine
|
|
# Check that create_engine was called with the right parameters
|
|
assert mock_create_engine.call_count == 1
|
|
call_args = mock_create_engine.call_args
|
|
|
|
# Verify the URL contains the expected components
|
|
url_str = str(call_args[0][0])
|
|
assert 'postgresql+pg8000://' in url_str
|
|
assert 'test_user' in url_str
|
|
# Password may be masked in URL string representation
|
|
assert 'test_password' in url_str or '***' in url_str
|
|
assert 'localhost:5432' in url_str
|
|
assert 'test_db' in url_str
|
|
|
|
# Verify other parameters
|
|
assert call_args[1]['pool_size'] == 25
|
|
assert call_args[1]['max_overflow'] == 10
|
|
assert call_args[1]['pool_pre_ping']
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_postgres_async_connection_with_host(
|
|
self, postgres_db_session_injector
|
|
):
|
|
"""Test PostgreSQL async connection when host is defined."""
|
|
with patch(
|
|
'openhands.app_server.services.db_session_injector.create_async_engine'
|
|
) as mock_create_async_engine:
|
|
mock_engine = MagicMock()
|
|
mock_create_async_engine.return_value = mock_engine
|
|
|
|
engine = await postgres_db_session_injector.get_async_db_engine()
|
|
|
|
assert engine == mock_engine
|
|
# Check that create_async_engine was called with the right parameters
|
|
assert mock_create_async_engine.call_count == 1
|
|
call_args = mock_create_async_engine.call_args
|
|
|
|
# Verify the URL contains the expected components
|
|
url_str = str(call_args[0][0])
|
|
assert 'postgresql+asyncpg://' in url_str
|
|
assert 'test_user' in url_str
|
|
# Password may be masked in URL string representation
|
|
assert 'test_password' in url_str or '***' in url_str
|
|
assert 'localhost:5432' in url_str
|
|
assert 'test_db' in url_str
|
|
|
|
# Verify other parameters
|
|
assert call_args[1]['pool_size'] == 25
|
|
assert call_args[1]['max_overflow'] == 10
|
|
assert call_args[1]['pool_pre_ping']
|
|
|
|
@patch(
|
|
'openhands.app_server.services.db_session_injector.DbSessionInjector._create_gcp_engine'
|
|
)
|
|
def test_gcp_connection_configuration(
|
|
self, mock_create_gcp_engine, gcp_db_session_injector
|
|
):
|
|
"""Test GCP Cloud SQL connection configuration."""
|
|
mock_engine = MagicMock()
|
|
mock_create_gcp_engine.return_value = mock_engine
|
|
|
|
engine = gcp_db_session_injector.get_db_engine()
|
|
|
|
assert engine == mock_engine
|
|
mock_create_gcp_engine.assert_called_once()
|
|
|
|
@patch(
|
|
'openhands.app_server.services.db_session_injector.DbSessionInjector._create_async_gcp_engine'
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_gcp_async_connection_configuration(
|
|
self, mock_create_async_gcp_engine, gcp_db_session_injector
|
|
):
|
|
"""Test GCP Cloud SQL async connection configuration."""
|
|
mock_engine = AsyncMock()
|
|
mock_create_async_gcp_engine.return_value = mock_engine
|
|
|
|
engine = await gcp_db_session_injector.get_async_db_engine()
|
|
|
|
assert engine == mock_engine
|
|
mock_create_async_gcp_engine.assert_called_once()
|
|
|
|
|
|
class TestDbSessionInjectorEngineReuse:
|
|
"""Test engine creation and caching behavior."""
|
|
|
|
def test_sync_engine_reuse(self, basic_db_session_injector):
|
|
"""Test that sync engines are cached and reused."""
|
|
engine1 = basic_db_session_injector.get_db_engine()
|
|
engine2 = basic_db_session_injector.get_db_engine()
|
|
|
|
assert engine1 is engine2
|
|
assert basic_db_session_injector._engine is engine1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_engine_reuse(self, basic_db_session_injector):
|
|
"""Test that async engines are cached and reused."""
|
|
engine1 = await basic_db_session_injector.get_async_db_engine()
|
|
engine2 = await basic_db_session_injector.get_async_db_engine()
|
|
|
|
assert engine1 is engine2
|
|
assert basic_db_session_injector._async_engine is engine1
|
|
|
|
def test_session_maker_reuse(self, basic_db_session_injector):
|
|
"""Test that session makers are cached and reused."""
|
|
session_maker1 = basic_db_session_injector.get_session_maker()
|
|
session_maker2 = basic_db_session_injector.get_session_maker()
|
|
|
|
assert session_maker1 is session_maker2
|
|
assert basic_db_session_injector._session_maker is session_maker1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_session_maker_reuse(self, basic_db_session_injector):
|
|
"""Test that async session makers are cached and reused."""
|
|
session_maker1 = await basic_db_session_injector.get_async_session_maker()
|
|
session_maker2 = await basic_db_session_injector.get_async_session_maker()
|
|
|
|
assert session_maker1 is session_maker2
|
|
assert basic_db_session_injector._async_session_maker is session_maker1
|
|
|
|
|
|
class TestDbSessionInjectorSessionManagement:
|
|
"""Test session management and reuse within request contexts."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_depends_reuse_within_request(self, basic_db_session_injector):
|
|
"""Test that managed sessions are reused within the same request context."""
|
|
request = MockRequest()
|
|
|
|
# First call should create a new session and store it in request state
|
|
session_generator1 = basic_db_session_injector.depends(request)
|
|
session1 = await session_generator1.__anext__()
|
|
|
|
# Verify session is stored in request state
|
|
assert hasattr(request.state, 'db_session')
|
|
assert request.state.db_session is session1
|
|
|
|
# Second call should return the same session from request state
|
|
session_generator2 = basic_db_session_injector.depends(request)
|
|
session2 = await session_generator2.__anext__()
|
|
|
|
assert session1 is session2
|
|
|
|
# Clean up generators
|
|
try:
|
|
await session_generator1.__anext__()
|
|
except StopAsyncIteration:
|
|
pass
|
|
try:
|
|
await session_generator2.__anext__()
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_depends_cleanup_on_completion(self, basic_db_session_injector):
|
|
"""Test that managed sessions are properly cleaned up after request completion."""
|
|
request = MockRequest()
|
|
|
|
# Mock the async session maker and session
|
|
with patch(
|
|
'openhands.app_server.services.db_session_injector.async_sessionmaker'
|
|
) as mock_sessionmaker_class:
|
|
mock_session = AsyncMock()
|
|
mock_session_context = AsyncMock()
|
|
mock_session_context.__aenter__.return_value = mock_session
|
|
mock_session_context.__aexit__.return_value = None
|
|
mock_sessionmaker = MagicMock()
|
|
mock_sessionmaker.return_value = mock_session_context
|
|
mock_sessionmaker_class.return_value = mock_sessionmaker
|
|
|
|
# Use the managed session dependency
|
|
session_gen = basic_db_session_injector.depends(request)
|
|
session = await session_gen.__anext__()
|
|
|
|
assert hasattr(request.state, 'db_session')
|
|
assert request.state.db_session is session
|
|
|
|
# Simulate completion by exhausting the generator
|
|
try:
|
|
await session_gen.__anext__()
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
# After completion, session should be cleaned up from request state
|
|
# Note: cleanup only happens when a new session is created, not when reusing
|
|
# Since we're mocking the session maker, the cleanup behavior depends on the mock setup
|
|
# For this test, we verify that the session was created and stored properly
|
|
assert session is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_depends_rollback_on_exception(self, basic_db_session_injector):
|
|
"""Test that managed sessions are rolled back on exceptions."""
|
|
request = MockRequest()
|
|
|
|
# Mock the async session maker and session
|
|
with patch(
|
|
'openhands.app_server.services.db_session_injector.async_sessionmaker'
|
|
) as mock_sessionmaker_class:
|
|
mock_session = AsyncMock()
|
|
mock_session_context = AsyncMock()
|
|
mock_session_context.__aenter__.return_value = mock_session
|
|
mock_session_context.__aexit__.return_value = None
|
|
mock_sessionmaker = MagicMock()
|
|
mock_sessionmaker.return_value = mock_session_context
|
|
mock_sessionmaker_class.return_value = mock_sessionmaker
|
|
|
|
session_gen = basic_db_session_injector.depends(request)
|
|
session = await session_gen.__anext__()
|
|
|
|
# The actual rollback testing would require more complex mocking
|
|
# For now, just verify the session was created
|
|
assert session is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_session_dependency_creates_new_sessions(
|
|
self, basic_db_session_injector
|
|
):
|
|
"""Test that async_session dependency creates new sessions each time."""
|
|
session_generator1 = basic_db_session_injector.async_session()
|
|
session1 = await session_generator1.__anext__()
|
|
|
|
session_generator2 = basic_db_session_injector.async_session()
|
|
session2 = await session_generator2.__anext__()
|
|
|
|
# These should be different sessions since async_session doesn't use request state
|
|
assert session1 is not session2
|
|
|
|
# Clean up generators
|
|
try:
|
|
await session_generator1.__anext__()
|
|
except StopAsyncIteration:
|
|
pass
|
|
try:
|
|
await session_generator2.__anext__()
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
|
|
class TestDbSessionInjectorGCPIntegration:
|
|
"""Test GCP-specific functionality."""
|
|
|
|
def test_gcp_connection_creation(self, gcp_db_session_injector):
|
|
"""Test GCP database connection creation."""
|
|
# Mock the google.cloud.sql.connector module
|
|
with patch.dict('sys.modules', {'google.cloud.sql.connector': MagicMock()}):
|
|
mock_connector_module = sys.modules['google.cloud.sql.connector']
|
|
mock_connector = MagicMock()
|
|
mock_connector_module.Connector.return_value = mock_connector
|
|
mock_connection = MagicMock()
|
|
mock_connector.connect.return_value = mock_connection
|
|
|
|
connection = gcp_db_session_injector._create_gcp_db_connection()
|
|
|
|
assert connection == mock_connection
|
|
mock_connector.connect.assert_called_once_with(
|
|
'test-project:us-central1:test-instance',
|
|
'pg8000',
|
|
user='test_user',
|
|
password='test_password',
|
|
db='test_db',
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_gcp_async_connection_creation(self, gcp_db_session_injector):
|
|
"""Test GCP async database connection creation."""
|
|
# Mock the google.cloud.sql.connector module
|
|
with patch.dict('sys.modules', {'google.cloud.sql.connector': MagicMock()}):
|
|
mock_connector_module = sys.modules['google.cloud.sql.connector']
|
|
mock_connector = AsyncMock()
|
|
mock_connector_module.Connector.return_value.__aenter__.return_value = (
|
|
mock_connector
|
|
)
|
|
mock_connector_module.Connector.return_value.__aexit__.return_value = None
|
|
mock_connection = AsyncMock()
|
|
mock_connector.connect_async.return_value = mock_connection
|
|
|
|
connection = await gcp_db_session_injector._create_async_gcp_db_connection()
|
|
|
|
assert connection == mock_connection
|
|
mock_connector.connect_async.assert_called_once_with(
|
|
'test-project:us-central1:test-instance',
|
|
'asyncpg',
|
|
user='test_user',
|
|
password='test_password',
|
|
db='test_db',
|
|
)
|
|
|
|
|
|
class TestDbSessionInjectorEdgeCases:
|
|
"""Test edge cases and error conditions."""
|
|
|
|
def test_none_password_handling(self, temp_persistence_dir):
|
|
"""Test handling of None password values."""
|
|
with patch(
|
|
'openhands.app_server.services.db_session_injector.create_engine'
|
|
) as mock_create_engine:
|
|
mock_engine = MagicMock()
|
|
mock_create_engine.return_value = mock_engine
|
|
|
|
service = DbSessionInjector(
|
|
persistence_dir=temp_persistence_dir, host='localhost', password=None
|
|
)
|
|
|
|
# Should not raise an exception
|
|
engine = service.get_db_engine()
|
|
assert engine == mock_engine
|
|
|
|
def test_empty_string_password_from_env(self, temp_persistence_dir):
|
|
"""Test handling of empty string password from environment."""
|
|
with patch.dict(os.environ, {'DB_PASS': ''}):
|
|
service = DbSessionInjector(persistence_dir=temp_persistence_dir)
|
|
assert service.password.get_secret_value() == ''
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_request_contexts_isolated(self, basic_db_session_injector):
|
|
"""Test that different request contexts have isolated sessions."""
|
|
request1 = MockRequest()
|
|
request2 = MockRequest()
|
|
|
|
# Create sessions for different requests
|
|
session_gen1 = basic_db_session_injector.depends(request1)
|
|
session1 = await session_gen1.__anext__()
|
|
|
|
session_gen2 = basic_db_session_injector.depends(request2)
|
|
session2 = await session_gen2.__anext__()
|
|
|
|
# Sessions should be different for different requests
|
|
assert session1 is not session2
|
|
assert request1.state.db_session is session1
|
|
assert request2.state.db_session is session2
|
|
|
|
# Clean up generators
|
|
try:
|
|
await session_gen1.__anext__()
|
|
except StopAsyncIteration:
|
|
pass
|
|
try:
|
|
await session_gen2.__anext__()
|
|
except StopAsyncIteration:
|
|
pass
|