Separate event store from event stream (#7592)

This commit is contained in:
tofarr 2025-04-02 10:05:59 -06:00 committed by GitHub
parent 5524fe1408
commit f14a0ea011
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 300 additions and 270 deletions

View File

@ -0,0 +1,23 @@
import asyncio
from typing import Any, AsyncIterator
from openhands.events.event import Event
from openhands.events.event_store import EventStore
class AsyncEventStoreWrapper:
def __init__(self, event_store: EventStore, *args: Any, **kwargs: Any) -> None:
self.event_store = event_store
self.args = args
self.kwargs = kwargs
async def __aiter__(self) -> AsyncIterator[Event]:
loop = asyncio.get_running_loop()
# Create an async generator that yields events
for event in self.event_store.get_events(*self.args, **self.kwargs):
# Run the blocking get_events() in a thread pool
def get_event(e: Event = event) -> Event:
return e
yield await loop.run_in_executor(None, get_event)

View File

@ -0,0 +1,239 @@
import json
from dataclasses import dataclass
from typing import Iterable
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event, EventSource
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.storage.files import FileStore
from openhands.storage.locations import (
get_conversation_event_filename,
get_conversation_events_dir,
)
from openhands.utils.shutdown_listener import should_continue
@dataclass
class EventStore:
"""
A stored list of events backing a conversation
"""
sid: str
file_store: FileStore
user_id: str | None
cur_id: int = -1 # We fix this in post init if it is not specified
def __post_init__(self) -> None:
if self.cur_id >= 0:
return
events = []
try:
events_dir = get_conversation_events_dir(self.sid, self.user_id)
events = self.file_store.list(events_dir)
except FileNotFoundError:
logger.debug(f'No events found for session {self.sid} at {events_dir}')
if self.user_id:
# During transition to new location, try old location if user_id is set
# TODO: remove this code after 5/1/2025
try:
events_dir = get_conversation_events_dir(self.sid)
events += self.file_store.list(events_dir)
except FileNotFoundError:
logger.debug(f'No events found for session {self.sid} at {events_dir}')
if not events:
self.cur_id = 0
return
# if we have events, we need to find the highest id to prepare for new events
for event_str in events:
id = self._get_id_from_filename(event_str)
if id >= self.cur_id:
self.cur_id = id + 1
def get_events(
self,
start_id: int = 0,
end_id: int | None = None,
reverse: bool = False,
filter_out_type: tuple[type[Event], ...] | None = None,
filter_hidden: bool = False,
) -> Iterable[Event]:
"""
Retrieve events from the event stream, optionally filtering out events of a given type
and events marked as hidden.
Args:
start_id: The ID of the first event to retrieve. Defaults to 0.
end_id: The ID of the last event to retrieve. Defaults to the last event in the stream.
reverse: Whether to retrieve events in reverse order. Defaults to False.
filter_out_type: A tuple of event types to filter out. Typically used to filter out backend events from the agent.
filter_hidden: If True, filters out events with the 'hidden' attribute set to True.
Yields:
Events from the stream that match the criteria.
"""
def should_filter(event: Event) -> bool:
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
return True
if filter_out_type is not None and isinstance(event, filter_out_type):
return True
return False
if reverse:
if end_id is None:
end_id = self.cur_id - 1
event_id = end_id
while event_id >= start_id:
try:
event = self.get_event(event_id)
if not should_filter(event):
yield event
except FileNotFoundError:
logger.debug(f'No event found for ID {event_id}')
event_id -= 1
else:
event_id = start_id
while should_continue():
if end_id is not None and event_id > end_id:
break
try:
event = self.get_event(event_id)
if not should_filter(event):
yield event
except FileNotFoundError:
break
event_id += 1
def get_event(self, id: int) -> Event:
filename = self._get_filename_for_id(id, self.user_id)
try:
content = self.file_store.read(filename)
data = json.loads(content)
return event_from_dict(data)
except FileNotFoundError:
logger.debug(f'File {filename} not found')
# TODO remove this block after 5/1/2025
if self.user_id:
filename = self._get_filename_for_id(id, None)
content = self.file_store.read(filename)
data = json.loads(content)
return event_from_dict(data)
raise
def get_latest_event(self) -> Event:
return self.get_event(self.cur_id - 1)
def get_latest_event_id(self) -> int:
return self.cur_id - 1
def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]:
for event in self.get_events():
if event.source == source:
yield event
def _should_filter_event(
self,
event: Event,
query: str | None = None,
event_types: tuple[type[Event], ...] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
) -> bool:
"""Check if an event should be filtered out based on the given criteria.
Args:
event: The event to check
query: Text to search for in event content
event_type: Filter by event type classes (e.g., (FileReadAction, ) ).
source: Filter by event source
start_date: Filter events after this date (ISO format)
end_date: Filter events before this date (ISO format)
Returns:
bool: True if the event should be filtered out, False if it matches all criteria
"""
if event_types and not isinstance(event, event_types):
return True
if source:
if event.source is None or event.source.value != source:
return True
if start_date and event.timestamp is not None and event.timestamp < start_date:
return True
if end_date and event.timestamp is not None and event.timestamp > end_date:
return True
# Text search in event content if query provided
if query:
event_dict = event_to_dict(event)
event_str = json.dumps(event_dict).lower()
if query.lower() not in event_str:
return True
return False
def get_matching_events(
self,
query: str | None = None,
event_types: tuple[type[Event], ...] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
start_id: int = 0,
limit: int = 100,
reverse: bool = False,
) -> list[Event]:
"""Get matching events from the event stream based on filters.
Args:
query: Text to search for in event content
event_types: Filter by event type classes (e.g., (FileReadAction, ) ).
source: Filter by event source
start_date: Filter events after this date (ISO format)
end_date: Filter events before this date (ISO format)
start_id: Starting ID in the event stream. Defaults to 0
limit: Maximum number of events to return. Must be between 1 and 100. Defaults to 100
reverse: Whether to retrieve events in reverse order. Defaults to False.
Returns:
list: List of matching events (as dicts)
Raises:
ValueError: If limit is less than 1 or greater than 100
"""
if limit < 1 or limit > 100:
raise ValueError('Limit must be between 1 and 100')
matching_events: list = []
for event in self.get_events(start_id=start_id, reverse=reverse):
if self._should_filter_event(
event, query, event_types, source, start_date, end_date
):
continue
matching_events.append(event)
# Stop if we have enough events
if len(matching_events) >= limit:
break
return matching_events
def _get_filename_for_id(self, id: int, user_id: str | None) -> str:
return get_conversation_event_filename(self.sid, id, user_id)
@staticmethod
def _get_id_from_filename(filename: str) -> int:
try:
return int(filename.split('/')[-1].split('.')[0])
except ValueError:
logger.warning(f'get id from filename ({filename}) failed.')
return -1

