mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Refactor enterprise database.py to use DbSessionInjector (#12446)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -5,7 +5,7 @@ from threading import Thread
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from sqlalchemy import func, select
|
||||
from storage.database import a_session_maker, engine, session_maker
|
||||
from storage.database import a_session_maker, get_engine, session_maker
|
||||
from storage.user import User
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -47,6 +47,7 @@ def add_debugging_routes(api: FastAPI):
|
||||
- checked_out: Number of connections currently in use
|
||||
- overflow: Number of overflow connections created beyond pool_size
|
||||
"""
|
||||
engine = get_engine()
|
||||
return {
|
||||
'checked_in': engine.pool.checkedin(),
|
||||
'checked_out': engine.pool.checkedout(),
|
||||
@@ -129,6 +130,7 @@ def _db_check(delay: int):
|
||||
with session_maker() as session:
|
||||
num_users = session.query(User).count()
|
||||
time.sleep(delay)
|
||||
engine = get_engine()
|
||||
logger.info(
|
||||
'check',
|
||||
extra={
|
||||
|
||||
@@ -1,126 +1,38 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
"""
|
||||
Database connection module for enterprise storage.
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.util import await_only
|
||||
This is for backwards compatibility with V0.
|
||||
|
||||
# Check if we're running in a test environment
|
||||
IS_TESTING = 'pytest' in sys.modules
|
||||
This module provides database engines and session makers by delegating to the
|
||||
centralized DbSessionInjector from app_server/config.py. This ensures a single
|
||||
source of truth for database connection configuration.
|
||||
"""
|
||||
|
||||
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
|
||||
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
|
||||
DB_USER = os.environ.get('DB_USER', 'postgres')
|
||||
DB_PASS = os.environ.get('DB_PASS', 'postgres').strip()
|
||||
DB_NAME = os.environ.get('DB_NAME', 'openhands')
|
||||
|
||||
GCP_DB_INSTANCE = os.environ.get('GCP_DB_INSTANCE') # for GCP environments
|
||||
GCP_PROJECT = os.environ.get('GCP_PROJECT')
|
||||
GCP_REGION = os.environ.get('GCP_REGION')
|
||||
|
||||
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
|
||||
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
|
||||
|
||||
# Initialize Cloud SQL Connector once at module level for GCP environments.
|
||||
_connector = None
|
||||
import contextlib
|
||||
|
||||
|
||||
def _get_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
def _get_db_session_injector():
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
def get_db_connection():
|
||||
global _connector
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
if not _connector:
|
||||
_connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return _connector.connect(
|
||||
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
|
||||
)
|
||||
|
||||
return create_engine(
|
||||
'postgresql+pg8000://',
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
else:
|
||||
host_string = (
|
||||
f'postgresql+pg8000://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
return create_engine(
|
||||
host_string,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
_config = get_global_config()
|
||||
return _config.db_session
|
||||
|
||||
|
||||
async def async_creator():
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
async with Connector(loop=loop) as connector:
|
||||
conn = await connector.connect_async(
|
||||
f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}', # Cloud SQL instance connection name"
|
||||
'asyncpg',
|
||||
user=DB_USER,
|
||||
password=DB_PASS,
|
||||
db=DB_NAME,
|
||||
)
|
||||
return conn
|
||||
def session_maker():
|
||||
db_session_injector = _get_db_session_injector()
|
||||
session_maker = db_session_injector.get_session_maker()
|
||||
return session_maker()
|
||||
|
||||
|
||||
def _get_async_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
|
||||
def adapted_creator():
|
||||
dbapi = engine.dialect.dbapi
|
||||
from sqlalchemy.dialects.postgresql.asyncpg import (
|
||||
AsyncAdapt_asyncpg_connection,
|
||||
)
|
||||
|
||||
return AsyncAdapt_asyncpg_connection(
|
||||
dbapi,
|
||||
await_only(async_creator()),
|
||||
prepared_statement_cache_size=100,
|
||||
)
|
||||
|
||||
# create async connection pool with wrapped creator
|
||||
return create_async_engine(
|
||||
'postgresql+asyncpg://',
|
||||
creator=adapted_creator,
|
||||
# Use NullPool to disable connection pooling and avoid event loop issues
|
||||
poolclass=NullPool,
|
||||
)
|
||||
else:
|
||||
host_string = (
|
||||
f'postgresql+asyncpg://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
return create_async_engine(
|
||||
host_string,
|
||||
# Use NullPool to disable connection pooling and avoid event loop issues
|
||||
poolclass=NullPool,
|
||||
)
|
||||
@contextlib.asynccontextmanager
|
||||
async def a_session_maker():
|
||||
db_session_injector = _get_db_session_injector()
|
||||
a_session_maker = await db_session_injector.get_async_session_maker()
|
||||
async with a_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
engine = _get_db_engine()
|
||||
session_maker = sessionmaker(bind=engine)
|
||||
|
||||
a_engine = _get_async_db_engine()
|
||||
a_session_maker = sessionmaker(
|
||||
bind=a_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
# Configure the session to use the same connection for all operations in a transaction
|
||||
# This helps prevent the "Task got Future attached to a different loop" error
|
||||
future=True,
|
||||
)
|
||||
def get_engine():
|
||||
db_session_injector = _get_db_session_injector()
|
||||
engine = db_session_injector.get_db_engine()
|
||||
return engine
|
||||
|
||||
@@ -21,7 +21,7 @@ from sqlalchemy import text
|
||||
# Add the parent directory to the path so we can import from storage
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from server.auth.token_manager import get_keycloak_admin
|
||||
from storage.database import engine
|
||||
from storage.database import get_engine
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -85,7 +85,7 @@ def get_recent_conversations(minutes: int = 60) -> List[Dict[str, Any]]:
|
||||
created_at DESC
|
||||
""")
|
||||
|
||||
with engine.connect() as connection:
|
||||
with get_engine().connect() as connection:
|
||||
result = connection.execute(query, {'minutes': minutes})
|
||||
conversations = [
|
||||
{
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy import text
|
||||
# Add the parent directory to the path so we can import from storage
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from storage.database import engine
|
||||
from storage.database import get_engine
|
||||
|
||||
|
||||
def test_conversation_count_query():
|
||||
@@ -29,6 +29,8 @@ def test_conversation_count_query():
|
||||
user_id
|
||||
""")
|
||||
|
||||
engine = get_engine()
|
||||
|
||||
with engine.connect() as connection:
|
||||
count_result = connection.execute(count_query)
|
||||
user_counts = [
|
||||
|
||||
Reference in New Issue
Block a user