mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Refactor: Don't serialize matching events when searching event stream (#6509)
This commit is contained in:
@@ -384,7 +384,7 @@ class EventStream:
|
||||
start_id: int = 0,
|
||||
limit: int = 100,
|
||||
reverse: bool = False,
|
||||
) -> list:
|
||||
) -> list[type[Event]]:
|
||||
"""Get matching events from the event stream based on filters.
|
||||
|
||||
Args:
|
||||
@@ -414,7 +414,7 @@ class EventStream:
|
||||
):
|
||||
continue
|
||||
|
||||
matching_events.append(event_to_dict(event))
|
||||
matching_events.append(event)
|
||||
|
||||
# Stop if we have enough events
|
||||
if len(matching_events) >= limit:
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization.event import event_from_dict
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}')
|
||||
@@ -156,6 +156,8 @@ async def search_events(
|
||||
has_more = len(matching_events) > limit
|
||||
if has_more:
|
||||
matching_events = matching_events[:limit] # Remove the extra event
|
||||
|
||||
matching_events = [event_to_dict(event) for event in matching_events]
|
||||
return {
|
||||
'events': matching_events,
|
||||
'has_more': has_more,
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import pytest
|
||||
from pytest import TempPathFactory
|
||||
|
||||
from openhands.core.schema.observation import ObservationType
|
||||
from openhands.events import EventSource, EventStream
|
||||
from openhands.events.action import (
|
||||
NullAction,
|
||||
@@ -78,12 +79,15 @@ def test_get_matching_events_type_filter(temp_dir: str):
|
||||
# Filter by NullAction
|
||||
events = event_stream.get_matching_events(event_types=(NullAction,))
|
||||
assert len(events) == 2
|
||||
assert all(e['action'] == 'null' for e in events)
|
||||
assert all(isinstance(e, NullAction) for e in events)
|
||||
|
||||
# Filter by NullObservation
|
||||
events = event_stream.get_matching_events(event_types=(NullObservation,))
|
||||
assert len(events) == 1
|
||||
assert events[0]['observation'] == 'null'
|
||||
assert (
|
||||
isinstance(events[0], NullObservation)
|
||||
and events[0].observation == ObservationType.NULL
|
||||
)
|
||||
|
||||
# Filter by NullAction and MessageAction
|
||||
events = event_stream.get_matching_events(event_types=(NullAction, MessageAction))
|
||||
@@ -91,7 +95,7 @@ def test_get_matching_events_type_filter(temp_dir: str):
|
||||
|
||||
# Filter in reverse
|
||||
events = event_stream.get_matching_events(reverse=True, limit=1)
|
||||
assert events[0]['message'] == 'test'
|
||||
assert isinstance(events[0], MessageAction) and events[0].content == 'test'
|
||||
|
||||
|
||||
def test_get_matching_events_query_search(temp_dir: str):
|
||||
@@ -126,12 +130,17 @@ def test_get_matching_events_source_filter(temp_dir: str):
|
||||
# Filter by AGENT source
|
||||
events = event_stream.get_matching_events(source='agent')
|
||||
assert len(events) == 2
|
||||
assert all(e['source'] == 'agent' for e in events)
|
||||
assert all(
|
||||
isinstance(e, NullObservation) and e.source == EventSource.AGENT for e in events
|
||||
)
|
||||
|
||||
# Filter by ENVIRONMENT source
|
||||
events = event_stream.get_matching_events(source='environment')
|
||||
assert len(events) == 1
|
||||
assert events[0]['source'] == 'environment'
|
||||
assert (
|
||||
isinstance(events[0], NullObservation)
|
||||
and events[0].source == EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
|
||||
def test_get_matching_events_pagination(temp_dir: str):
|
||||
@@ -149,13 +158,13 @@ def test_get_matching_events_pagination(temp_dir: str):
|
||||
# Test start_id
|
||||
events = event_stream.get_matching_events(start_id=2)
|
||||
assert len(events) == 3
|
||||
assert events[0]['content'] == 'test2'
|
||||
assert isinstance(events[0], NullObservation) and events[0].content == 'test2'
|
||||
|
||||
# Test combination of start_id and limit
|
||||
events = event_stream.get_matching_events(start_id=1, limit=2)
|
||||
assert len(events) == 2
|
||||
assert events[0]['content'] == 'test1'
|
||||
assert events[1]['content'] == 'test2'
|
||||
assert isinstance(events[0], NullObservation) and events[0].content == 'test1'
|
||||
assert isinstance(events[1], NullObservation) and events[1].content == 'test2'
|
||||
|
||||
|
||||
def test_get_matching_events_limit_validation(temp_dir: str):
|
||||
|
||||
Reference in New Issue
Block a user