Optimize memory usage in FileEditObservation (#6622)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
Graham Neubig 2025-02-07 08:19:32 -05:00 committed by GitHub
parent ff48f8beba
commit 93d2e4a338
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 324 additions and 32 deletions

View File

@ -1,3 +1,5 @@
"""File-related observation classes for tracking file operations."""
from dataclasses import dataclass
from difflib import SequenceMatcher
@ -16,32 +18,40 @@ class FileReadObservation(Observation):
@property
def message(self) -> str:
"""Get a human-readable message describing the file read operation."""
return f'I read the file {self.path}.'
def __str__(self) -> str:
return f'[Read from {self.path} is successful.]\n' f'{self.content}'
"""Get a string representation of the file read observation."""
return f'[Read from {self.path} is successful.]\n{self.content}'
@dataclass
class FileWriteObservation(Observation):
"""This data class represents a file write operation"""
"""This data class represents a file write operation."""
path: str
observation: str = ObservationType.WRITE
@property
def message(self) -> str:
"""Get a human-readable message describing the file write operation."""
return f'I wrote to the file {self.path}.'
def __str__(self) -> str:
return f'[Write to {self.path} is successful.]\n' f'{self.content}'
"""Get a string representation of the file write observation."""
return f'[Write to {self.path} is successful.]\n{self.content}'
@dataclass
class FileEditObservation(Observation):
"""This data class represents a file edit operation"""
"""This data class represents a file edit operation.
The observation includes both the old and new content of the file, and can
generate a diff visualization showing the changes. The diff is computed lazily
and cached to improve performance.
"""
# content: str will be a unified diff patch string include NO context lines
path: str
prev_exist: bool
old_content: str
@ -49,22 +59,31 @@ class FileEditObservation(Observation):
observation: str = ObservationType.EDIT
impl_source: FileEditSource = FileEditSource.LLM_BASED_EDIT
formatted_output_and_error: str = ''
_diff_cache: str | None = None # Cache for the diff visualization
@property
def message(self) -> str:
"""Get a human-readable message describing the file edit operation."""
return f'I edited the file {self.path}.'
def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]]]:
"""Get the edit groups of the file edit."""
"""Get the edit groups showing changes between old and new content.
Args:
n_context_lines: Number of context lines to show around each change.
Returns:
A list of edit groups, where each group contains before/after edits.
"""
old_lines = self.old_content.split('\n')
new_lines = self.new_content.split('\n')
# Borrowed from difflib.unified_diff to directly parse into structured format.
# Borrowed from difflib.unified_diff to directly parse into structured format
edit_groups: list[dict] = []
for group in SequenceMatcher(None, old_lines, new_lines).get_grouped_opcodes(
n_context_lines
):
# take the max line number in the group
_indent_pad_size = len(str(group[-1][3])) + 1 # +1 for the "*" prefix
# Take the max line number in the group
_indent_pad_size = len(str(group[-1][3])) + 1 # +1 for "*" prefix
cur_group: dict[str, list[str]] = {
'before_edits': [],
'after_edits': [],
@ -72,23 +91,27 @@ class FileEditObservation(Observation):
for tag, i1, i2, j1, j2 in group:
if tag == 'equal':
for idx, line in enumerate(old_lines[i1:i2]):
line_num = i1 + idx + 1
cur_group['before_edits'].append(
f'{i1+idx+1:>{_indent_pad_size}}|{line}'
f'{line_num:>{_indent_pad_size}}|{line}'
)
for idx, line in enumerate(new_lines[j1:j2]):
line_num = j1 + idx + 1
cur_group['after_edits'].append(
f'{j1+idx+1:>{_indent_pad_size}}|{line}'
f'{line_num:>{_indent_pad_size}}|{line}'
)
continue
if tag in {'replace', 'delete'}:
for idx, line in enumerate(old_lines[i1:i2]):
line_num = i1 + idx + 1
cur_group['before_edits'].append(
f'-{i1+idx+1:>{_indent_pad_size-1}}|{line}'
f'-{line_num:>{_indent_pad_size-1}}|{line}'
)
if tag in {'replace', 'insert'}:
for idx, line in enumerate(new_lines[j1:j2]):
line_num = j1 + idx + 1
cur_group['after_edits'].append(
f'+{j1+idx+1:>{_indent_pad_size-1}}|{line}'
f'+{line_num:>{_indent_pad_size-1}}|{line}'
)
edit_groups.append(cur_group)
return edit_groups
@ -100,24 +123,37 @@ class FileEditObservation(Observation):
) -> str:
"""Visualize the diff of the file edit.
Instead of showing the diff line by line, this function
shows each hunk of changes as a separate entity.
Instead of showing the diff line by line, this function shows each hunk
of changes as a separate entity.
Args:
n_context_lines: The number of lines of context to show before and after the changes.
change_applied: Whether the changes are applied to the file. If true, the file have been modified. If not, the file is not modified (due to linting errors).
n_context_lines: Number of context lines to show before/after changes.
change_applied: Whether changes are applied. If false, shows as
attempted edit.
Returns:
A string containing the formatted diff visualization.
"""
if change_applied and self.content.strip() == '':
# diff patch is empty
return '(no changes detected. Please make sure your edits changes the content of the existing file.)\n'
# Use cached diff if available
if self._diff_cache is not None:
return self._diff_cache
# Check if there are any changes
if change_applied and self.old_content == self.new_content:
msg = '(no changes detected. Please make sure your edits change '
msg += 'the content of the existing file.)\n'
self._diff_cache = msg
return self._diff_cache
edit_groups = self.get_edit_groups(n_context_lines=n_context_lines)
result = [
f'[Existing file {self.path} is edited with {len(edit_groups)} changes.]'
if change_applied
else f"[Changes are NOT applied to {self.path} - Here's how the file looks like if changes are applied.]"
]
if change_applied:
header = f'[Existing file {self.path} is edited with '
header += f'{len(edit_groups)} changes.]'
else:
header = f"[Changes are NOT applied to {self.path} - Here's how "
header += 'the file looks like if changes are applied.]'
result = [header]
op_type = 'edit' if change_applied else 'ATTEMPTED edit'
for i, cur_edit_group in enumerate(edit_groups):
@ -129,18 +165,21 @@ class FileEditObservation(Observation):
result.append(f'(content after {op_type})')
result.extend(cur_edit_group['after_edits'])
result.append(f'[end of {op_type} {i+1} / {len(edit_groups)}]')
return '\n'.join(result)
# Cache the result
self._diff_cache = '\n'.join(result)
return self._diff_cache
def __str__(self) -> str:
"""Get a string representation of the file edit observation."""
if self.impl_source == FileEditSource.OH_ACI:
return self.formatted_output_and_error
ret = ''
if not self.prev_exist:
assert (
self.old_content == ''
), 'old_content should be empty if the file is new (prev_exist=False).'
ret += f'[New file {self.path} is created with the provided content.]\n'
return ret.rstrip() + '\n'
ret += self.visualize_diff()
return ret.rstrip() + '\n'
return f'[New file {self.path} is created with the provided content.]\n'
# Use cached diff if available, otherwise compute it
return self.visualize_diff().rstrip() + '\n'

