diff --git a/enterprise/storage/database.py b/enterprise/storage/database.py index 57e0c833a7..ec06550e03 100644 --- a/enterprise/storage/database.py +++ b/enterprise/storage/database.py @@ -21,16 +21,21 @@ POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25')) MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10')) POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800')) +# Initialize Cloud SQL Connector once at module level for GCP environments. +_connector = None + def _get_db_engine(): if GCP_DB_INSTANCE: # GCP environments def get_db_connection(): + global _connector from google.cloud.sql.connector import Connector - connector = Connector() + if not _connector: + _connector = Connector() instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}' - return connector.connect( + return _connector.connect( instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME ) diff --git a/openhands/app_server/app_conversation/app_conversation_router.py b/openhands/app_server/app_conversation/app_conversation_router.py index 532602dbca..f5d2b996c7 100644 --- a/openhands/app_server/app_conversation/app_conversation_router.py +++ b/openhands/app_server/app_conversation/app_conversation_router.py @@ -210,11 +210,17 @@ async def start_app_conversation( set_db_session_keep_open(request.state, True) set_httpx_client_keep_open(request.state, True) - """Start an app conversation start task and return it.""" - async_iter = app_conversation_service.start_app_conversation(start_request) - result = await anext(async_iter) - asyncio.create_task(_consume_remaining(async_iter, db_session, httpx_client)) - return result + try: + """Start an app conversation start task and return it.""" + async_iter = app_conversation_service.start_app_conversation(start_request) + result = await anext(async_iter) + asyncio.create_task(_consume_remaining(async_iter, db_session, httpx_client)) + return result + except Exception: + await db_session.close() + await httpx_client.aclose() + raise + @router.post('/stream-start') diff --git a/openhands/app_server/event_callback/webhook_router.py b/openhands/app_server/event_callback/webhook_router.py index 37ae9d89b2..62dd7bec16 100644 --- a/openhands/app_server/event_callback/webhook_router.py +++ b/openhands/app_server/event_callback/webhook_router.py @@ -21,12 +21,10 @@ from openhands.app_server.app_conversation.app_conversation_models import ( ) from openhands.app_server.config import ( depends_app_conversation_info_service, - depends_db_session, depends_event_service, depends_jwt_service, depends_sandbox_service, get_event_callback_service, - get_global_config, ) from openhands.app_server.errors import AuthError from openhands.app_server.event.event_service import EventService @@ -54,8 +52,6 @@ sandbox_service_dependency = depends_sandbox_service() event_service_dependency = depends_event_service() app_conversation_info_service_dependency = depends_app_conversation_info_service() jwt_dependency = depends_jwt_service() -config = get_global_config() -db_session_dependency = depends_db_session() _logger = logging.getLogger(__name__) diff --git a/openhands/app_server/services/db_session_injector.py b/openhands/app_server/services/db_session_injector.py index b7fd404f00..737e1ff879 100644 --- a/openhands/app_server/services/db_session_injector.py +++ b/openhands/app_server/services/db_session_injector.py @@ -4,7 +4,7 @@ import asyncio import logging import os from pathlib import Path -from typing import AsyncGenerator +from typing import Any, AsyncGenerator import asyncpg from fastapi import Request @@ -44,6 +44,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): _async_engine: AsyncEngine | None = PrivateAttr(default=None) _session_maker: sessionmaker | None = PrivateAttr(default=None) _async_session_maker: async_sessionmaker | None = PrivateAttr(default=None) + _gcp_connector: Any = PrivateAttr(default=None) @model_validator(mode='after') def fill_empty_fields(self): @@ -67,14 +68,18 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): return self def _create_gcp_db_connection(self): - # Lazy import because lib does not import if user does not have posgres installed - from google.cloud.sql.connector import Connector + gcp_connector = self._gcp_connector + if gcp_connector is None: + # Lazy import because lib does not import if user does not have posgres installed + from google.cloud.sql.connector import Connector + + gcp_connector = Connector() + self._gcp_connector = gcp_connector - connector = Connector() instance_string = f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}' password = self.password assert password is not None - return connector.connect( + return gcp_connector.connect( instance_string, 'pg8000', user=self.user, @@ -83,21 +88,25 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): ) async def _create_async_gcp_db_connection(self): - # Lazy import because lib does not import if user does not have posgres installed - from google.cloud.sql.connector import Connector + gcp_connector = self._gcp_connector + if gcp_connector is None: + # Lazy import because lib does not import if user does not have posgres installed + from google.cloud.sql.connector import Connector - loop = asyncio.get_running_loop() - async with Connector(loop=loop) as connector: - password = self.password - assert password is not None - conn = await connector.connect_async( - f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}', - 'asyncpg', - user=self.user, - password=password.get_secret_value(), - db=self.name, - ) - return conn + loop = asyncio.get_running_loop() + gcp_connector = Connector(loop=loop) + self._gcp_connector = gcp_connector + + password = self.password + assert password is not None + conn = await gcp_connector.connect_async( + f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}', + 'asyncpg', + user=self.user, + password=password.get_secret_value(), + db=self.name, + ) + return conn def _create_gcp_engine(self): engine = create_engine(