diff --git a/enterprise/run_maintenance_tasks.py b/enterprise/run_maintenance_tasks.py index 89eb91b821..b4fb75218b 100644 --- a/enterprise/run_maintenance_tasks.py +++ b/enterprise/run_maintenance_tasks.py @@ -21,11 +21,12 @@ async def main(): def set_stale_task_error(): + # started_at is naive UTC; strip tzinfo before comparing. + cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=1) with session_maker() as session: session.query(MaintenanceTask).filter( MaintenanceTask.status == MaintenanceTaskStatus.WORKING, - MaintenanceTask.started_at - < datetime.now(timezone.utc) - timedelta(hours=1), + MaintenanceTask.started_at < cutoff, ).update({MaintenanceTask.status: MaintenanceTaskStatus.ERROR}) session.commit() @@ -37,9 +38,10 @@ async def run_tasks(): if not task: return - # Update the status + # started_at/updated_at are naive UTC; strip tzinfo. + now_utc = datetime.now(timezone.utc).replace(tzinfo=None) task.status = MaintenanceTaskStatus.WORKING - task.updated_at = task.started_at = datetime.now(timezone.utc) + task.updated_at = task.started_at = now_utc session.commit() try: diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index d514b70693..74a2d3d73e 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -31,7 +31,8 @@ class ApiKeyStore: Args: user_id: The ID of the user to create the key for name: Optional name for the key - expires_at: Optional expiration date for the key + expires_at: Expiration datetime in UTC. Timezone info is stripped before + writing to the TIMESTAMP WITHOUT TIME ZONE column. Returns: The generated API key @@ -42,6 +43,10 @@ class ApiKeyStore: raise ValueError(f'User not found: {user_id}') org_id = user.current_org_id + # Column is TIMESTAMP WITHOUT TIME ZONE; strip tzinfo before writing. + if expires_at is not None and expires_at.tzinfo is not None: + expires_at = expires_at.replace(tzinfo=None) + async with a_session_maker() as session: key_record = ApiKey( key=api_key, @@ -66,9 +71,8 @@ class ApiKeyStore: if not key_record: return None - # Check if the key has expired + # expires_at is stored as naive UTC; re-attach tzinfo for comparison. if key_record.expires_at: - # Handle timezone-naive datetime from database by assuming it's UTC expires_at = key_record.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=UTC) diff --git a/enterprise/tests/unit/test_api_key_store.py b/enterprise/tests/unit/test_api_key_store.py index 26f96d3f03..d3a2d13d1e 100644 --- a/enterprise/tests/unit/test_api_key_store.py +++ b/enterprise/tests/unit/test_api_key_store.py @@ -55,6 +55,28 @@ def test_generate_api_key(api_key_store): assert len(key) == len('sk-oh-') + 32 +@pytest.mark.asyncio +@patch('storage.api_key_store.UserStore.get_user_by_id') +async def test_create_api_key_strips_timezone_from_expires_at( + mock_get_user, api_key_store, async_session_maker, mock_user +): + """Timezone-aware expires_at must be stored as naive UTC without shifting the value.""" + user_id = str(uuid.uuid4()) + aware_expiry = datetime.now(UTC) + timedelta(days=30) + mock_get_user.return_value = mock_user + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + key = await api_key_store.create_api_key(user_id, expires_at=aware_expiry) + + async with async_session_maker() as session: + result = await session.execute(select(ApiKey).filter(ApiKey.key == key)) + record = result.scalars().first() + + assert record.expires_at is not None + assert record.expires_at.tzinfo is None + assert record.expires_at == aware_expiry.replace(tzinfo=None) + + @pytest.mark.asyncio @patch('storage.api_key_store.UserStore.get_user_by_id') async def test_create_api_key( diff --git a/enterprise/tests/unit/test_run_maintenance_tasks.py b/enterprise/tests/unit/test_run_maintenance_tasks.py index d7456ff2a8..8efba13442 100644 --- a/enterprise/tests/unit/test_run_maintenance_tasks.py +++ b/enterprise/tests/unit/test_run_maintenance_tasks.py @@ -304,6 +304,76 @@ class TestRunMaintenanceTasks: assert 'error' in updated_task.info assert updated_task.info['error'] == 'Test error' + def test_set_stale_task_error_uses_naive_utc_cutoff(self, session_maker): + """set_stale_task_error must compare against a naive UTC cutoff.""" + # Create tasks using naive UTC started_at values (matching what run_tasks writes). + with session_maker() as session: + stale = MaintenanceTask( + status=MaintenanceTaskStatus.WORKING, + processor_type='test.processor', + processor_json='{}', + started_at=datetime.now(timezone.utc).replace(tzinfo=None) + - timedelta(hours=2), + ) + recent = MaintenanceTask( + status=MaintenanceTaskStatus.WORKING, + processor_type='test.processor', + processor_json='{}', + started_at=datetime.now(timezone.utc).replace(tzinfo=None) + - timedelta(minutes=30), + ) + session.add_all([stale, recent]) + session.commit() + stale_id, recent_id = stale.id, recent.id + + with patch('run_maintenance_tasks.session_maker', return_value=session_maker()): + set_stale_task_error() + + with session_maker() as session: + assert ( + session.get(MaintenanceTask, stale_id).status + == MaintenanceTaskStatus.ERROR + ) + assert ( + session.get(MaintenanceTask, recent_id).status + == MaintenanceTaskStatus.WORKING + ) + + @pytest.mark.asyncio + async def test_run_tasks_stores_naive_utc_started_at(self, session_maker): + """run_tasks must write naive UTC datetimes to started_at and updated_at.""" + processor = AsyncMock(return_value={}) + + with session_maker() as session: + task = MaintenanceTask( + status=MaintenanceTaskStatus.PENDING, + processor_type='test.processor', + processor_json='{}', + ) + session.add(task) + session.commit() + task_id = task.id + + with patch( + 'storage.maintenance_task.MaintenanceTask.get_processor', + return_value=processor, + ): + with patch( + 'run_maintenance_tasks.session_maker', return_value=session_maker() + ): + with patch('asyncio.sleep', new_callable=AsyncMock): + try: + await asyncio.wait_for(run_tasks(), timeout=1.0) + except asyncio.TimeoutError: + pass + + with session_maker() as session: + updated = session.get(MaintenanceTask, task_id) + assert updated.started_at is not None + assert updated.started_at.tzinfo is None + assert updated.updated_at is not None + assert updated.updated_at.tzinfo is None + @pytest.mark.asyncio async def test_run_tasks_respects_delay(self, session_maker): """Test that run_tasks respects the delay parameter.""" diff --git a/openhands/app_server/services/db_session_injector.py b/openhands/app_server/services/db_session_injector.py index 063797a63d..2d3f90be36 100644 --- a/openhands/app_server/services/db_session_injector.py +++ b/openhands/app_server/services/db_session_injector.py @@ -126,10 +126,11 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): async def _create_async_gcp_creator(self): from sqlalchemy.dialects.postgresql.asyncpg import ( AsyncAdapt_asyncpg_connection, + AsyncAdapt_asyncpg_dbapi, ) return AsyncAdapt_asyncpg_connection( - asyncpg, + AsyncAdapt_asyncpg_dbapi(asyncpg), await self._create_async_gcp_db_connection(), prepared_statement_cache_size=100, ) @@ -137,11 +138,14 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): async def _create_async_gcp_engine(self): from sqlalchemy.dialects.postgresql.asyncpg import ( AsyncAdapt_asyncpg_connection, + AsyncAdapt_asyncpg_dbapi, ) + dbapi = AsyncAdapt_asyncpg_dbapi(asyncpg) + def adapted_creator(): return AsyncAdapt_asyncpg_connection( - asyncpg, + dbapi, await_only(self._create_async_gcp_db_connection()), prepared_statement_cache_size=100, )