diff --git a/enterprise/storage/device_code_store.py b/enterprise/storage/device_code_store.py index d8c421dd6f..af57e7e64e 100644 --- a/enterprise/storage/device_code_store.py +++ b/enterprise/storage/device_code_store.py @@ -4,7 +4,7 @@ import secrets import string from datetime import datetime, timedelta, timezone -from sqlalchemy import delete, select +from sqlalchemy import delete from sqlalchemy.exc import IntegrityError from storage.device_code import DeviceCode @@ -182,14 +182,13 @@ class DeviceCodeStore: """ with self.session_maker() as session: # Get expired device codes, ordered by oldest first (using ID as proxy for creation order) - query = ( - select(DeviceCode) - .where(DeviceCode.expires_at < datetime.now(timezone.utc)) + expired_codes = ( + session.query(DeviceCode) + .filter(DeviceCode.expires_at < datetime.now(timezone.utc)) .order_by(DeviceCode.id.asc()) .limit(limit) + .all() ) - result = session.execute(query) - expired_codes = result.scalars().all() if not expired_codes: logger.info('No expired device codes found') diff --git a/enterprise/tests/unit/storage/test_device_code_store.py b/enterprise/tests/unit/storage/test_device_code_store.py index 75adc0d402..1e82760eb9 100644 --- a/enterprise/tests/unit/storage/test_device_code_store.py +++ b/enterprise/tests/unit/storage/test_device_code_store.py @@ -194,17 +194,13 @@ class TestDeviceCodeStore: def test_cleanup_stale_device_codes_empty(self, device_code_store, mock_session): """Test cleanup when no expired device codes exist.""" - # Mock empty query result using select() pattern - mock_scalars = MagicMock() - mock_scalars.all.return_value = [] - mock_result = MagicMock() - mock_result.scalars.return_value = mock_scalars - mock_session.execute.return_value = mock_result + # Mock empty query result + mock_session.query.return_value.filter.return_value.order_by.return_value.limit.return_value.all.return_value = [] result = device_code_store.cleanup_stale_device_codes(limit=50) assert result == 0 - mock_session.execute.assert_called_once() + mock_session.query.assert_called_once_with(DeviceCode) def test_cleanup_stale_device_codes_with_data( self, device_code_store, mock_session @@ -216,23 +212,21 @@ class TestDeviceCodeStore: mock_device2 = MagicMock() mock_device2.id = 2 - # Mock query result with 2 expired codes using select() pattern - mock_scalars = MagicMock() - mock_scalars.all.return_value = [mock_device1, mock_device2] - mock_select_result = MagicMock() - mock_select_result.scalars.return_value = mock_scalars + # Mock query result with 2 expired codes + mock_session.query.return_value.filter.return_value.order_by.return_value.limit.return_value.all.return_value = [ + mock_device1, + mock_device2, + ] # Mock the delete execution result - mock_delete_result = MagicMock() - mock_delete_result.rowcount = 2 - - # First execute call returns select result, second returns delete result - mock_session.execute.side_effect = [mock_select_result, mock_delete_result] + mock_result = MagicMock() + mock_result.rowcount = 2 + mock_session.execute.return_value = mock_result result = device_code_store.cleanup_stale_device_codes(limit=50) assert result == 2 - assert mock_session.execute.call_count == 2 + mock_session.execute.assert_called_once() mock_session.commit.assert_called_once() def test_cleanup_stale_device_codes_with_limit( @@ -242,21 +236,20 @@ class TestDeviceCodeStore: # Create mock device codes mock_devices = [MagicMock(id=i) for i in range(1, 4)] # 3 codes - # Mock query result with 3 expired codes using select() pattern - mock_scalars = MagicMock() - mock_scalars.all.return_value = mock_devices - mock_select_result = MagicMock() - mock_select_result.scalars.return_value = mock_scalars + # Mock query result with 3 expired codes + mock_session.query.return_value.filter.return_value.order_by.return_value.limit.return_value.all.return_value = mock_devices # Mock the delete execution result - mock_delete_result = MagicMock() - mock_delete_result.rowcount = 3 - - # First execute call returns select result, second returns delete result - mock_session.execute.side_effect = [mock_select_result, mock_delete_result] + mock_result = MagicMock() + mock_result.rowcount = 3 + mock_session.execute.return_value = mock_result result = device_code_store.cleanup_stale_device_codes(limit=3) assert result == 3 - assert mock_session.execute.call_count == 2 + mock_session.execute.assert_called_once() mock_session.commit.assert_called_once() + # Verify the limit was applied in the query + mock_session.query.return_value.filter.return_value.order_by.return_value.limit.assert_called_with( + 3 + )