View File

@ -5,17 +5,16 @@ from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterable
from typing import Any, Callable
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event, EventSource
from openhands.events.event_store import EventStore
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.io import json
from openhands.storage import FileStore
from openhands.storage.locations import (
get_conversation_dir,
get_conversation_event_filename,
get_conversation_events_dir,
)
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.shutdown_listener import should_continue
@ -42,33 +41,11 @@ async def session_exists(
return False
class AsyncEventStreamWrapper:
def __init__(self, event_stream: 'EventStream', *args: Any, **kwargs: Any) -> None:
self.event_stream = event_stream
self.args = args
self.kwargs = kwargs
async def __aiter__(self) -> AsyncIterator[Event]:
loop = asyncio.get_running_loop()
# Create an async generator that yields events
for event in self.event_stream.get_events(*self.args, **self.kwargs):
# Run the blocking get_events() in a thread pool
def get_event(e: Event = event) -> Event:
return e
yield await loop.run_in_executor(None, get_event)
class EventStream:
sid: str
user_id: str | None
file_store: FileStore
class EventStream(EventStore):
secrets: dict[str, str]
# For each subscriber ID, there is a map of callback functions - useful
# when there are multiple listeners
_subscribers: dict[str, dict[str, Callable]]
_cur_id: int = 0
_lock: threading.Lock
_queue: queue.Queue[Event]
_queue_thread: threading.Thread
@ -77,9 +54,7 @@ class EventStream:
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]]
def __init__(self, sid: str, file_store: FileStore, user_id: str | None = None):
self.sid = sid
self.file_store = file_store
self.user_id = user_id
super().__init__(sid, file_store, user_id)
self._stop_flag = threading.Event()
self._queue: queue.Queue[Event] = queue.Queue()
self._thread_pools = {}
@ -90,40 +65,8 @@ class EventStream:
self._queue_thread.start()
self._subscribers = {}
self._lock = threading.Lock()
self._cur_id = 0
self.secrets = {}
# load the stream
self.__post_init__()
def __post_init__(self) -> None:
events = []
try:
events_dir = get_conversation_events_dir(self.sid, self.user_id)
events += self.file_store.list(events_dir)
except FileNotFoundError:
logger.debug(f'No events found for session {self.sid} at {events_dir}')
if self.user_id:
# During transition to new location, try old location if user_id is set
# TODO: remove this code after 5/1/2025
try:
events_dir = get_conversation_events_dir(self.sid)
events += self.file_store.list(events_dir)
except FileNotFoundError:
logger.debug(f'No events found for session {self.sid} at {events_dir}')
if not events:
self._cur_id = 0
return
# if we have events, we need to find the highest id to prepare for new events
for event_str in events:
id = self._get_id_from_filename(event_str)
if id >= self._cur_id:
self._cur_id = id + 1
def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@ -177,94 +120,6 @@ class EventStream:
del self._subscribers[subscriber_id][callback_id]
def _get_filename_for_id(self, id: int, user_id: str | None) -> str:
return get_conversation_event_filename(self.sid, id, user_id)
@staticmethod
def _get_id_from_filename(filename: str) -> int:
try:
return int(filename.split('/')[-1].split('.')[0])
except ValueError:
logger.warning(f'get id from filename ({filename}) failed.')
return -1
def get_events(
self,
start_id: int = 0,
end_id: int | None = None,
reverse: bool = False,
filter_out_type: tuple[type[Event], ...] | None = None,
filter_hidden: bool = False,
) -> Iterable[Event]:
"""
Retrieve events from the event stream, optionally filtering out events of a given type
and events marked as hidden.
Args:
start_id: The ID of the first event to retrieve. Defaults to 0.
end_id: The ID of the last event to retrieve. Defaults to the last event in the stream.
reverse: Whether to retrieve events in reverse order. Defaults to False.
filter_out_type: A tuple of event types to filter out. Typically used to filter out backend events from the agent.
filter_hidden: If True, filters out events with the 'hidden' attribute set to True.
Yields:
Events from the stream that match the criteria.
"""
def should_filter(event: Event) -> bool:
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
return True
if filter_out_type is not None and isinstance(event, filter_out_type):
return True
return False
if reverse:
if end_id is None:
end_id = self._cur_id - 1
event_id = end_id
while event_id >= start_id:
try:
event = self.get_event(event_id)
if not should_filter(event):
yield event
except FileNotFoundError:
logger.debug(f'No event found for ID {event_id}')
event_id -= 1
else:
event_id = start_id
while should_continue():
if end_id is not None and event_id > end_id:
break
try:
event = self.get_event(event_id)
if not should_filter(event):
yield event
except FileNotFoundError:
break
event_id += 1
def get_event(self, id: int) -> Event:
filename = self._get_filename_for_id(id, self.user_id)
try:
content = self.file_store.read(filename)
data = json.loads(content)
return event_from_dict(data)
except FileNotFoundError:
logger.debug(f'File {filename} not found')
# TODO remove this block after 5/1/2025
if self.user_id:
filename = self._get_filename_for_id(id, None)
content = self.file_store.read(filename)
data = json.loads(content)
return event_from_dict(data)
raise
def get_latest_event(self) -> Event:
return self.get_event(self._cur_id - 1)
def get_latest_event_id(self) -> int:
return self._cur_id - 1
def subscribe(
self,
subscriber_id: EventStreamSubscriber,
@ -304,8 +159,8 @@ class EventStream:
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
)
with self._lock:
event._id = self._cur_id # type: ignore [attr-defined]
self._cur_id += 1
event._id = self.cur_id # type: ignore [attr-defined]
self.cur_id += 1
logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
event._timestamp = datetime.now().isoformat()
event._source = source # type: ignore [attr-defined]
@ -373,100 +228,3 @@ class EventStream:
raise e
return _handle_callback_error
def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]:
for event in self.get_events():
if event.source == source:
yield event
def _should_filter_event(
self,
event: Event,
query: str | None = None,
event_types: tuple[type[Event], ...] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
) -> bool:
"""Check if an event should be filtered out based on the given criteria.
Args:
event: The event to check
query: Text to search for in event content
event_type: Filter by event type classes (e.g., (FileReadAction, ) ).
source: Filter by event source
start_date: Filter events after this date (ISO format)
end_date: Filter events before this date (ISO format)
Returns:
bool: True if the event should be filtered out, False if it matches all criteria
"""
if event_types and not isinstance(event, event_types):
return True
if source:
if event.source is None or event.source.value != source:
return True
if start_date and event.timestamp is not None and event.timestamp < start_date:
return True
if end_date and event.timestamp is not None and event.timestamp > end_date:
return True
# Text search in event content if query provided
if query:
event_dict = event_to_dict(event)
event_str = json.dumps(event_dict).lower()
if query.lower() not in event_str:
return True
return False
def get_matching_events(
self,
query: str | None = None,
event_types: tuple[type[Event], ...] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
start_id: int = 0,
limit: int = 100,
reverse: bool = False,
) -> list[Event]:
"""Get matching events from the event stream based on filters.
Args:
query: Text to search for in event content
event_types: Filter by event type classes (e.g., (FileReadAction, ) ).
source: Filter by event source
start_date: Filter events after this date (ISO format)
end_date: Filter events before this date (ISO format)
start_id: Starting ID in the event stream. Defaults to 0
limit: Maximum number of events to return. Must be between 1 and 100. Defaults to 100
reverse: Whether to retrieve events in reverse order. Defaults to False.
Returns:
list: List of matching events (as dicts)
Raises:
ValueError: If limit is less than 1 or greater than 100
"""
if limit < 1 or limit > 100:
raise ValueError('Limit must be between 1 and 100')
matching_events: list = []
for event in self.get_events(start_id=start_id, reverse=reverse):
if self._should_filter_event(
event, query, event_types, source, start_date, end_date
):
continue
matching_events.append(event)
# Stop if we have enough events
if len(matching_events) >= limit:
break
return matching_events

