From 604534905f944cd208d77aa2a565769879e900c1 Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Mon, 27 Jan 2025 13:58:09 -0500 Subject: [PATCH] Refactor: Use type[Event] instead of str to filter events (#6480) Co-authored-by: openhands --- openhands/events/stream.py | 10 +++++----- openhands/server/routes/conversation.py | 19 ++++++++++++++++++- tests/unit/test_event_stream.py | 4 ++-- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 2ef6047f24..50b24ed848 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -319,7 +319,7 @@ class EventStream: self, event, query: str | None = None, - event_type: str | None = None, + event_type: type[Event] | None = None, source: str | None = None, start_date: str | None = None, end_date: str | None = None, @@ -329,7 +329,7 @@ class EventStream: Args: event: The event to check query (str, optional): Text to search for in event content - event_type (str, optional): Filter by event type (e.g., "FileReadAction") + event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction) source (str, optional): Filter by event source start_date (str, optional): Filter events after this date (ISO format) end_date (str, optional): Filter events before this date (ISO format) @@ -337,7 +337,7 @@ class EventStream: Returns: bool: True if the event should be filtered out, False if it matches all criteria """ - if event_type and not event.__class__.__name__ == event_type: + if event_type and not isinstance(event, event_type): return True if source and not event.source.value == source: @@ -361,7 +361,7 @@ class EventStream: def get_matching_events( self, query: str | None = None, - event_type: str | None = None, + event_type: type[Event] | None = None, source: str | None = None, start_date: str | None = None, end_date: str | None = None, @@ -372,7 +372,7 @@ class EventStream: Args: query (str, optional): Text to search for in event content - event_type (str, optional): Filter by event type (e.g., "FileReadAction") + event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction) source (str, optional): Filter by event source start_date (str, optional): Filter events after this date (ISO format) end_date (str, optional): Filter events before this date (ISO format) diff --git a/openhands/server/routes/conversation.py b/openhands/server/routes/conversation.py index 1fcda47635..d5fab4515a 100644 --- a/openhands/server/routes/conversation.py +++ b/openhands/server/routes/conversation.py @@ -2,11 +2,26 @@ from fastapi import APIRouter, HTTPException, Request, status from fastapi.responses import JSONResponse from openhands.core.logger import openhands_logger as logger +from openhands.events.event import Event +from openhands.events.serialization.event import event_from_dict from openhands.runtime.base import Runtime app = APIRouter(prefix='/api/conversations/{conversation_id}') +def str_to_event_type(event: str | None) -> Event | None: + if not event: + return None + + for event_type in ['observation', 'action']: + try: + return event_from_dict({event_type: event}) + except Exception: + continue + + return None + + @app.get('/config') async def get_remote_runtime_config(request: Request): """Retrieve the runtime configuration. @@ -126,9 +141,11 @@ async def search_events( ) # Get matching events from the stream event_stream = request.state.conversation.event_stream + + cast_event_type = str_to_event_type(event_type) matching_events = event_stream.get_matching_events( query=query, - event_type=event_type, + event_type=cast_event_type, source=source, start_date=start_date, end_date=end_date, diff --git a/tests/unit/test_event_stream.py b/tests/unit/test_event_stream.py index 36d51e78e7..d9ce963bf6 100644 --- a/tests/unit/test_event_stream.py +++ b/tests/unit/test_event_stream.py @@ -74,12 +74,12 @@ def test_get_matching_events_type_filter(temp_dir: str): event_stream.add_event(NullAction(), EventSource.AGENT) # Filter by NullAction - events = event_stream.get_matching_events(event_type='NullAction') + events = event_stream.get_matching_events(event_type=NullAction) assert len(events) == 2 assert all(e['action'] == 'null' for e in events) # Filter by NullObservation - events = event_stream.get_matching_events(event_type='NullObservation') + events = event_stream.get_matching_events(event_type=NullObservation) assert len(events) == 1 assert events[0]['observation'] == 'null'