From 53871f206b2c9b5ff489216ca18c149c492889b9 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 18 Dec 2025 20:40:04 +0000 Subject: [PATCH] Refactor cleanup_stale_device_codes to use modern SQLAlchemy 2.0 select() API Replace legacy session.query().filter().limit().all() pattern with the modern select().where().limit() + execute().scalars().all() pattern, which is more idiomatic and consistent with other parts of the codebase (e.g., gitlab_webhook_store.py). Co-authored-by: openhands --- enterprise/storage/device_code_store.py | 11 ++-- .../unit/storage/test_device_code_store.py | 51 +++++++++++-------- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/enterprise/storage/device_code_store.py b/enterprise/storage/device_code_store.py index af57e7e64e..d8c421dd6f 100644 --- a/enterprise/storage/device_code_store.py +++ b/enterprise/storage/device_code_store.py @@ -4,7 +4,7 @@ import secrets import string from datetime import datetime, timedelta, timezone -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.exc import IntegrityError from storage.device_code import DeviceCode @@ -182,13 +182,14 @@ class DeviceCodeStore: """ with self.session_maker() as session: # Get expired device codes, ordered by oldest first (using ID as proxy for creation order) - expired_codes = ( - session.query(DeviceCode) - .filter(DeviceCode.expires_at < datetime.now(timezone.utc)) + query = ( + select(DeviceCode) + .where(DeviceCode.expires_at < datetime.now(timezone.utc)) .order_by(DeviceCode.id.asc()) .limit(limit) - .all() ) + result = session.execute(query) + expired_codes = result.scalars().all() if not expired_codes: logger.info('No expired device codes found') diff --git a/enterprise/tests/unit/storage/test_device_code_store.py b/enterprise/tests/unit/storage/test_device_code_store.py index 1e82760eb9..75adc0d402 100644 --- a/enterprise/tests/unit/storage/test_device_code_store.py +++ b/enterprise/tests/unit/storage/test_device_code_store.py @@ -194,13 +194,17 @@ class TestDeviceCodeStore: def test_cleanup_stale_device_codes_empty(self, device_code_store, mock_session): """Test cleanup when no expired device codes exist.""" - # Mock empty query result - mock_session.query.return_value.filter.return_value.order_by.return_value.limit.return_value.all.return_value = [] + # Mock empty query result using select() pattern + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_result = MagicMock() + mock_result.scalars.return_value = mock_scalars + mock_session.execute.return_value = mock_result result = device_code_store.cleanup_stale_device_codes(limit=50) assert result == 0 - mock_session.query.assert_called_once_with(DeviceCode) + mock_session.execute.assert_called_once() def test_cleanup_stale_device_codes_with_data( self, device_code_store, mock_session @@ -212,21 +216,23 @@ class TestDeviceCodeStore: mock_device2 = MagicMock() mock_device2.id = 2 - # Mock query result with 2 expired codes - mock_session.query.return_value.filter.return_value.order_by.return_value.limit.return_value.all.return_value = [ - mock_device1, - mock_device2, - ] + # Mock query result with 2 expired codes using select() pattern + mock_scalars = MagicMock() + mock_scalars.all.return_value = [mock_device1, mock_device2] + mock_select_result = MagicMock() + mock_select_result.scalars.return_value = mock_scalars # Mock the delete execution result - mock_result = MagicMock() - mock_result.rowcount = 2 - mock_session.execute.return_value = mock_result + mock_delete_result = MagicMock() + mock_delete_result.rowcount = 2 + + # First execute call returns select result, second returns delete result + mock_session.execute.side_effect = [mock_select_result, mock_delete_result] result = device_code_store.cleanup_stale_device_codes(limit=50) assert result == 2 - mock_session.execute.assert_called_once() + assert mock_session.execute.call_count == 2 mock_session.commit.assert_called_once() def test_cleanup_stale_device_codes_with_limit( @@ -236,20 +242,21 @@ class TestDeviceCodeStore: # Create mock device codes mock_devices = [MagicMock(id=i) for i in range(1, 4)] # 3 codes - # Mock query result with 3 expired codes - mock_session.query.return_value.filter.return_value.order_by.return_value.limit.return_value.all.return_value = mock_devices + # Mock query result with 3 expired codes using select() pattern + mock_scalars = MagicMock() + mock_scalars.all.return_value = mock_devices + mock_select_result = MagicMock() + mock_select_result.scalars.return_value = mock_scalars # Mock the delete execution result - mock_result = MagicMock() - mock_result.rowcount = 3 - mock_session.execute.return_value = mock_result + mock_delete_result = MagicMock() + mock_delete_result.rowcount = 3 + + # First execute call returns select result, second returns delete result + mock_session.execute.side_effect = [mock_select_result, mock_delete_result] result = device_code_store.cleanup_stale_device_codes(limit=3) assert result == 3 - mock_session.execute.assert_called_once() + assert mock_session.execute.call_count == 2 mock_session.commit.assert_called_once() - # Verify the limit was applied in the query - mock_session.query.return_value.filter.return_value.order_by.return_value.limit.assert_called_with( - 3 - )