View File

@ -120,6 +120,10 @@ class EventStream:
for callback_id in callback_ids:
self._clean_up_subscriber(subscriber_id, callback_id)
# Clear queue
while not self._queue.empty():
self._queue.get()
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str):
if subscriber_id not in self._subscribers:
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')

View File

@ -1,15 +1,29 @@
import gc
import json
import os
import psutil
import pytest
from pytest import TempPathFactory
from openhands.core.schema.observation import ObservationType
from openhands.core.schema import ActionType, ObservationType
from openhands.events import EventSource, EventStream
from openhands.events.action import (
NullAction,
)
from openhands.events.action.files import (
FileEditAction,
FileReadAction,
FileWriteAction,
)
from openhands.events.action.message import MessageAction
from openhands.events.event import FileEditSource, FileReadSource
from openhands.events.observation import NullObservation
from openhands.events.observation.files import (
FileEditObservation,
FileReadObservation,
FileWriteObservation,
)
from openhands.storage import get_file_store
@ -185,3 +199,103 @@ def test_get_matching_events_limit_validation(temp_dir: str):
assert len(events) == 1
events = event_stream.get_matching_events(limit=100)
assert len(events) == 1
def test_memory_usage_file_operations(temp_dir: str):
"""Test memory usage during file operations in EventStream.
This test verifies that memory usage during file operations is reasonable
and that memory is properly cleaned up after operations complete.
"""
def get_memory_mb():
"""Get current memory usage in MB"""
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024
# Create a test file with 100kb content
test_file = os.path.join(temp_dir, 'test_file.txt')
test_content = 'x' * (100 * 1024) # 100kb of data
with open(test_file, 'w') as f:
f.write(test_content)
# Initialize FileStore and EventStream
file_store = get_file_store('local', temp_dir)
# Record initial memory usage
gc.collect()
initial_memory = get_memory_mb()
max_memory_increase = 0
# Perform operations 20 times
for i in range(20):
event_stream = EventStream('test_session', file_store)
# 1. Read file
read_action = FileReadAction(
path=test_file,
start=0,
end=-1,
thought='Reading file',
action=ActionType.READ,
impl_source=FileReadSource.DEFAULT,
)
event_stream.add_event(read_action, EventSource.AGENT)
read_obs = FileReadObservation(
path=test_file, impl_source=FileReadSource.DEFAULT, content=test_content
)
event_stream.add_event(read_obs, EventSource.ENVIRONMENT)
# 2. Write file
write_action = FileWriteAction(
path=test_file,
content=test_content,
start=0,
end=-1,
thought='Writing file',
action=ActionType.WRITE,
)
event_stream.add_event(write_action, EventSource.AGENT)
write_obs = FileWriteObservation(path=test_file, content=test_content)
event_stream.add_event(write_obs, EventSource.ENVIRONMENT)
# 3. Edit file
edit_action = FileEditAction(
path=test_file,
content=test_content,
start=1,
end=-1,
thought='Editing file',
action=ActionType.EDIT,
impl_source=FileEditSource.LLM_BASED_EDIT,
)
event_stream.add_event(edit_action, EventSource.AGENT)
edit_obs = FileEditObservation(
path=test_file,
prev_exist=True,
old_content=test_content,
new_content=test_content,
impl_source=FileEditSource.LLM_BASED_EDIT,
content=test_content,
)
event_stream.add_event(edit_obs, EventSource.ENVIRONMENT)
# Close event stream and force garbage collection
event_stream.close()
gc.collect()
# Check memory usage
current_memory = get_memory_mb()
memory_increase = current_memory - initial_memory
max_memory_increase = max(max_memory_increase, memory_increase)
# Clean up
os.remove(test_file)
# Memory increase should be reasonable (less than 50MB after 20 iterations)
assert (
max_memory_increase < 50
), f'Memory increase of {max_memory_increase:.1f}MB exceeds limit of 50MB'

