Refactor: Don't serialize matching events when searching event stream (#6509)

This commit is contained in:
Rohit Malhotra
2025-01-28 18:17:44 -05:00
committed by GitHub
parent 35346068d1
commit eb760f32c7
3 changed files with 22 additions and 11 deletions

View File

@@ -384,7 +384,7 @@ class EventStream:
start_id: int = 0,
limit: int = 100,
reverse: bool = False,
) -> list:
) -> list[type[Event]]:
"""Get matching events from the event stream based on filters.
Args:
@@ -414,7 +414,7 @@ class EventStream:
):
continue
matching_events.append(event_to_dict(event))
matching_events.append(event)
# Stop if we have enough events
if len(matching_events) >= limit:

View File

@@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.events.serialization.event import event_from_dict
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.runtime.base import Runtime
app = APIRouter(prefix='/api/conversations/{conversation_id}')
@@ -156,6 +156,8 @@ async def search_events(
has_more = len(matching_events) > limit
if has_more:
matching_events = matching_events[:limit] # Remove the extra event
matching_events = [event_to_dict(event) for event in matching_events]
return {
'events': matching_events,
'has_more': has_more,

View File

@@ -3,6 +3,7 @@ import json
import pytest
from pytest import TempPathFactory
from openhands.core.schema.observation import ObservationType
from openhands.events import EventSource, EventStream
from openhands.events.action import (
NullAction,
@@ -78,12 +79,15 @@ def test_get_matching_events_type_filter(temp_dir: str):
# Filter by NullAction
events = event_stream.get_matching_events(event_types=(NullAction,))
assert len(events) == 2
assert all(e['action'] == 'null' for e in events)
assert all(isinstance(e, NullAction) for e in events)
# Filter by NullObservation
events = event_stream.get_matching_events(event_types=(NullObservation,))
assert len(events) == 1
assert events[0]['observation'] == 'null'
assert (
isinstance(events[0], NullObservation)
and events[0].observation == ObservationType.NULL
)
# Filter by NullAction and MessageAction
events = event_stream.get_matching_events(event_types=(NullAction, MessageAction))
@@ -91,7 +95,7 @@ def test_get_matching_events_type_filter(temp_dir: str):
# Filter in reverse
events = event_stream.get_matching_events(reverse=True, limit=1)
assert events[0]['message'] == 'test'
assert isinstance(events[0], MessageAction) and events[0].content == 'test'
def test_get_matching_events_query_search(temp_dir: str):
@@ -126,12 +130,17 @@ def test_get_matching_events_source_filter(temp_dir: str):
# Filter by AGENT source
events = event_stream.get_matching_events(source='agent')
assert len(events) == 2
assert all(e['source'] == 'agent' for e in events)
assert all(
isinstance(e, NullObservation) and e.source == EventSource.AGENT for e in events
)
# Filter by ENVIRONMENT source
events = event_stream.get_matching_events(source='environment')
assert len(events) == 1
assert events[0]['source'] == 'environment'
assert (
isinstance(events[0], NullObservation)
and events[0].source == EventSource.ENVIRONMENT
)
def test_get_matching_events_pagination(temp_dir: str):
@@ -149,13 +158,13 @@ def test_get_matching_events_pagination(temp_dir: str):
# Test start_id
events = event_stream.get_matching_events(start_id=2)
assert len(events) == 3
assert events[0]['content'] == 'test2'
assert isinstance(events[0], NullObservation) and events[0].content == 'test2'
# Test combination of start_id and limit
events = event_stream.get_matching_events(start_id=1, limit=2)
assert len(events) == 2
assert events[0]['content'] == 'test1'
assert events[1]['content'] == 'test2'
assert isinstance(events[0], NullObservation) and events[0].content == 'test1'
assert isinstance(events[1], NullObservation) and events[1].content == 'test2'
def test_get_matching_events_limit_validation(temp_dir: str):