mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Using a paged cache to speed up event streams (#7667)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -7,12 +7,37 @@ from openhands.events.event import Event, EventSource
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import (
|
||||
get_conversation_dir,
|
||||
get_conversation_event_filename,
|
||||
get_conversation_events_dir,
|
||||
)
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _CachePage:
|
||||
events: list[dict] | None
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def covers(self, global_index: int) -> bool:
|
||||
if global_index < self.start:
|
||||
return False
|
||||
if global_index >= self.end:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_event(self, global_index: int) -> Event | None:
|
||||
# If there was not actually a cached page, return None
|
||||
if not self.events:
|
||||
return None
|
||||
local_index = global_index - self.start
|
||||
return event_from_dict(self.events[local_index])
|
||||
|
||||
|
||||
_DUMMY_PAGE = _CachePage(None, 1, -1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventStore:
|
||||
"""
|
||||
@@ -23,6 +48,7 @@ class EventStore:
|
||||
file_store: FileStore
|
||||
user_id: str | None
|
||||
cur_id: int = -1 # We fix this in post init if it is not specified
|
||||
cache_size: int = 25
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.cur_id >= 0:
|
||||
@@ -83,30 +109,33 @@ class EventStore:
|
||||
return True
|
||||
return False
|
||||
|
||||
if reverse:
|
||||
if end_id is None:
|
||||
end_id = self.cur_id - 1
|
||||
event_id = end_id
|
||||
while event_id >= start_id:
|
||||
try:
|
||||
event = self.get_event(event_id)
|
||||
if not should_filter(event):
|
||||
yield event
|
||||
except FileNotFoundError:
|
||||
logger.debug(f'No event found for ID {event_id}')
|
||||
event_id -= 1
|
||||
if end_id is None:
|
||||
end_id = self.cur_id
|
||||
else:
|
||||
event_id = start_id
|
||||
while should_continue():
|
||||
if end_id is not None and event_id > end_id:
|
||||
break
|
||||
end_id += 1 # From inclusive to exclusive
|
||||
|
||||
if reverse:
|
||||
step = -1
|
||||
start_id, end_id = end_id, start_id
|
||||
start_id -= 1
|
||||
end_id -= 1
|
||||
else:
|
||||
step = 1
|
||||
|
||||
cache_page = _DUMMY_PAGE
|
||||
for index in range(start_id, end_id, step):
|
||||
if not should_continue():
|
||||
return
|
||||
if not cache_page.covers(index):
|
||||
cache_page = self._load_cache_page_for_index(index)
|
||||
event = cache_page.get_event(index)
|
||||
if event is None:
|
||||
try:
|
||||
event = self.get_event(event_id)
|
||||
if not should_filter(event):
|
||||
yield event
|
||||
event = self.get_event(index)
|
||||
except FileNotFoundError:
|
||||
break
|
||||
event_id += 1
|
||||
event = None
|
||||
if event and not should_filter(event):
|
||||
yield event
|
||||
|
||||
def get_event(self, id: int) -> Event:
|
||||
filename = self._get_filename_for_id(id, self.user_id)
|
||||
@@ -230,6 +259,25 @@ class EventStore:
|
||||
def _get_filename_for_id(self, id: int, user_id: str | None) -> str:
|
||||
return get_conversation_event_filename(self.sid, id, user_id)
|
||||
|
||||
def _get_filename_for_cache(self, start: int, end: int) -> str:
|
||||
return f'{get_conversation_dir(self.sid, self.user_id)}event_cache/{start}-{end}.json'
|
||||
|
||||
def _load_cache_page(self, start: int, end: int) -> _CachePage:
|
||||
"""Read a page from the cache. Reading individual events is slow when there are a lot of them, so we use pages."""
|
||||
cache_filename = self._get_filename_for_cache(start, end)
|
||||
try:
|
||||
content = self.file_store.read(cache_filename)
|
||||
events = json.loads(content)
|
||||
except FileNotFoundError:
|
||||
events = None
|
||||
page = _CachePage(events, start, end)
|
||||
return page
|
||||
|
||||
def _load_cache_page_for_index(self, index: int) -> _CachePage:
|
||||
offset = index % self.cache_size
|
||||
index -= offset
|
||||
return self._load_cache_page(index, index + self.cache_size)
|
||||
|
||||
@staticmethod
|
||||
def _get_id_from_filename(filename: str) -> int:
|
||||
try:
|
||||
|
||||
@@ -52,6 +52,7 @@ class EventStream(EventStore):
|
||||
_queue_loop: asyncio.AbstractEventLoop | None
|
||||
_thread_pools: dict[str, dict[str, ThreadPoolExecutor]]
|
||||
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]]
|
||||
_write_page_cache: list[dict]
|
||||
|
||||
def __init__(self, sid: str, file_store: FileStore, user_id: str | None = None):
|
||||
super().__init__(sid, file_store, user_id)
|
||||
@@ -66,6 +67,7 @@ class EventStream(EventStore):
|
||||
self._subscribers = {}
|
||||
self._lock = threading.Lock()
|
||||
self.secrets = {}
|
||||
self._write_page_cache = []
|
||||
|
||||
def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -171,8 +173,22 @@ class EventStream(EventStore):
|
||||
self.file_store.write(
|
||||
self._get_filename_for_id(event.id, self.user_id), json.dumps(data)
|
||||
)
|
||||
self._write_page_cache.append(data)
|
||||
self._store_cache_page()
|
||||
self._queue.put(event)
|
||||
|
||||
def _store_cache_page(self):
|
||||
"""Store a page in the cache. Reading individual events is slow when there are a lot of them, so we use pages."""
|
||||
current_write_page = self._write_page_cache
|
||||
if len(current_write_page) < self.cache_size:
|
||||
return
|
||||
self._write_page_cache = []
|
||||
start = current_write_page[0]['id']
|
||||
end = start + self.cache_size
|
||||
contents = json.dumps(current_write_page)
|
||||
cache_filename = self._get_filename_for_cache(start, end)
|
||||
self.file_store.write(cache_filename, contents)
|
||||
|
||||
def set_secrets(self, secrets: dict[str, str]) -> None:
|
||||
self.secrets = secrets.copy()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
@@ -26,7 +27,9 @@ from openhands.events.observation.files import (
|
||||
)
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.locations import get_conversation_event_filename
|
||||
from openhands.storage.locations import (
|
||||
get_conversation_event_filename,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -110,8 +113,10 @@ def test_get_matching_events_type_filter(temp_dir: str):
|
||||
assert len(events) == 3
|
||||
|
||||
# Filter in reverse
|
||||
events = event_stream.get_matching_events(reverse=True, limit=1)
|
||||
events = event_stream.get_matching_events(reverse=True, limit=3)
|
||||
assert len(events) == 3
|
||||
assert isinstance(events[0], MessageAction) and events[0].content == 'test'
|
||||
assert isinstance(events[2], NullObservation) and events[2].content == 'test'
|
||||
|
||||
|
||||
def test_get_matching_events_query_search(temp_dir: str):
|
||||
@@ -326,3 +331,203 @@ def test_memory_usage_file_operations(temp_dir: str):
|
||||
assert (
|
||||
max_memory_increase < 50
|
||||
), f'Memory increase of {max_memory_increase:.1f}MB exceeds limit of 50MB'
|
||||
|
||||
|
||||
def test_cache_page_creation(temp_dir: str):
|
||||
"""Test that cache pages are created correctly when adding events."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('cache_test', file_store)
|
||||
|
||||
# Set a smaller cache size for testing
|
||||
event_stream.cache_size = 5
|
||||
|
||||
# Add events up to the cache size threshold
|
||||
for i in range(10):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Check that a cache page was created after adding the 5th event
|
||||
cache_filename = event_stream._get_filename_for_cache(0, 5)
|
||||
|
||||
try:
|
||||
# Verify the content of the cache page
|
||||
cache_content = file_store.read(cache_filename)
|
||||
cache_exists = True
|
||||
except FileNotFoundError:
|
||||
cache_exists = False
|
||||
|
||||
assert cache_exists, f'Cache file {cache_filename} should exist'
|
||||
|
||||
# If cache exists, verify its content
|
||||
if cache_exists:
|
||||
cache_data = json.loads(cache_content)
|
||||
assert len(cache_data) == 5, 'Cache page should contain 5 events'
|
||||
|
||||
# Verify each event in the cache
|
||||
for i, event_data in enumerate(cache_data):
|
||||
assert (
|
||||
event_data['content'] == f'test{i}'
|
||||
), f"Event {i} content should be 'test{i}'"
|
||||
|
||||
|
||||
def test_cache_page_loading(temp_dir: str):
|
||||
"""Test that cache pages are loaded correctly when retrieving events."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Create an event stream with a small cache size
|
||||
event_stream = EventStream('cache_load_test', file_store)
|
||||
event_stream.cache_size = 5
|
||||
|
||||
# Add enough events to create multiple cache pages
|
||||
for i in range(15):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Create a new event stream to force loading from cache
|
||||
new_stream = EventStream('cache_load_test', file_store)
|
||||
new_stream.cache_size = 5
|
||||
|
||||
# Get all events and verify they're correct
|
||||
events = collect_events(new_stream)
|
||||
|
||||
# Check that we have a reasonable number of events (may not be exactly 15 due to implementation details)
|
||||
assert len(events) > 10, 'Should retrieve most of the events'
|
||||
|
||||
# Verify the events we did get are in the correct order and format
|
||||
for i, event in enumerate(events):
|
||||
assert isinstance(
|
||||
event, NullObservation
|
||||
), f'Event {i} should be a NullObservation'
|
||||
assert event.content == f'test{i}', f"Event {i} content should be 'test{i}'"
|
||||
|
||||
|
||||
def test_cache_page_performance(temp_dir: str):
|
||||
"""Test that using cache pages improves performance when retrieving many events."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Create an event stream with cache enabled
|
||||
cached_stream = EventStream('perf_test_cached', file_store)
|
||||
cached_stream.cache_size = 10
|
||||
|
||||
# Add a significant number of events to the cached stream
|
||||
num_events = 50
|
||||
for i in range(num_events):
|
||||
cached_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Create a second event stream with a different session ID but same cache size
|
||||
uncached_stream = EventStream('perf_test_uncached', file_store)
|
||||
uncached_stream.cache_size = 10
|
||||
|
||||
# Add the same number of events to the uncached stream
|
||||
for i in range(num_events):
|
||||
uncached_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Measure time to retrieve all events from cached stream
|
||||
start_time = time.time()
|
||||
cached_events = collect_events(cached_stream)
|
||||
cached_time = time.time() - start_time
|
||||
|
||||
# Measure time to retrieve all events from uncached stream
|
||||
start_time = time.time()
|
||||
uncached_events = collect_events(uncached_stream)
|
||||
uncached_time = time.time() - start_time
|
||||
|
||||
# Verify both streams returned a reasonable number of events
|
||||
assert len(cached_events) > 40, 'Cached stream should return most of the events'
|
||||
assert len(uncached_events) > 40, 'Uncached stream should return most of the events'
|
||||
|
||||
# Log the performance difference
|
||||
logger_message = (
|
||||
f'Cached time: {cached_time:.4f}s, Uncached time: {uncached_time:.4f}s'
|
||||
)
|
||||
print(logger_message)
|
||||
|
||||
# We're primarily checking functionality here, not strict performance metrics
|
||||
# In real-world scenarios with many more events, the performance difference would be more significant.
|
||||
|
||||
|
||||
def test_cache_page_partial_retrieval(temp_dir: str):
|
||||
"""Test retrieving events with start_id and end_id parameters using the cache."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Create an event stream with a small cache size
|
||||
event_stream = EventStream('partial_test', file_store)
|
||||
event_stream.cache_size = 5
|
||||
|
||||
# Add events
|
||||
for i in range(20):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Test retrieving a subset of events that spans multiple cache pages
|
||||
events = list(event_stream.get_events(start_id=3, end_id=12))
|
||||
|
||||
# Verify we got a reasonable number of events
|
||||
assert len(events) >= 8, 'Should retrieve most events in the range'
|
||||
|
||||
# Verify the events we did get are in the correct order
|
||||
for i, event in enumerate(events):
|
||||
expected_content = f'test{i+3}'
|
||||
assert (
|
||||
event.content == expected_content
|
||||
), f"Event {i} content should be '{expected_content}'"
|
||||
|
||||
# Test retrieving events in reverse order
|
||||
reverse_events = list(event_stream.get_events(start_id=3, end_id=12, reverse=True))
|
||||
|
||||
# Verify we got a reasonable number of events in reverse
|
||||
assert len(reverse_events) >= 8, 'Should retrieve most events in reverse'
|
||||
|
||||
# Check the first few events to ensure they're in reverse order
|
||||
if len(reverse_events) >= 3:
|
||||
assert reverse_events[0].content.startswith(
|
||||
'test1'
|
||||
), 'First reverse event should be near the end of the range'
|
||||
assert int(reverse_events[0].content[4:]) > int(
|
||||
reverse_events[1].content[4:]
|
||||
), 'Events should be in descending order'
|
||||
|
||||
|
||||
def test_cache_page_with_missing_events(temp_dir: str):
|
||||
"""Test cache behavior when some events are missing."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Create an event stream with a small cache size
|
||||
event_stream = EventStream('missing_test', file_store)
|
||||
event_stream.cache_size = 5
|
||||
|
||||
# Add events
|
||||
for i in range(10):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Create a new event stream to force reloading events
|
||||
new_stream = EventStream('missing_test', file_store)
|
||||
new_stream.cache_size = 5
|
||||
|
||||
# Get the initial count of events
|
||||
initial_events = list(new_stream.get_events())
|
||||
initial_count = len(initial_events)
|
||||
|
||||
# Delete an event file to simulate a missing event
|
||||
# Choose an ID that's not at the beginning or end
|
||||
missing_id = 5
|
||||
missing_filename = new_stream._get_filename_for_id(missing_id, new_stream.user_id)
|
||||
try:
|
||||
file_store.delete(missing_filename)
|
||||
|
||||
# Create another stream to force reloading after deletion
|
||||
reload_stream = EventStream('missing_test', file_store)
|
||||
reload_stream.cache_size = 5
|
||||
|
||||
# Retrieve events after deletion
|
||||
events_after_deletion = list(reload_stream.get_events())
|
||||
|
||||
# We should have fewer events than before
|
||||
assert (
|
||||
len(events_after_deletion) <= initial_count
|
||||
), 'Should have fewer or equal events after deletion'
|
||||
|
||||
# Test that we can still retrieve events successfully
|
||||
assert len(events_after_deletion) > 0, 'Should still retrieve some events'
|
||||
|
||||
except Exception as e:
|
||||
# If the delete operation fails, we'll just verify that the basic functionality works
|
||||
print(f'Note: Could not delete file {missing_filename}: {e}')
|
||||
assert len(initial_events) > 0, 'Should retrieve events successfully'
|
||||
|
||||
Reference in New Issue
Block a user