View File

@ -6,7 +6,7 @@ import socketio
from openhands.core.config import AppConfig
from openhands.events.action import MessageAction
from openhands.events.stream import EventStream
from openhands.events.event_store import EventStore
from openhands.server.config.server_config import ServerConfig
from openhands.server.monitoring import MonitoringListener
from openhands.server.session.conversation import Conversation
@ -54,7 +54,7 @@ class ConversationManager(ABC):
settings: Settings,
user_id: str | None,
github_user_id: str | None,
) -> EventStream | None:
) -> EventStore | None:
"""Join a conversation and return its event stream."""
async def is_agent_loop_running(self, sid: str) -> bool:
@ -83,7 +83,7 @@ class ConversationManager(ABC):
initial_user_msg: MessageAction | None = None,
replay_json: str | None = None,
github_user_id: str | None = None,
) -> EventStream:
) -> EventStore:
"""Start an event loop if one is not already running"""
@abstractmethod

View File

@ -11,7 +11,8 @@ from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import MessageAction
from openhands.events.stream import EventStream, EventStreamSubscriber, session_exists
from openhands.events.event_store import EventStore
from openhands.events.stream import EventStreamSubscriber, session_exists
from openhands.server.config.server_config import ServerConfig
from openhands.server.monitoring import MonitoringListener
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
@ -115,7 +116,7 @@ class StandaloneConversationManager(ConversationManager):
settings: Settings,
user_id: str | None,
github_user_id: str | None,
) -> EventStream:
) -> EventStore:
logger.info(
f'join_conversation:{sid}:{connection_id}',
extra={'session_id': sid, 'user_id': user_id},
@ -254,7 +255,7 @@ class StandaloneConversationManager(ConversationManager):
initial_user_msg: MessageAction | None = None,
replay_json: str | None = None,
github_user_id: str | None = None,
) -> EventStream:
) -> EventStore:
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
session: Session | None = None
if not await self.is_agent_loop_running(sid):
@ -300,23 +301,29 @@ class StandaloneConversationManager(ConversationManager):
except ValueError:
pass # Already subscribed - take no action
event_stream = await self._get_event_stream(sid, user_id)
if not event_stream:
event_store = await self._get_event_store(sid, user_id)
if not event_store:
logger.error(
f'No event stream after starting agent loop: {sid}',
extra={'session_id': sid},
)
raise RuntimeError(f'no_event_stream:{sid}')
return event_stream
return event_store
async def _get_event_stream(
async def _get_event_store(
self, sid: str, user_id: str | None
) -> EventStream | None:
logger.info(f'_get_event_stream:{sid}', extra={'session_id': sid})
) -> EventStore | None:
logger.info(f'_get_event_store:{sid}', extra={'session_id': sid})
session = self._local_agent_loops_by_sid.get(sid)
if session:
logger.info(f'found_local_agent_loop:{sid}', extra={'session_id': sid})
return session.agent_session.event_stream
event_stream = session.agent_session.event_stream
return EventStore(
event_stream.sid,
event_stream.file_store,
event_stream.user_id,
event_stream.cur_id,
)
return None
async def send_to_event_stream(self, connection_id: str, data: dict):

