Fix FastAPI Query parameter validation: lte -> le (#13502)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra
2026-03-19 20:27:10 -04:00
committed by GitHub
parent f75141af3e
commit 63956c3292
8 changed files with 243 additions and 33 deletions

View File

@@ -68,7 +68,7 @@ async def list_user_orgs(
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
Query(title='The max number of results in the page', gt=0, le=100),
] = 100,
user_id: str = Depends(get_user_id),
) -> OrgPage:
@@ -734,7 +734,7 @@ async def get_org_members(
Query(
title='The max number of results in the page',
gt=0,
lte=100,
le=100,
),
] = 10,
email: Annotated[

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, HTTPException, Query
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
@@ -60,7 +60,7 @@ async def search_shared_conversations(
Query(
title='The max number of results in the page',
gt=0,
lte=100,
le=100,
),
] = 100,
include_sub_conversations: Annotated[
@@ -72,8 +72,6 @@ async def search_shared_conversations(
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> SharedConversationPage:
"""Search / List shared conversations."""
assert limit > 0
assert limit <= 100
return await shared_conversation_service.search_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
@@ -127,7 +125,11 @@ async def batch_get_shared_conversations(
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> list[SharedConversation | None]:
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
assert len(ids) <= 100
if len(ids) > 100:
raise HTTPException(
status_code=400,
detail=f'Cannot request more than 100 conversations at once, got {len(ids)}',
)
uuids = [UUID(id_) for id_ in ids]
shared_conversation_info = (
await shared_conversation_service.batch_get_shared_conversation_info(uuids)

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, HTTPException, Query
from server.sharing.shared_event_service import (
SharedEventService,
SharedEventServiceInjector,
@@ -77,13 +77,11 @@ async def search_shared_events(
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
Query(title='The max number of results in the page', gt=0, le=100),
] = 100,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> EventPage:
"""Search / List events for a shared conversation."""
assert limit > 0
assert limit <= 100
return await shared_event_service.search_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
@@ -134,7 +132,11 @@ async def batch_get_shared_events(
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> list[Event | None]:
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
assert len(id) <= 100
if len(id) > 100:
raise HTTPException(
status_code=400,
detail=f'Cannot request more than 100 events at once, got {len(id)}',
)
event_ids = [UUID(id_) for id_ in id]
events = await shared_event_service.batch_get_shared_events(
UUID(conversation_id), event_ids

View File

@@ -234,7 +234,7 @@ async def search_app_conversations(
Query(
title='The max number of results in the page',
gt=0,
lte=100,
le=100,
),
] = 100,
include_sub_conversations: Annotated[
@@ -248,8 +248,6 @@ async def search_app_conversations(
),
) -> AppConversationPage:
"""Search / List sandboxed conversations."""
assert limit > 0
assert limit <= 100
return await app_conversation_service.search_app_conversations(
title__contains=title__contains,
created_at__gte=created_at__gte,
@@ -422,7 +420,7 @@ async def search_app_conversation_start_tasks(
Query(
title='The max number of results in the page',
gt=0,
lte=100,
le=100,
),
] = 100,
app_conversation_start_task_service: AppConversationStartTaskService = (
@@ -430,8 +428,6 @@ async def search_app_conversation_start_tasks(
),
) -> AppConversationStartTaskPage:
"""Search / List conversation start tasks."""
assert limit > 0
assert limit <= 100
return (
await app_conversation_start_task_service.search_app_conversation_start_tasks(
conversation_id__eq=conversation_id__eq,
@@ -472,7 +468,11 @@ async def batch_get_app_conversation_start_tasks(
),
) -> list[AppConversationStartTask | None]:
"""Get a batch of start app conversation tasks given their ids. Return None for any missing."""
assert len(ids) < 100
if len(ids) > 100:
raise HTTPException(
status_code=400,
detail=f'Cannot request more than 100 start tasks at once, got {len(ids)}',
)
start_tasks = await app_conversation_start_task_service.batch_get_app_conversation_start_tasks(
ids
)

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Query
from fastapi import APIRouter, HTTPException, Query
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.config import depends_event_service
@@ -51,13 +51,11 @@ async def search_events(
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
Query(title='The max number of results in the page', gt=0, le=100),
] = 100,
event_service: EventService = event_service_dependency,
) -> EventPage:
"""Search / List events."""
assert limit > 0
assert limit <= 100
return await event_service.search_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
@@ -102,7 +100,11 @@ async def batch_get_events(
event_service: EventService = event_service_dependency,
) -> list[Event | None]:
"""Get a batch of events given their ids, returning null for any missing event."""
if len(id) > 100:
raise HTTPException(
status_code=400,
detail=f'Cannot request more than 100 events at once, got {len(id)}',
)
event_ids = [UUID(id_) for id_ in id]
assert len(id) <= 100
events = await event_service.batch_get_events(UUID(conversation_id), event_ids)
return events

View File

@@ -44,13 +44,11 @@ async def search_sandboxes(
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
Query(title='The max number of results in the page', gt=0, le=100),
] = 100,
sandbox_service: SandboxService = sandbox_service_dependency,
) -> SandboxPage:
"""Search / list sandboxes owned by the current user."""
assert limit > 0
assert limit <= 100
return await sandbox_service.search_sandboxes(page_id=page_id, limit=limit)
@@ -60,7 +58,11 @@ async def batch_get_sandboxes(
sandbox_service: SandboxService = sandbox_service_dependency,
) -> list[SandboxInfo | None]:
"""Get a batch of sandboxes given their ids, returning null for any missing."""
assert len(id) < 100
if len(id) > 100:
raise HTTPException(
status_code=400,
detail=f'Cannot request more than 100 sandboxes at once, got {len(id)}',
)
sandboxes = await sandbox_service.batch_get_sandboxes(id)
return sandboxes

View File

@@ -2,7 +2,7 @@
from typing import Annotated
from fastapi import APIRouter, Query
from fastapi import APIRouter, HTTPException, Query
from openhands.app_server.config import depends_sandbox_spec_service
from openhands.app_server.sandbox.sandbox_spec_models import (
@@ -35,13 +35,11 @@ async def search_sandbox_specs(
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
Query(title='The max number of results in the page', gt=0, le=100),
] = 100,
sandbox_spec_service: SandboxSpecService = sandbox_spec_service_dependency,
) -> SandboxSpecInfoPage:
"""Search / List sandbox specs."""
assert limit > 0
assert limit <= 100
return await sandbox_spec_service.search_sandbox_specs(page_id=page_id, limit=limit)
@@ -51,6 +49,10 @@ async def batch_get_sandbox_specs(
sandbox_spec_service: SandboxSpecService = sandbox_spec_service_dependency,
) -> list[SandboxSpecInfo | None]:
"""Get a batch of sandbox specs given their ids, returning null for any missing."""
assert len(id) <= 100
if len(id) > 100:
raise HTTPException(
status_code=400,
detail=f'Cannot request more than 100 sandbox specs at once, got {len(id)}',
)
sandbox_specs = await sandbox_spec_service.batch_get_sandbox_specs(id)
return sandbox_specs

View File

@@ -0,0 +1,200 @@
"""Unit tests for the event_router endpoints.
This module tests the event router endpoints,
focusing on limit validation and error handling.
"""
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from fastapi import FastAPI, HTTPException, status
from fastapi.testclient import TestClient
from openhands.app_server.event.event_router import batch_get_events, router
from openhands.server.dependencies import check_session_api_key
def _make_mock_event_service(search_return=None, batch_get_return=None):
"""Create a mock EventService for testing."""
service = MagicMock()
service.search_events = AsyncMock(return_value=search_return)
service.batch_get_events = AsyncMock(return_value=batch_get_return or [])
return service
@pytest.fixture
def test_client():
"""Create a test client with the actual event router and mocked dependencies.
We override check_session_api_key to bypass auth checks.
This allows us to test the actual Query parameter validation in the router.
"""
app = FastAPI()
app.include_router(router)
# Override the auth dependency to always pass
app.dependency_overrides[check_session_api_key] = lambda: None
client = TestClient(app, raise_server_exceptions=False)
yield client
# Clean up
app.dependency_overrides.clear()
class TestSearchEventsValidation:
"""Test suite for search_events endpoint limit validation via FastAPI."""
def test_returns_422_for_limit_exceeding_100(self, test_client):
"""Test that limit > 100 returns 422 Unprocessable Entity.
FastAPI's Query validation (le=100) should reject limit=200.
"""
conversation_id = str(uuid4())
response = test_client.get(
f'/conversation/{conversation_id}/events/search',
params={'limit': 200},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
# Verify the error message mentions the constraint
error_detail = response.json()['detail']
assert any(
'less than or equal to 100' in str(err).lower() or 'le' in str(err).lower()
for err in error_detail
)
def test_returns_422_for_limit_zero(self, test_client):
"""Test that limit=0 returns 422 Unprocessable Entity.
FastAPI's Query validation (gt=0) should reject limit=0.
"""
conversation_id = str(uuid4())
response = test_client.get(
f'/conversation/{conversation_id}/events/search',
params={'limit': 0},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_returns_422_for_negative_limit(self, test_client):
"""Test that negative limit returns 422 Unprocessable Entity.
FastAPI's Query validation (gt=0) should reject limit=-1.
"""
conversation_id = str(uuid4())
response = test_client.get(
f'/conversation/{conversation_id}/events/search',
params={'limit': -1},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_accepts_valid_limit_100(self, test_client):
"""Test that limit=100 is accepted (boundary case).
Verify that limit=100 passes FastAPI validation and doesn't return 422.
"""
conversation_id = str(uuid4())
response = test_client.get(
f'/conversation/{conversation_id}/events/search',
params={'limit': 100},
)
# Should pass validation (not 422) - may fail on other errors like missing service
assert response.status_code != status.HTTP_422_UNPROCESSABLE_ENTITY
def test_accepts_valid_limit_1(self, test_client):
"""Test that limit=1 is accepted (boundary case).
Verify that limit=1 passes FastAPI validation and doesn't return 422.
"""
conversation_id = str(uuid4())
response = test_client.get(
f'/conversation/{conversation_id}/events/search',
params={'limit': 1},
)
# Should pass validation (not 422) - may fail on other errors like missing service
assert response.status_code != status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
class TestBatchGetEvents:
"""Test suite for batch_get_events endpoint."""
async def test_returns_400_for_more_than_100_ids(self):
"""Test that requesting more than 100 IDs returns 400 Bad Request.
Arrange: Create list with 101 IDs
Act: Call batch_get_events
Assert: HTTPException is raised with 400 status
"""
# Arrange
conversation_id = str(uuid4())
ids = [str(uuid4()) for _ in range(101)]
mock_service = _make_mock_event_service()
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await batch_get_events(
conversation_id=conversation_id,
id=ids,
event_service=mock_service,
)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'Cannot request more than 100 events' in exc_info.value.detail
assert '101' in exc_info.value.detail
async def test_accepts_exactly_100_ids(self):
"""Test that exactly 100 IDs is accepted.
Arrange: Create list with 100 IDs
Act: Call batch_get_events
Assert: No exception is raised and service is called
"""
# Arrange
conversation_id = str(uuid4())
ids = [str(uuid4()) for _ in range(100)]
mock_return = [None] * 100
mock_service = _make_mock_event_service(batch_get_return=mock_return)
# Act
result = await batch_get_events(
conversation_id=conversation_id,
id=ids,
event_service=mock_service,
)
# Assert
assert result == mock_return
mock_service.batch_get_events.assert_called_once()
async def test_accepts_empty_list(self):
"""Test that empty list of IDs is accepted.
Arrange: Create empty list of IDs
Act: Call batch_get_events
Assert: No exception is raised
"""
# Arrange
conversation_id = str(uuid4())
mock_service = _make_mock_event_service(batch_get_return=[])
# Act
result = await batch_get_events(
conversation_id=conversation_id,
id=[],
event_service=mock_service,
)
# Assert
assert result == []
mock_service.batch_get_events.assert_called_once()