mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
fix: add integration tests, store prompt in config, remove AST extraction, extract pagination helper, simplify update
- Thread 1: Added integration tests using real SQLite database (aiosqlite) that exercise actual SQL queries for list, get, create, delete, pagination - Thread 3: Store prompt in config JSON column so DB is source of truth, not the generated file - Thread 4: Removed _extract_prompt_from_file (AST extraction) entirely - Thread 5: Extracted _paginate() helper used by search_automations and list_automation_runs - Thread 6: Simplified update endpoint - reads prompt from config instead of parsing file content Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -36,6 +36,13 @@ def _file_store_key(automation_id: str) -> str:
|
||||
return f'{FILE_STORE_PREFIX}/{automation_id}/automation.py'
|
||||
|
||||
|
||||
def _paginate(rows: list, limit: int, id_attr: str = 'id') -> tuple[list, str | None]:
|
||||
"""Return (items, next_page_id) from an overfetched result set."""
|
||||
if len(rows) > limit:
|
||||
return rows[:limit], getattr(rows[limit], id_attr)
|
||||
return rows, None
|
||||
|
||||
|
||||
def _automation_to_response(automation: Automation) -> AutomationResponse:
|
||||
return AutomationResponse(
|
||||
id=automation.id,
|
||||
@@ -75,7 +82,7 @@ def _generate_and_validate_file(
|
||||
repository: str | None = None,
|
||||
branch: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
"""Generate automation file, extract and validate config.
|
||||
"""Generate automation file, extract config, validate, and store prompt in config.
|
||||
|
||||
Returns (file_content, config_dict).
|
||||
Raises HTTPException on validation failure.
|
||||
@@ -96,6 +103,8 @@ def _generate_and_validate_file(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
detail=f'Invalid automation config: {e}',
|
||||
)
|
||||
# Store prompt in config so DB is the source of truth (not the file)
|
||||
config['prompt'] = prompt
|
||||
return file_content, config
|
||||
|
||||
|
||||
@@ -186,13 +195,9 @@ async def search_automations(
|
||||
result = await session.execute(query)
|
||||
rows = list(result.scalars().all())
|
||||
|
||||
next_page_id: str | None = None
|
||||
if len(rows) > limit:
|
||||
next_page_id = rows[limit].id
|
||||
rows = rows[:limit]
|
||||
|
||||
items, next_page_id = _paginate(rows, limit)
|
||||
return PaginatedAutomationsResponse(
|
||||
items=[_automation_to_response(a) for a in rows],
|
||||
items=[_automation_to_response(a) for a in items],
|
||||
total=total,
|
||||
next_page_id=next_page_id,
|
||||
)
|
||||
@@ -242,7 +247,6 @@ async def update_automation(
|
||||
detail='Automation not found',
|
||||
)
|
||||
|
||||
# Collect non-None updates
|
||||
updates = {
|
||||
k: v
|
||||
for k, v in request.model_dump(exclude_unset=True).items()
|
||||
@@ -256,6 +260,7 @@ async def update_automation(
|
||||
current_config = automation.config or {}
|
||||
current_triggers = current_config.get('triggers', {}).get('cron', {})
|
||||
|
||||
# Merge: use request values if provided, else fall back to current config
|
||||
new_name = updates.get('name', automation.name)
|
||||
new_schedule = updates.get(
|
||||
'schedule', current_triggers.get('schedule', '')
|
||||
@@ -263,19 +268,7 @@ async def update_automation(
|
||||
new_timezone = updates.get(
|
||||
'timezone', current_triggers.get('timezone', 'UTC')
|
||||
)
|
||||
|
||||
if 'prompt' in updates:
|
||||
prompt = updates['prompt']
|
||||
else:
|
||||
try:
|
||||
existing_content = file_store.read(automation.file_store_key)
|
||||
prompt = _extract_prompt_from_file(existing_content)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
'Could not read existing automation file for prompt extraction',
|
||||
extra={'automation_id': automation_id},
|
||||
)
|
||||
prompt = ''
|
||||
prompt = updates.get('prompt', current_config.get('prompt', ''))
|
||||
|
||||
file_content, config = _generate_and_validate_file(
|
||||
name=new_name,
|
||||
@@ -296,7 +289,6 @@ async def update_automation(
|
||||
automation.config = config
|
||||
automation.name = new_name
|
||||
|
||||
# Apply simple field updates
|
||||
if 'name' in updates and not needs_regen:
|
||||
automation.name = updates['name']
|
||||
if 'enabled' in updates:
|
||||
@@ -308,34 +300,6 @@ async def update_automation(
|
||||
return _automation_to_response(automation)
|
||||
|
||||
|
||||
def _extract_prompt_from_file(file_content: str) -> str:
|
||||
"""Best-effort extraction of the prompt from a generated automation file.
|
||||
|
||||
Looks for `conversation.send_message(...)` in the file.
|
||||
"""
|
||||
import ast
|
||||
|
||||
try:
|
||||
tree = ast.parse(file_content)
|
||||
for node in ast.walk(tree):
|
||||
if (
|
||||
isinstance(node, ast.Expr)
|
||||
and isinstance(node.value, ast.Call)
|
||||
and isinstance(node.value.func, ast.Attribute)
|
||||
and node.value.func.attr == 'send_message'
|
||||
and node.value.args
|
||||
):
|
||||
arg = node.value.args[0]
|
||||
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
|
||||
return arg.value
|
||||
if isinstance(arg, ast.JoinedStr):
|
||||
# f-string — return a reconstructed version
|
||||
return ast.unparse(arg)
|
||||
except Exception:
|
||||
pass
|
||||
return ''
|
||||
|
||||
|
||||
@automation_router.delete('/{automation_id}', status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_automation(
|
||||
automation_id: str,
|
||||
@@ -456,13 +420,9 @@ async def list_automation_runs(
|
||||
result = await session.execute(query)
|
||||
rows = list(result.scalars().all())
|
||||
|
||||
next_page_id: str | None = None
|
||||
if len(rows) > limit:
|
||||
next_page_id = rows[limit].id
|
||||
rows = rows[:limit]
|
||||
|
||||
items, next_page_id = _paginate(rows, limit)
|
||||
return PaginatedRunsResponse(
|
||||
items=[_run_to_response(r) for r in rows],
|
||||
items=[_run_to_response(r) for r in items],
|
||||
total=total,
|
||||
next_page_id=next_page_id,
|
||||
)
|
||||
|
||||
264
enterprise/tests/unit/server/test_automations_api_integration.py
Normal file
264
enterprise/tests/unit/server/test_automations_api_integration.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Integration tests for automation CRUD API using a real in-memory SQLite database.
|
||||
|
||||
These tests exercise actual SQL queries (list, get, create+verify, pagination, delete)
|
||||
rather than mocking the database layer.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from server.routes.automations import automation_router
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from storage.automation import Automation, AutomationRun
|
||||
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
TEST_USER_ID = 'integration-test-user'
|
||||
OTHER_USER_ID = 'other-user'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine('sqlite+aiosqlite://', echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(db_engine):
|
||||
return async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(session_maker):
|
||||
"""FastAPI app wired to a real SQLite database."""
|
||||
app = FastAPI()
|
||||
app.include_router(automation_router)
|
||||
app.dependency_overrides[get_user_id] = lambda: TEST_USER_ID
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session_ctx():
|
||||
async with session_maker() as session:
|
||||
yield session
|
||||
|
||||
with patch('server.routes.automations.a_session_maker', _session_ctx):
|
||||
yield app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://test') as c:
|
||||
yield c
|
||||
|
||||
|
||||
def _make_automation_obj(
|
||||
user_id: str = TEST_USER_ID,
|
||||
name: str = 'Test Auto',
|
||||
created_at: datetime | None = None,
|
||||
**kwargs,
|
||||
) -> Automation:
|
||||
return Automation(
|
||||
id=kwargs.get('automation_id', uuid.uuid4().hex),
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
enabled=kwargs.get('enabled', True),
|
||||
config=kwargs.get(
|
||||
'config',
|
||||
{
|
||||
'name': name,
|
||||
'triggers': {'cron': {'schedule': '0 9 * * 5', 'timezone': 'UTC'}},
|
||||
'prompt': 'Do something',
|
||||
},
|
||||
),
|
||||
trigger_type='cron',
|
||||
file_store_key=kwargs.get('file_store_key', f'automations/{uuid.uuid4().hex}/automation.py'),
|
||||
created_at=created_at or datetime.now(UTC),
|
||||
updated_at=created_at or datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
# ---------- Test: list (search) returns correct results ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_user_automations(client, session_maker):
|
||||
"""GET /search returns only automations owned by the requesting user."""
|
||||
async with session_maker() as session:
|
||||
a1 = _make_automation_obj(name='Auto A', created_at=datetime(2026, 1, 1, tzinfo=UTC))
|
||||
a2 = _make_automation_obj(name='Auto B', created_at=datetime(2026, 1, 2, tzinfo=UTC))
|
||||
a_other = _make_automation_obj(user_id=OTHER_USER_ID, name='Other User Auto')
|
||||
session.add_all([a1, a2, a_other])
|
||||
await session.commit()
|
||||
|
||||
response = await client.get('/api/v1/automations/search')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['total'] == 2
|
||||
assert len(data['items']) == 2
|
||||
names = {item['name'] for item in data['items']}
|
||||
assert names == {'Auto A', 'Auto B'}
|
||||
|
||||
|
||||
# ---------- Test: get returns the right object ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_returns_correct_automation(client, session_maker):
|
||||
"""GET /{id} returns the correct automation by ID."""
|
||||
auto_id = uuid.uuid4().hex
|
||||
async with session_maker() as session:
|
||||
auto = _make_automation_obj(automation_id=auto_id, name='Specific Auto')
|
||||
session.add(auto)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(f'/api/v1/automations/{auto_id}')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['id'] == auto_id
|
||||
assert data['name'] == 'Specific Auto'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_returns_404(client):
|
||||
"""GET /{id} for non-existent automation returns 404."""
|
||||
response = await client.get('/api/v1/automations/does-not-exist')
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---------- Test: create + verify in DB ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_stores_in_db(client, session_maker):
|
||||
"""POST creates an automation and it's readable from the database."""
|
||||
mock_file_store = MagicMock()
|
||||
config = {
|
||||
'name': 'New Auto',
|
||||
'triggers': {'cron': {'schedule': '0 9 * * 5', 'timezone': 'UTC'}},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.automations.generate_automation_file',
|
||||
return_value='__config__ = {}',
|
||||
),
|
||||
patch('server.routes.automations.extract_config', return_value=config),
|
||||
patch('server.routes.automations.validate_config'),
|
||||
patch('server.routes.automations.file_store', mock_file_store),
|
||||
):
|
||||
response = await client.post(
|
||||
'/api/v1/automations',
|
||||
json={
|
||||
'name': 'New Auto',
|
||||
'schedule': '0 9 * * 5',
|
||||
'prompt': 'Summarize PRs',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
created_id = data['id']
|
||||
|
||||
# Verify it's in the DB via the GET endpoint
|
||||
get_response = await client.get(f'/api/v1/automations/{created_id}')
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()['name'] == 'New Auto'
|
||||
# Verify prompt is stored in config
|
||||
assert get_response.json()['config'].get('prompt') == 'Summarize PRs'
|
||||
|
||||
|
||||
# ---------- Test: delete actually deletes ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_removes_from_db(client, session_maker):
|
||||
"""DELETE removes the automation from the database."""
|
||||
auto_id = uuid.uuid4().hex
|
||||
async with session_maker() as session:
|
||||
auto = _make_automation_obj(automation_id=auto_id, name='To Delete')
|
||||
session.add(auto)
|
||||
await session.commit()
|
||||
|
||||
mock_file_store = MagicMock()
|
||||
with patch('server.routes.automations.file_store', mock_file_store):
|
||||
response = await client.delete(f'/api/v1/automations/{auto_id}')
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify it's gone
|
||||
get_response = await client.get(f'/api/v1/automations/{auto_id}')
|
||||
assert get_response.status_code == 404
|
||||
|
||||
|
||||
# ---------- Test: pagination actually works ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_returns_correct_pages(client, session_maker):
|
||||
"""Pagination with limit returns correct page sizes and next_page_id."""
|
||||
base_time = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
async with session_maker() as session:
|
||||
for i in range(5):
|
||||
auto = _make_automation_obj(
|
||||
name=f'Auto {i}',
|
||||
created_at=base_time + timedelta(hours=i),
|
||||
)
|
||||
session.add(auto)
|
||||
await session.commit()
|
||||
|
||||
# First page with limit=2
|
||||
response = await client.get('/api/v1/automations/search?limit=2')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['total'] == 5
|
||||
assert len(data['items']) == 2
|
||||
assert data['next_page_id'] is not None
|
||||
|
||||
# Second page using cursor — should return remaining items before cursor
|
||||
next_id = data['next_page_id']
|
||||
response2 = await client.get(f'/api/v1/automations/search?limit=2&page_id={next_id}')
|
||||
assert response2.status_code == 200
|
||||
data2 = response2.json()
|
||||
assert len(data2['items']) == 2
|
||||
|
||||
# Collect all items from both pages and verify no duplicates
|
||||
all_ids = [item['id'] for item in data['items']] + [
|
||||
item['id'] for item in data2['items']
|
||||
]
|
||||
assert len(all_ids) == len(set(all_ids)), 'Pages must not contain duplicate items'
|
||||
|
||||
|
||||
# ---------- Test: user isolation at DB level ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(client, session_maker):
|
||||
"""User A cannot see or access User B's automations via actual DB queries."""
|
||||
auto_id = uuid.uuid4().hex
|
||||
async with session_maker() as session:
|
||||
other_auto = _make_automation_obj(
|
||||
automation_id=auto_id,
|
||||
user_id=OTHER_USER_ID,
|
||||
name='Other User Auto',
|
||||
)
|
||||
session.add(other_auto)
|
||||
await session.commit()
|
||||
|
||||
# Should not be found by TEST_USER_ID
|
||||
response = await client.get(f'/api/v1/automations/{auto_id}')
|
||||
assert response.status_code == 404
|
||||
|
||||
# Should not appear in search
|
||||
search_response = await client.get('/api/v1/automations/search')
|
||||
assert search_response.status_code == 200
|
||||
assert search_response.json()['total'] == 0
|
||||
Reference in New Issue
Block a user