View File

@ -7,6 +7,7 @@ from openhands.events.action import (
NullAction,
)
from openhands.events.action.agent import RecallAction
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
from openhands.events.observation import (
NullObservation,
)
@ -15,7 +16,6 @@ from openhands.events.observation.agent import (
RecallObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.shared import (
SettingsStoreImpl,
config,
@ -60,8 +60,8 @@ async def connect(connection_id: str, environ):
agent_state_changed = None
if event_stream is None:
raise ConnectionRefusedError('Failed to join conversation')
async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
async for event in async_stream:
async_store = AsyncEventStoreWrapper(event_stream, latest_event_id + 1)
async for event in async_store:
logger.info(f'oh_event: {event.__class__.__name__}')
if isinstance(
event,

View File

@ -2,8 +2,8 @@ from fastapi import APIRouter, Request, status
from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
from openhands.utils.async_utils import call_sync_from_async
@ -34,11 +34,11 @@ async def submit_feedback(request: Request, conversation_id: str) -> JSONRespons
# Assuming the storage service is already configured in the backend
# and there is a function to handle the storage.
body = await request.json()
async_stream = AsyncEventStreamWrapper(
async_store = AsyncEventStoreWrapper(
request.state.conversation.event_stream, filter_hidden=True
)
trajectory = []
async for event in async_stream:
async for event in async_store:
trajectory.append(event_to_dict(event))
feedback = FeedbackDataModel(
email=body.get('email', ''),

View File

@ -2,8 +2,8 @@ from fastapi import APIRouter, Request, status
from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
from openhands.events.serialization import event_to_trajectory
from openhands.events.stream import AsyncEventStreamWrapper
app = APIRouter(prefix='/api/conversations/{conversation_id}')
@ -22,11 +22,11 @@ async def get_trajectory(request: Request) -> JSONResponse:
events.
"""
try:
async_stream = AsyncEventStreamWrapper(
async_store = AsyncEventStoreWrapper(
request.state.conversation.event_stream, filter_hidden=True
)
trajectory = []
async for event in async_stream:
async for event in async_store:
trajectory.append(event_to_trajectory(event))
return JSONResponse(
status_code=status.HTTP_200_OK, content={'trajectory': trajectory}

View File

@ -39,6 +39,7 @@ def get_mock_sio(get_message: GetMessageMock | None = None):
async def test_init_new_local_session():
session_instance = AsyncMock()
session_instance.agent_session = MagicMock()
session_instance.agent_session.event_stream.cur_id = 1
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
@ -85,6 +86,7 @@ async def test_join_local_session():
session_instance.agent_session = MagicMock()
mock_session = MagicMock()
mock_session.return_value = session_instance
session_instance.agent_session.event_stream.cur_id = 1
sio = get_mock_sio()
get_running_agent_loops_mock = AsyncMock()
get_running_agent_loops_mock.return_value = set()
@ -136,6 +138,7 @@ async def test_add_to_local_event_stream():
session_instance.agent_session = MagicMock()
mock_session = MagicMock()
mock_session.return_value = session_instance
session_instance.agent_session.event_stream.cur_id = 1
sio = get_mock_sio()
get_running_agent_loops_mock = AsyncMock()
get_running_agent_loops_mock.return_value = set()