fix: add integration tests, store prompt in config, remove AST extraction, extract pagination helper, simplify update

- Thread 1: Added integration tests using real SQLite database (aiosqlite)
  that exercise actual SQL queries for list, get, create, delete, pagination
- Thread 3: Store prompt in config JSON column so DB is source of truth,
  not the generated file
- Thread 4: Removed _extract_prompt_from_file (AST extraction) entirely
- Thread 5: Extracted _paginate() helper used by search_automations and
  list_automation_runs
- Thread 6: Simplified update endpoint - reads prompt from config instead
  of parsing file content

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
openhands
2026-03-11 18:01:25 +00:00
parent 20e0ebacf0
commit 6ec03098ad
2 changed files with 280 additions and 56 deletions

View File

@@ -36,6 +36,13 @@ def _file_store_key(automation_id: str) -> str:
return f'{FILE_STORE_PREFIX}/{automation_id}/automation.py'
def _paginate(rows: list, limit: int, id_attr: str = 'id') -> tuple[list, str | None]:
"""Return (items, next_page_id) from an overfetched result set."""
if len(rows) > limit:
return rows[:limit], getattr(rows[limit], id_attr)
return rows, None
def _automation_to_response(automation: Automation) -> AutomationResponse:
return AutomationResponse(
id=automation.id,
@@ -75,7 +82,7 @@ def _generate_and_validate_file(
repository: str | None = None,
branch: str | None = None,
) -> tuple[str, dict]:
"""Generate automation file, extract and validate config.
"""Generate automation file, extract config, validate, and store prompt in config.
Returns (file_content, config_dict).
Raises HTTPException on validation failure.
@@ -96,6 +103,8 @@ def _generate_and_validate_file(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=f'Invalid automation config: {e}',
)
# Store prompt in config so DB is the source of truth (not the file)
config['prompt'] = prompt
return file_content, config
@@ -186,13 +195,9 @@ async def search_automations(
result = await session.execute(query)
rows = list(result.scalars().all())
next_page_id: str | None = None
if len(rows) > limit:
next_page_id = rows[limit].id
rows = rows[:limit]
items, next_page_id = _paginate(rows, limit)
return PaginatedAutomationsResponse(
items=[_automation_to_response(a) for a in rows],
items=[_automation_to_response(a) for a in items],
total=total,
next_page_id=next_page_id,
)
@@ -242,7 +247,6 @@ async def update_automation(
detail='Automation not found',
)
# Collect non-None updates
updates = {
k: v
for k, v in request.model_dump(exclude_unset=True).items()
@@ -256,6 +260,7 @@ async def update_automation(
current_config = automation.config or {}
current_triggers = current_config.get('triggers', {}).get('cron', {})
# Merge: use request values if provided, else fall back to current config
new_name = updates.get('name', automation.name)
new_schedule = updates.get(
'schedule', current_triggers.get('schedule', '')
@@ -263,19 +268,7 @@ async def update_automation(
new_timezone = updates.get(
'timezone', current_triggers.get('timezone', 'UTC')
)
if 'prompt' in updates:
prompt = updates['prompt']
else:
try:
existing_content = file_store.read(automation.file_store_key)
prompt = _extract_prompt_from_file(existing_content)
except Exception:
logger.warning(
'Could not read existing automation file for prompt extraction',
extra={'automation_id': automation_id},
)
prompt = ''
prompt = updates.get('prompt', current_config.get('prompt', ''))
file_content, config = _generate_and_validate_file(
name=new_name,
@@ -296,7 +289,6 @@ async def update_automation(
automation.config = config
automation.name = new_name
# Apply simple field updates
if 'name' in updates and not needs_regen:
automation.name = updates['name']
if 'enabled' in updates:
@@ -308,34 +300,6 @@ async def update_automation(
return _automation_to_response(automation)
def _extract_prompt_from_file(file_content: str) -> str:
"""Best-effort extraction of the prompt from a generated automation file.
Looks for `conversation.send_message(...)` in the file.
"""
import ast
try:
tree = ast.parse(file_content)
for node in ast.walk(tree):
if (
isinstance(node, ast.Expr)
and isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Attribute)
and node.value.func.attr == 'send_message'
and node.value.args
):
arg = node.value.args[0]
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
return arg.value
if isinstance(arg, ast.JoinedStr):
# f-string — return a reconstructed version
return ast.unparse(arg)
except Exception:
pass
return ''
@automation_router.delete('/{automation_id}', status_code=status.HTTP_204_NO_CONTENT)
async def delete_automation(
automation_id: str,
@@ -456,13 +420,9 @@ async def list_automation_runs(
result = await session.execute(query)
rows = list(result.scalars().all())
next_page_id: str | None = None
if len(rows) > limit:
next_page_id = rows[limit].id
rows = rows[:limit]
items, next_page_id = _paginate(rows, limit)
return PaginatedRunsResponse(
items=[_run_to_response(r) for r in rows],
items=[_run_to_response(r) for r in items],
total=total,
next_page_id=next_page_id,
)

View File

@@ -0,0 +1,264 @@
"""Integration tests for automation CRUD API using a real in-memory SQLite database.
These tests exercise actual SQL queries (list, get, create+verify, pagination, delete)
rather than mocking the database layer.
"""
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from server.routes.automations import automation_router
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from storage.automation import Automation, AutomationRun
from openhands.app_server.utils.sql_utils import Base
from openhands.server.user_auth import get_user_id
TEST_USER_ID = 'integration-test-user'
OTHER_USER_ID = 'other-user'
@pytest.fixture
async def db_engine():
engine = create_async_engine('sqlite+aiosqlite://', echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.fixture
def session_maker(db_engine):
return async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
@pytest.fixture
def app(session_maker):
"""FastAPI app wired to a real SQLite database."""
app = FastAPI()
app.include_router(automation_router)
app.dependency_overrides[get_user_id] = lambda: TEST_USER_ID
@asynccontextmanager
async def _session_ctx():
async with session_maker() as session:
yield session
with patch('server.routes.automations.a_session_maker', _session_ctx):
yield app
@pytest.fixture
async def client(app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url='http://test') as c:
yield c
def _make_automation_obj(
user_id: str = TEST_USER_ID,
name: str = 'Test Auto',
created_at: datetime | None = None,
**kwargs,
) -> Automation:
return Automation(
id=kwargs.get('automation_id', uuid.uuid4().hex),
user_id=user_id,
name=name,
enabled=kwargs.get('enabled', True),
config=kwargs.get(
'config',
{
'name': name,
'triggers': {'cron': {'schedule': '0 9 * * 5', 'timezone': 'UTC'}},
'prompt': 'Do something',
},
),
trigger_type='cron',
file_store_key=kwargs.get('file_store_key', f'automations/{uuid.uuid4().hex}/automation.py'),
created_at=created_at or datetime.now(UTC),
updated_at=created_at or datetime.now(UTC),
)
# ---------- Test: list (search) returns correct results ----------
@pytest.mark.asyncio
async def test_search_returns_user_automations(client, session_maker):
"""GET /search returns only automations owned by the requesting user."""
async with session_maker() as session:
a1 = _make_automation_obj(name='Auto A', created_at=datetime(2026, 1, 1, tzinfo=UTC))
a2 = _make_automation_obj(name='Auto B', created_at=datetime(2026, 1, 2, tzinfo=UTC))
a_other = _make_automation_obj(user_id=OTHER_USER_ID, name='Other User Auto')
session.add_all([a1, a2, a_other])
await session.commit()
response = await client.get('/api/v1/automations/search')
assert response.status_code == 200
data = response.json()
assert data['total'] == 2
assert len(data['items']) == 2
names = {item['name'] for item in data['items']}
assert names == {'Auto A', 'Auto B'}
# ---------- Test: get returns the right object ----------
@pytest.mark.asyncio
async def test_get_returns_correct_automation(client, session_maker):
"""GET /{id} returns the correct automation by ID."""
auto_id = uuid.uuid4().hex
async with session_maker() as session:
auto = _make_automation_obj(automation_id=auto_id, name='Specific Auto')
session.add(auto)
await session.commit()
response = await client.get(f'/api/v1/automations/{auto_id}')
assert response.status_code == 200
data = response.json()
assert data['id'] == auto_id
assert data['name'] == 'Specific Auto'
@pytest.mark.asyncio
async def test_get_nonexistent_returns_404(client):
"""GET /{id} for non-existent automation returns 404."""
response = await client.get('/api/v1/automations/does-not-exist')
assert response.status_code == 404
# ---------- Test: create + verify in DB ----------
@pytest.mark.asyncio
async def test_create_stores_in_db(client, session_maker):
"""POST creates an automation and it's readable from the database."""
mock_file_store = MagicMock()
config = {
'name': 'New Auto',
'triggers': {'cron': {'schedule': '0 9 * * 5', 'timezone': 'UTC'}},
}
with (
patch(
'server.routes.automations.generate_automation_file',
return_value='__config__ = {}',
),
patch('server.routes.automations.extract_config', return_value=config),
patch('server.routes.automations.validate_config'),
patch('server.routes.automations.file_store', mock_file_store),
):
response = await client.post(
'/api/v1/automations',
json={
'name': 'New Auto',
'schedule': '0 9 * * 5',
'prompt': 'Summarize PRs',
},
)
assert response.status_code == 201
data = response.json()
created_id = data['id']
# Verify it's in the DB via the GET endpoint
get_response = await client.get(f'/api/v1/automations/{created_id}')
assert get_response.status_code == 200
assert get_response.json()['name'] == 'New Auto'
# Verify prompt is stored in config
assert get_response.json()['config'].get('prompt') == 'Summarize PRs'
# ---------- Test: delete actually deletes ----------
@pytest.mark.asyncio
async def test_delete_removes_from_db(client, session_maker):
"""DELETE removes the automation from the database."""
auto_id = uuid.uuid4().hex
async with session_maker() as session:
auto = _make_automation_obj(automation_id=auto_id, name='To Delete')
session.add(auto)
await session.commit()
mock_file_store = MagicMock()
with patch('server.routes.automations.file_store', mock_file_store):
response = await client.delete(f'/api/v1/automations/{auto_id}')
assert response.status_code == 204
# Verify it's gone
get_response = await client.get(f'/api/v1/automations/{auto_id}')
assert get_response.status_code == 404
# ---------- Test: pagination actually works ----------
@pytest.mark.asyncio
async def test_pagination_returns_correct_pages(client, session_maker):
"""Pagination with limit returns correct page sizes and next_page_id."""
base_time = datetime(2026, 1, 1, tzinfo=UTC)
async with session_maker() as session:
for i in range(5):
auto = _make_automation_obj(
name=f'Auto {i}',
created_at=base_time + timedelta(hours=i),
)
session.add(auto)
await session.commit()
# First page with limit=2
response = await client.get('/api/v1/automations/search?limit=2')
assert response.status_code == 200
data = response.json()
assert data['total'] == 5
assert len(data['items']) == 2
assert data['next_page_id'] is not None
# Second page using cursor — should return remaining items before cursor
next_id = data['next_page_id']
response2 = await client.get(f'/api/v1/automations/search?limit=2&page_id={next_id}')
assert response2.status_code == 200
data2 = response2.json()
assert len(data2['items']) == 2
# Collect all items from both pages and verify no duplicates
all_ids = [item['id'] for item in data['items']] + [
item['id'] for item in data2['items']
]
assert len(all_ids) == len(set(all_ids)), 'Pages must not contain duplicate items'
# ---------- Test: user isolation at DB level ----------
@pytest.mark.asyncio
async def test_user_isolation(client, session_maker):
"""User A cannot see or access User B's automations via actual DB queries."""
auto_id = uuid.uuid4().hex
async with session_maker() as session:
other_auto = _make_automation_obj(
automation_id=auto_id,
user_id=OTHER_USER_ID,
name='Other User Auto',
)
session.add(other_auto)
await session.commit()
# Should not be found by TEST_USER_ID
response = await client.get(f'/api/v1/automations/{auto_id}')
assert response.status_code == 404
# Should not appear in search
search_response = await client.get('/api/v1/automations/search')
assert search_response.status_code == 200
assert search_response.json()['total'] == 0