Add event search endpoint with pagination and filtering (#4688)

Co-authored-by: AI Assistant <assistant@example.com>
This commit is contained in:
tofarr
2024-11-26 10:18:01 -07:00
committed by GitHub
parent 71be744f2e
commit be6ca4a3ce
3 changed files with 253 additions and 0 deletions

View File

@@ -211,6 +211,95 @@ class EventStream:
if event.source == source:
yield event
def _should_filter_event(
self,
event,
query: str | None = None,
event_type: str | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
) -> bool:
"""Check if an event should be filtered out based on the given criteria.
Args:
event: The event to check
query (str, optional): Text to search for in event content
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
Returns:
bool: True if the event should be filtered out, False if it matches all criteria
"""
if event_type and not event.__class__.__name__ == event_type:
return True
if source and not event.source.value == source:
return True
if start_date and event.timestamp < start_date:
return True
if end_date and event.timestamp > end_date:
return True
# Text search in event content if query provided
if query:
event_dict = event_to_dict(event)
event_str = str(event_dict).lower()
if query.lower() not in event_str:
return True
return False
def get_matching_events(
self,
query: str | None = None,
event_type: str | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
start_id: int = 0,
limit: int = 100,
) -> list:
"""Get matching events from the event stream based on filters.
Args:
query (str, optional): Text to search for in event content
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
start_id (int): Starting ID in the event stream. Defaults to 0
limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 100
Returns:
list: List of matching events (as dicts)
Raises:
ValueError: If limit is less than 1 or greater than 100
"""
if limit < 1 or limit > 100:
raise ValueError('Limit must be between 1 and 100')
matching_events: list = []
for event in self.get_events(start_id=start_id):
if self._should_filter_event(
event, query, event_type, source, start_date, end_date
):
continue
matching_events.append(event_to_dict(event))
# Stop if we have enough events
if len(matching_events) >= limit:
break
return matching_events
def clear(self):
self.file_store.delete(f'sessions/{self.sid}')
self._cur_id = 0

View File

@@ -279,6 +279,66 @@ async def attach_session(request: Request, call_next):
return response
@app.get('/api/events/search')
async def search_events(
request: Request,
query: str | None = None,
start_id: int = 0,
limit: int = 20,
event_type: str | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
):
"""Search through the event stream with filtering and pagination.
Args:
request (Request): The incoming request object
query (str, optional): Text to search for in event content
start_id (int): Starting ID in the event stream. Defaults to 0
limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 20
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
Returns:
dict: Dictionary containing:
- events: List of matching events
- has_more: Whether there are more matching events after this batch
Raises:
HTTPException: If conversation is not found
ValueError: If limit is less than 1 or greater than 100
"""
if not request.state.conversation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail='Conversation not found'
)
# Get matching events from the stream
event_stream = request.state.conversation.event_stream
matching_events = event_stream.get_matching_events(
query=query,
event_type=event_type,
source=source,
start_date=start_date,
end_date=end_date,
start_id=start_id,
limit=limit + 1, # Get one extra to check if there are more
)
# Check if there are more events
has_more = len(matching_events) > limit
if has_more:
matching_events = matching_events[:limit] # Remove the extra event
return {
'events': matching_events,
'has_more': has_more,
}
@app.get('/api/options/models')
async def get_litellm_models() -> list[str]:
"""

View File

@@ -62,3 +62,107 @@ def test_rehydration(temp_dir: str):
assert len(events) == 2
assert events[0].content == 'obs1'
assert events[1].content == 'obs2'
def test_get_matching_events_type_filter(temp_dir: str):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
# Add mixed event types
event_stream.add_event(NullAction(), EventSource.AGENT)
event_stream.add_event(NullObservation('test'), EventSource.AGENT)
event_stream.add_event(NullAction(), EventSource.AGENT)
# Filter by NullAction
events = event_stream.get_matching_events(event_type='NullAction')
assert len(events) == 2
assert all(e['action'] == 'null' for e in events)
# Filter by NullObservation
events = event_stream.get_matching_events(event_type='NullObservation')
assert len(events) == 1
assert events[0]['observation'] == 'null'
def test_get_matching_events_query_search(temp_dir: str):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
event_stream.add_event(NullObservation('hello world'), EventSource.AGENT)
event_stream.add_event(NullObservation('test message'), EventSource.AGENT)
event_stream.add_event(NullObservation('another hello'), EventSource.AGENT)
# Search for 'hello'
events = event_stream.get_matching_events(query='hello')
assert len(events) == 2
# Search should be case-insensitive
events = event_stream.get_matching_events(query='HELLO')
assert len(events) == 2
# Search for non-existent text
events = event_stream.get_matching_events(query='nonexistent')
assert len(events) == 0
def test_get_matching_events_source_filter(temp_dir: str):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
event_stream.add_event(NullObservation('test1'), EventSource.AGENT)
event_stream.add_event(NullObservation('test2'), EventSource.ENVIRONMENT)
event_stream.add_event(NullObservation('test3'), EventSource.AGENT)
# Filter by AGENT source
events = event_stream.get_matching_events(source='agent')
assert len(events) == 2
assert all(e['source'] == 'agent' for e in events)
# Filter by ENVIRONMENT source
events = event_stream.get_matching_events(source='environment')
assert len(events) == 1
assert events[0]['source'] == 'environment'
def test_get_matching_events_pagination(temp_dir: str):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
# Add 5 events
for i in range(5):
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
# Test limit
events = event_stream.get_matching_events(limit=3)
assert len(events) == 3
# Test start_id
events = event_stream.get_matching_events(start_id=2)
assert len(events) == 3
assert events[0]['content'] == 'test2'
# Test combination of start_id and limit
events = event_stream.get_matching_events(start_id=1, limit=2)
assert len(events) == 2
assert events[0]['content'] == 'test1'
assert events[1]['content'] == 'test2'
def test_get_matching_events_limit_validation(temp_dir: str):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store)
# Test limit less than 1
with pytest.raises(ValueError, match='Limit must be between 1 and 100'):
event_stream.get_matching_events(limit=0)
# Test limit greater than 100
with pytest.raises(ValueError, match='Limit must be between 1 and 100'):
event_stream.get_matching_events(limit=101)
# Test valid limits work
event_stream.add_event(NullObservation('test'), EventSource.AGENT)
events = event_stream.get_matching_events(limit=1)
assert len(events) == 1
events = event_stream.get_matching_events(limit=100)
assert len(events) == 1