Refactor enterprise database.py to use DbSessionInjector (#12446)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-01-16 14:38:05 -07:00
committed by GitHub
parent a987387353
commit 0d5f97c8c7
4 changed files with 34 additions and 118 deletions

View File

@@ -5,7 +5,7 @@ from threading import Thread
from fastapi import APIRouter, FastAPI
from sqlalchemy import func, select
from storage.database import a_session_maker, engine, session_maker
from storage.database import a_session_maker, get_engine, session_maker
from storage.user import User
from openhands.core.logger import openhands_logger as logger
@@ -47,6 +47,7 @@ def add_debugging_routes(api: FastAPI):
- checked_out: Number of connections currently in use
- overflow: Number of overflow connections created beyond pool_size
"""
engine = get_engine()
return {
'checked_in': engine.pool.checkedin(),
'checked_out': engine.pool.checkedout(),
@@ -129,6 +130,7 @@ def _db_check(delay: int):
with session_maker() as session:
num_users = session.query(User).count()
time.sleep(delay)
engine = get_engine()
logger.info(
'check',
extra={

View File

@@ -1,126 +1,38 @@
import asyncio
import os
import sys
"""
Database connection module for enterprise storage.
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.util import await_only
This is for backwards compatibility with V0.
# Check if we're running in a test environment
IS_TESTING = 'pytest' in sys.modules
This module provides database engines and session makers by delegating to the
centralized DbSessionInjector from app_server/config.py. This ensures a single
source of truth for database connection configuration.
"""
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
DB_USER = os.environ.get('DB_USER', 'postgres')
DB_PASS = os.environ.get('DB_PASS', 'postgres').strip()
DB_NAME = os.environ.get('DB_NAME', 'openhands')
GCP_DB_INSTANCE = os.environ.get('GCP_DB_INSTANCE') # for GCP environments
GCP_PROJECT = os.environ.get('GCP_PROJECT')
GCP_REGION = os.environ.get('GCP_REGION')
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
# Initialize Cloud SQL Connector once at module level for GCP environments.
_connector = None
import contextlib
def _get_db_engine():
if GCP_DB_INSTANCE: # GCP environments
def _get_db_session_injector():
from openhands.app_server.config import get_global_config
def get_db_connection():
global _connector
from google.cloud.sql.connector import Connector
if not _connector:
_connector = Connector()
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
return _connector.connect(
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
)
return create_engine(
'postgresql+pg8000://',
creator=get_db_connection,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)
else:
host_string = (
f'postgresql+pg8000://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
)
return create_engine(
host_string,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)
_config = get_global_config()
return _config.db_session
async def async_creator():
from google.cloud.sql.connector import Connector
loop = asyncio.get_running_loop()
async with Connector(loop=loop) as connector:
conn = await connector.connect_async(
f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}', # Cloud SQL instance connection name"
'asyncpg',
user=DB_USER,
password=DB_PASS,
db=DB_NAME,
)
return conn
def session_maker():
db_session_injector = _get_db_session_injector()
session_maker = db_session_injector.get_session_maker()
return session_maker()
def _get_async_db_engine():
if GCP_DB_INSTANCE: # GCP environments
def adapted_creator():
dbapi = engine.dialect.dbapi
from sqlalchemy.dialects.postgresql.asyncpg import (
AsyncAdapt_asyncpg_connection,
)
return AsyncAdapt_asyncpg_connection(
dbapi,
await_only(async_creator()),
prepared_statement_cache_size=100,
)
# create async connection pool with wrapped creator
return create_async_engine(
'postgresql+asyncpg://',
creator=adapted_creator,
# Use NullPool to disable connection pooling and avoid event loop issues
poolclass=NullPool,
)
else:
host_string = (
f'postgresql+asyncpg://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
)
return create_async_engine(
host_string,
# Use NullPool to disable connection pooling and avoid event loop issues
poolclass=NullPool,
)
@contextlib.asynccontextmanager
async def a_session_maker():
db_session_injector = _get_db_session_injector()
a_session_maker = await db_session_injector.get_async_session_maker()
async with a_session_maker() as session:
yield session
engine = _get_db_engine()
session_maker = sessionmaker(bind=engine)
a_engine = _get_async_db_engine()
a_session_maker = sessionmaker(
bind=a_engine,
class_=AsyncSession,
expire_on_commit=False,
# Configure the session to use the same connection for all operations in a transaction
# This helps prevent the "Task got Future attached to a different loop" error
future=True,
)
def get_engine():
db_session_injector = _get_db_session_injector()
engine = db_session_injector.get_db_engine()
return engine

View File

@@ -21,7 +21,7 @@ from sqlalchemy import text
# Add the parent directory to the path so we can import from storage
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from server.auth.token_manager import get_keycloak_admin
from storage.database import engine
from storage.database import get_engine
# Configure logging
logging.basicConfig(
@@ -85,7 +85,7 @@ def get_recent_conversations(minutes: int = 60) -> List[Dict[str, Any]]:
created_at DESC
""")
with engine.connect() as connection:
with get_engine().connect() as connection:
result = connection.execute(query, {'minutes': minutes})
conversations = [
{

View File

@@ -13,7 +13,7 @@ from sqlalchemy import text
# Add the parent directory to the path so we can import from storage
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from storage.database import engine
from storage.database import get_engine
def test_conversation_count_query():
@@ -29,6 +29,8 @@ def test_conversation_count_query():
user_id
""")
engine = get_engine()
with engine.connect() as connection:
count_result = connection.execute(count_query)
user_counts = [