View File

@ -0,0 +1,135 @@
"""Tests for FileEditObservation class."""
from openhands.events.event import FileEditSource
from openhands.events.observation.files import FileEditObservation
def test_file_edit_observation_basic():
"""Test basic properties of FileEditObservation."""
obs = FileEditObservation(
path='/test/file.txt',
prev_exist=True,
old_content='Hello\nWorld\n',
new_content='Hello\nNew World\n',
impl_source=FileEditSource.LLM_BASED_EDIT,
content='Hello\nWorld\n', # Initial content is old_content
)
assert obs.path == '/test/file.txt'
assert obs.prev_exist is True
assert obs.old_content == 'Hello\nWorld\n'
assert obs.new_content == 'Hello\nNew World\n'
assert obs.impl_source == FileEditSource.LLM_BASED_EDIT
assert obs.message == 'I edited the file /test/file.txt.'
def test_file_edit_observation_diff_cache():
"""Test that diff visualization is cached."""
obs = FileEditObservation(
path='/test/file.txt',
prev_exist=True,
old_content='Hello\nWorld\n',
new_content='Hello\nNew World\n',
impl_source=FileEditSource.LLM_BASED_EDIT,
content='Hello\nWorld\n', # Initial content is old_content
)
# First call should compute diff
diff1 = obs.visualize_diff()
assert obs._diff_cache is not None
# Second call should use cache
diff2 = obs.visualize_diff()
assert diff1 == diff2
def test_file_edit_observation_no_changes():
"""Test behavior when content hasn't changed."""
content = 'Hello\nWorld\n'
obs = FileEditObservation(
path='/test/file.txt',
prev_exist=True,
old_content=content,
new_content=content,
impl_source=FileEditSource.LLM_BASED_EDIT,
content=content, # Initial content is old_content
)
diff = obs.visualize_diff()
assert '(no changes detected' in diff
def test_file_edit_observation_get_edit_groups():
"""Test the get_edit_groups method."""
obs = FileEditObservation(
path='/test/file.txt',
prev_exist=True,
old_content='Line 1\nLine 2\nLine 3\nLine 4\n',
new_content='Line 1\nNew Line 2\nLine 3\nNew Line 4\n',
impl_source=FileEditSource.LLM_BASED_EDIT,
content='Line 1\nLine 2\nLine 3\nLine 4\n', # Initial content is old_content
)
groups = obs.get_edit_groups(n_context_lines=1)
assert len(groups) > 0
# Check structure of edit groups
for group in groups:
assert 'before_edits' in group
assert 'after_edits' in group
assert isinstance(group['before_edits'], list)
assert isinstance(group['after_edits'], list)
# Verify line numbers and content
first_group = groups[0]
assert any('Line 2' in line for line in first_group['before_edits'])
assert any('New Line 2' in line for line in first_group['after_edits'])
def test_file_edit_observation_new_file():
"""Test behavior when editing a new file."""
obs = FileEditObservation(
path='/test/new_file.txt',
prev_exist=False,
old_content='',
new_content='Hello\nWorld\n',
impl_source=FileEditSource.LLM_BASED_EDIT,
content='', # Initial content is old_content (empty for new file)
)
assert obs.prev_exist is False
assert obs.old_content == ''
assert (
str(obs)
== '[New file /test/new_file.txt is created with the provided content.]\n'
)
# Test that trying to visualize diff for a new file works
diff = obs.visualize_diff()
assert diff is not None
def test_file_edit_observation_context_lines():
"""Test diff visualization with different context line settings."""
obs = FileEditObservation(
path='/test/file.txt',
prev_exist=True,
old_content='Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n',
new_content='Line 1\nNew Line 2\nLine 3\nNew Line 4\nLine 5\n',
impl_source=FileEditSource.LLM_BASED_EDIT,
content='Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n', # Initial content is old_content
)
# Test with 0 context lines
groups_0 = obs.get_edit_groups(n_context_lines=0)
# Test with 2 context lines
groups_2 = obs.get_edit_groups(n_context_lines=2)
# More context should mean more lines in the groups
total_lines_0 = sum(
len(g['before_edits']) + len(g['after_edits']) for g in groups_0
)
total_lines_2 = sum(
len(g['before_edits']) + len(g['after_edits']) for g in groups_2
)
assert total_lines_2 > total_lines_0