mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Refactor: Use type[Event] instead of str to filter events (#6480)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -319,7 +319,7 @@ class EventStream:
|
||||
self,
|
||||
event,
|
||||
query: str | None = None,
|
||||
event_type: str | None = None,
|
||||
event_type: type[Event] | None = None,
|
||||
source: str | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
@@ -329,7 +329,7 @@ class EventStream:
|
||||
Args:
|
||||
event: The event to check
|
||||
query (str, optional): Text to search for in event content
|
||||
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
|
||||
event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction)
|
||||
source (str, optional): Filter by event source
|
||||
start_date (str, optional): Filter events after this date (ISO format)
|
||||
end_date (str, optional): Filter events before this date (ISO format)
|
||||
@@ -337,7 +337,7 @@ class EventStream:
|
||||
Returns:
|
||||
bool: True if the event should be filtered out, False if it matches all criteria
|
||||
"""
|
||||
if event_type and not event.__class__.__name__ == event_type:
|
||||
if event_type and not isinstance(event, event_type):
|
||||
return True
|
||||
|
||||
if source and not event.source.value == source:
|
||||
@@ -361,7 +361,7 @@ class EventStream:
|
||||
def get_matching_events(
|
||||
self,
|
||||
query: str | None = None,
|
||||
event_type: str | None = None,
|
||||
event_type: type[Event] | None = None,
|
||||
source: str | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
@@ -372,7 +372,7 @@ class EventStream:
|
||||
|
||||
Args:
|
||||
query (str, optional): Text to search for in event content
|
||||
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
|
||||
event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction)
|
||||
source (str, optional): Filter by event source
|
||||
start_date (str, optional): Filter events after this date (ISO format)
|
||||
end_date (str, optional): Filter events before this date (ISO format)
|
||||
|
||||
@@ -2,11 +2,26 @@ from fastapi import APIRouter, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization.event import event_from_dict
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}')
|
||||
|
||||
|
||||
def str_to_event_type(event: str | None) -> Event | None:
|
||||
if not event:
|
||||
return None
|
||||
|
||||
for event_type in ['observation', 'action']:
|
||||
try:
|
||||
return event_from_dict({event_type: event})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@app.get('/config')
|
||||
async def get_remote_runtime_config(request: Request):
|
||||
"""Retrieve the runtime configuration.
|
||||
@@ -126,9 +141,11 @@ async def search_events(
|
||||
)
|
||||
# Get matching events from the stream
|
||||
event_stream = request.state.conversation.event_stream
|
||||
|
||||
cast_event_type = str_to_event_type(event_type)
|
||||
matching_events = event_stream.get_matching_events(
|
||||
query=query,
|
||||
event_type=event_type,
|
||||
event_type=cast_event_type,
|
||||
source=source,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
|
||||
@@ -74,12 +74,12 @@ def test_get_matching_events_type_filter(temp_dir: str):
|
||||
event_stream.add_event(NullAction(), EventSource.AGENT)
|
||||
|
||||
# Filter by NullAction
|
||||
events = event_stream.get_matching_events(event_type='NullAction')
|
||||
events = event_stream.get_matching_events(event_type=NullAction)
|
||||
assert len(events) == 2
|
||||
assert all(e['action'] == 'null' for e in events)
|
||||
|
||||
# Filter by NullObservation
|
||||
events = event_stream.get_matching_events(event_type='NullObservation')
|
||||
events = event_stream.get_matching_events(event_type=NullObservation)
|
||||
assert len(events) == 1
|
||||
assert events[0]['observation'] == 'null'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user