Refactor: Use type[Event] instead of str to filter events (#6480)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra
2025-01-27 13:58:09 -05:00
committed by GitHub
parent 4bde644fab
commit 604534905f
3 changed files with 25 additions and 8 deletions

View File

@@ -319,7 +319,7 @@ class EventStream:
self,
event,
query: str | None = None,
event_type: str | None = None,
event_type: type[Event] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
@@ -329,7 +329,7 @@ class EventStream:
Args:
event: The event to check
query (str, optional): Text to search for in event content
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction)
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
@@ -337,7 +337,7 @@ class EventStream:
Returns:
bool: True if the event should be filtered out, False if it matches all criteria
"""
if event_type and not event.__class__.__name__ == event_type:
if event_type and not isinstance(event, event_type):
return True
if source and not event.source.value == source:
@@ -361,7 +361,7 @@ class EventStream:
def get_matching_events(
self,
query: str | None = None,
event_type: str | None = None,
event_type: type[Event] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
@@ -372,7 +372,7 @@ class EventStream:
Args:
query (str, optional): Text to search for in event content
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction)
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)

View File

@@ -2,11 +2,26 @@ from fastapi import APIRouter, HTTPException, Request, status
from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.events.serialization.event import event_from_dict
from openhands.runtime.base import Runtime
app = APIRouter(prefix='/api/conversations/{conversation_id}')
def str_to_event_type(event: str | None) -> Event | None:
if not event:
return None
for event_type in ['observation', 'action']:
try:
return event_from_dict({event_type: event})
except Exception:
continue
return None
@app.get('/config')
async def get_remote_runtime_config(request: Request):
"""Retrieve the runtime configuration.
@@ -126,9 +141,11 @@ async def search_events(
)
# Get matching events from the stream
event_stream = request.state.conversation.event_stream
cast_event_type = str_to_event_type(event_type)
matching_events = event_stream.get_matching_events(
query=query,
event_type=event_type,
event_type=cast_event_type,
source=source,
start_date=start_date,
end_date=end_date,

View File

@@ -74,12 +74,12 @@ def test_get_matching_events_type_filter(temp_dir: str):
event_stream.add_event(NullAction(), EventSource.AGENT)
# Filter by NullAction
events = event_stream.get_matching_events(event_type='NullAction')
events = event_stream.get_matching_events(event_type=NullAction)
assert len(events) == 2
assert all(e['action'] == 'null' for e in events)
# Filter by NullObservation
events = event_stream.get_matching_events(event_type='NullObservation')
events = event_stream.get_matching_events(event_type=NullObservation)
assert len(events) == 1
assert events[0]['observation'] == 'null'