mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
fix: asyncpg, device key timestamp without timezone, error reporting (#13301)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -21,11 +21,12 @@ async def main():
|
||||
|
||||
|
||||
def set_stale_task_error():
|
||||
# started_at is naive UTC; strip tzinfo before comparing.
|
||||
cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=1)
|
||||
with session_maker() as session:
|
||||
session.query(MaintenanceTask).filter(
|
||||
MaintenanceTask.status == MaintenanceTaskStatus.WORKING,
|
||||
MaintenanceTask.started_at
|
||||
< datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
MaintenanceTask.started_at < cutoff,
|
||||
).update({MaintenanceTask.status: MaintenanceTaskStatus.ERROR})
|
||||
session.commit()
|
||||
|
||||
@@ -37,9 +38,10 @@ async def run_tasks():
|
||||
if not task:
|
||||
return
|
||||
|
||||
# Update the status
|
||||
# started_at/updated_at are naive UTC; strip tzinfo.
|
||||
now_utc = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
task.status = MaintenanceTaskStatus.WORKING
|
||||
task.updated_at = task.started_at = datetime.now(timezone.utc)
|
||||
task.updated_at = task.started_at = now_utc
|
||||
session.commit()
|
||||
|
||||
try:
|
||||
|
||||
@@ -31,7 +31,8 @@ class ApiKeyStore:
|
||||
Args:
|
||||
user_id: The ID of the user to create the key for
|
||||
name: Optional name for the key
|
||||
expires_at: Optional expiration date for the key
|
||||
expires_at: Expiration datetime in UTC. Timezone info is stripped before
|
||||
writing to the TIMESTAMP WITHOUT TIME ZONE column.
|
||||
|
||||
Returns:
|
||||
The generated API key
|
||||
@@ -42,6 +43,10 @@ class ApiKeyStore:
|
||||
raise ValueError(f'User not found: {user_id}')
|
||||
org_id = user.current_org_id
|
||||
|
||||
# Column is TIMESTAMP WITHOUT TIME ZONE; strip tzinfo before writing.
|
||||
if expires_at is not None and expires_at.tzinfo is not None:
|
||||
expires_at = expires_at.replace(tzinfo=None)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key,
|
||||
@@ -66,9 +71,8 @@ class ApiKeyStore:
|
||||
if not key_record:
|
||||
return None
|
||||
|
||||
# Check if the key has expired
|
||||
# expires_at is stored as naive UTC; re-attach tzinfo for comparison.
|
||||
if key_record.expires_at:
|
||||
# Handle timezone-naive datetime from database by assuming it's UTC
|
||||
expires_at = key_record.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
@@ -55,6 +55,28 @@ def test_generate_api_key(api_key_store):
|
||||
assert len(key) == len('sk-oh-') + 32
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
async def test_create_api_key_strips_timezone_from_expires_at(
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
):
|
||||
"""Timezone-aware expires_at must be stored as naive UTC without shifting the value."""
|
||||
user_id = str(uuid.uuid4())
|
||||
aware_expiry = datetime.now(UTC) + timedelta(days=30)
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
key = await api_key_store.create_api_key(user_id, expires_at=aware_expiry)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.key == key))
|
||||
record = result.scalars().first()
|
||||
|
||||
assert record.expires_at is not None
|
||||
assert record.expires_at.tzinfo is None
|
||||
assert record.expires_at == aware_expiry.replace(tzinfo=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
async def test_create_api_key(
|
||||
|
||||
@@ -304,6 +304,76 @@ class TestRunMaintenanceTasks:
|
||||
assert 'error' in updated_task.info
|
||||
assert updated_task.info['error'] == 'Test error'
|
||||
|
||||
def test_set_stale_task_error_uses_naive_utc_cutoff(self, session_maker):
|
||||
"""set_stale_task_error must compare against a naive UTC cutoff."""
|
||||
# Create tasks using naive UTC started_at values (matching what run_tasks writes).
|
||||
with session_maker() as session:
|
||||
stale = MaintenanceTask(
|
||||
status=MaintenanceTaskStatus.WORKING,
|
||||
processor_type='test.processor',
|
||||
processor_json='{}',
|
||||
started_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
- timedelta(hours=2),
|
||||
)
|
||||
recent = MaintenanceTask(
|
||||
status=MaintenanceTaskStatus.WORKING,
|
||||
processor_type='test.processor',
|
||||
processor_json='{}',
|
||||
started_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
- timedelta(minutes=30),
|
||||
)
|
||||
session.add_all([stale, recent])
|
||||
session.commit()
|
||||
stale_id, recent_id = stale.id, recent.id
|
||||
|
||||
with patch('run_maintenance_tasks.session_maker', return_value=session_maker()):
|
||||
set_stale_task_error()
|
||||
|
||||
with session_maker() as session:
|
||||
assert (
|
||||
session.get(MaintenanceTask, stale_id).status
|
||||
== MaintenanceTaskStatus.ERROR
|
||||
)
|
||||
assert (
|
||||
session.get(MaintenanceTask, recent_id).status
|
||||
== MaintenanceTaskStatus.WORKING
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_tasks_stores_naive_utc_started_at(self, session_maker):
|
||||
"""run_tasks must write naive UTC datetimes to started_at and updated_at."""
|
||||
processor = AsyncMock(return_value={})
|
||||
|
||||
with session_maker() as session:
|
||||
task = MaintenanceTask(
|
||||
status=MaintenanceTaskStatus.PENDING,
|
||||
processor_type='test.processor',
|
||||
processor_json='{}',
|
||||
)
|
||||
session.add(task)
|
||||
session.commit()
|
||||
task_id = task.id
|
||||
|
||||
with patch(
|
||||
'storage.maintenance_task.MaintenanceTask.get_processor',
|
||||
return_value=processor,
|
||||
):
|
||||
with patch(
|
||||
'run_maintenance_tasks.session_maker', return_value=session_maker()
|
||||
):
|
||||
with patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
try:
|
||||
await asyncio.wait_for(run_tasks(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
with session_maker() as session:
|
||||
updated = session.get(MaintenanceTask, task_id)
|
||||
assert updated.started_at is not None
|
||||
assert updated.started_at.tzinfo is None
|
||||
assert updated.updated_at is not None
|
||||
assert updated.updated_at.tzinfo is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_tasks_respects_delay(self, session_maker):
|
||||
"""Test that run_tasks respects the delay parameter."""
|
||||
|
||||
Reference in New Issue
Block a user