mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Optimize memory usage in FileEditObservation (#6622)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
parent
ff48f8beba
commit
93d2e4a338
@ -1,3 +1,5 @@
|
||||
"""File-related observation classes for tracking file operations."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
@ -16,32 +18,40 @@ class FileReadObservation(Observation):
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
"""Get a human-readable message describing the file read operation."""
|
||||
return f'I read the file {self.path}.'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'[Read from {self.path} is successful.]\n' f'{self.content}'
|
||||
"""Get a string representation of the file read observation."""
|
||||
return f'[Read from {self.path} is successful.]\n{self.content}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileWriteObservation(Observation):
|
||||
"""This data class represents a file write operation"""
|
||||
"""This data class represents a file write operation."""
|
||||
|
||||
path: str
|
||||
observation: str = ObservationType.WRITE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
"""Get a human-readable message describing the file write operation."""
|
||||
return f'I wrote to the file {self.path}.'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'[Write to {self.path} is successful.]\n' f'{self.content}'
|
||||
"""Get a string representation of the file write observation."""
|
||||
return f'[Write to {self.path} is successful.]\n{self.content}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileEditObservation(Observation):
|
||||
"""This data class represents a file edit operation"""
|
||||
"""This data class represents a file edit operation.
|
||||
|
||||
The observation includes both the old and new content of the file, and can
|
||||
generate a diff visualization showing the changes. The diff is computed lazily
|
||||
and cached to improve performance.
|
||||
"""
|
||||
|
||||
# content: str will be a unified diff patch string include NO context lines
|
||||
path: str
|
||||
prev_exist: bool
|
||||
old_content: str
|
||||
@ -49,22 +59,31 @@ class FileEditObservation(Observation):
|
||||
observation: str = ObservationType.EDIT
|
||||
impl_source: FileEditSource = FileEditSource.LLM_BASED_EDIT
|
||||
formatted_output_and_error: str = ''
|
||||
_diff_cache: str | None = None # Cache for the diff visualization
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
"""Get a human-readable message describing the file edit operation."""
|
||||
return f'I edited the file {self.path}.'
|
||||
|
||||
def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]]]:
|
||||
"""Get the edit groups of the file edit."""
|
||||
"""Get the edit groups showing changes between old and new content.
|
||||
|
||||
Args:
|
||||
n_context_lines: Number of context lines to show around each change.
|
||||
|
||||
Returns:
|
||||
A list of edit groups, where each group contains before/after edits.
|
||||
"""
|
||||
old_lines = self.old_content.split('\n')
|
||||
new_lines = self.new_content.split('\n')
|
||||
# Borrowed from difflib.unified_diff to directly parse into structured format.
|
||||
# Borrowed from difflib.unified_diff to directly parse into structured format
|
||||
edit_groups: list[dict] = []
|
||||
for group in SequenceMatcher(None, old_lines, new_lines).get_grouped_opcodes(
|
||||
n_context_lines
|
||||
):
|
||||
# take the max line number in the group
|
||||
_indent_pad_size = len(str(group[-1][3])) + 1 # +1 for the "*" prefix
|
||||
# Take the max line number in the group
|
||||
_indent_pad_size = len(str(group[-1][3])) + 1 # +1 for "*" prefix
|
||||
cur_group: dict[str, list[str]] = {
|
||||
'before_edits': [],
|
||||
'after_edits': [],
|
||||
@ -72,23 +91,27 @@ class FileEditObservation(Observation):
|
||||
for tag, i1, i2, j1, j2 in group:
|
||||
if tag == 'equal':
|
||||
for idx, line in enumerate(old_lines[i1:i2]):
|
||||
line_num = i1 + idx + 1
|
||||
cur_group['before_edits'].append(
|
||||
f'{i1+idx+1:>{_indent_pad_size}}|{line}'
|
||||
f'{line_num:>{_indent_pad_size}}|{line}'
|
||||
)
|
||||
for idx, line in enumerate(new_lines[j1:j2]):
|
||||
line_num = j1 + idx + 1
|
||||
cur_group['after_edits'].append(
|
||||
f'{j1+idx+1:>{_indent_pad_size}}|{line}'
|
||||
f'{line_num:>{_indent_pad_size}}|{line}'
|
||||
)
|
||||
continue
|
||||
if tag in {'replace', 'delete'}:
|
||||
for idx, line in enumerate(old_lines[i1:i2]):
|
||||
line_num = i1 + idx + 1
|
||||
cur_group['before_edits'].append(
|
||||
f'-{i1+idx+1:>{_indent_pad_size-1}}|{line}'
|
||||
f'-{line_num:>{_indent_pad_size-1}}|{line}'
|
||||
)
|
||||
if tag in {'replace', 'insert'}:
|
||||
for idx, line in enumerate(new_lines[j1:j2]):
|
||||
line_num = j1 + idx + 1
|
||||
cur_group['after_edits'].append(
|
||||
f'+{j1+idx+1:>{_indent_pad_size-1}}|{line}'
|
||||
f'+{line_num:>{_indent_pad_size-1}}|{line}'
|
||||
)
|
||||
edit_groups.append(cur_group)
|
||||
return edit_groups
|
||||
@ -100,24 +123,37 @@ class FileEditObservation(Observation):
|
||||
) -> str:
|
||||
"""Visualize the diff of the file edit.
|
||||
|
||||
Instead of showing the diff line by line, this function
|
||||
shows each hunk of changes as a separate entity.
|
||||
Instead of showing the diff line by line, this function shows each hunk
|
||||
of changes as a separate entity.
|
||||
|
||||
Args:
|
||||
n_context_lines: The number of lines of context to show before and after the changes.
|
||||
change_applied: Whether the changes are applied to the file. If true, the file have been modified. If not, the file is not modified (due to linting errors).
|
||||
n_context_lines: Number of context lines to show before/after changes.
|
||||
change_applied: Whether changes are applied. If false, shows as
|
||||
attempted edit.
|
||||
|
||||
Returns:
|
||||
A string containing the formatted diff visualization.
|
||||
"""
|
||||
if change_applied and self.content.strip() == '':
|
||||
# diff patch is empty
|
||||
return '(no changes detected. Please make sure your edits changes the content of the existing file.)\n'
|
||||
# Use cached diff if available
|
||||
if self._diff_cache is not None:
|
||||
return self._diff_cache
|
||||
|
||||
# Check if there are any changes
|
||||
if change_applied and self.old_content == self.new_content:
|
||||
msg = '(no changes detected. Please make sure your edits change '
|
||||
msg += 'the content of the existing file.)\n'
|
||||
self._diff_cache = msg
|
||||
return self._diff_cache
|
||||
|
||||
edit_groups = self.get_edit_groups(n_context_lines=n_context_lines)
|
||||
|
||||
result = [
|
||||
f'[Existing file {self.path} is edited with {len(edit_groups)} changes.]'
|
||||
if change_applied
|
||||
else f"[Changes are NOT applied to {self.path} - Here's how the file looks like if changes are applied.]"
|
||||
]
|
||||
if change_applied:
|
||||
header = f'[Existing file {self.path} is edited with '
|
||||
header += f'{len(edit_groups)} changes.]'
|
||||
else:
|
||||
header = f"[Changes are NOT applied to {self.path} - Here's how "
|
||||
header += 'the file looks like if changes are applied.]'
|
||||
result = [header]
|
||||
|
||||
op_type = 'edit' if change_applied else 'ATTEMPTED edit'
|
||||
for i, cur_edit_group in enumerate(edit_groups):
|
||||
@ -129,18 +165,21 @@ class FileEditObservation(Observation):
|
||||
result.append(f'(content after {op_type})')
|
||||
result.extend(cur_edit_group['after_edits'])
|
||||
result.append(f'[end of {op_type} {i+1} / {len(edit_groups)}]')
|
||||
return '\n'.join(result)
|
||||
|
||||
# Cache the result
|
||||
self._diff_cache = '\n'.join(result)
|
||||
return self._diff_cache
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Get a string representation of the file edit observation."""
|
||||
if self.impl_source == FileEditSource.OH_ACI:
|
||||
return self.formatted_output_and_error
|
||||
|
||||
ret = ''
|
||||
if not self.prev_exist:
|
||||
assert (
|
||||
self.old_content == ''
|
||||
), 'old_content should be empty if the file is new (prev_exist=False).'
|
||||
ret += f'[New file {self.path} is created with the provided content.]\n'
|
||||
return ret.rstrip() + '\n'
|
||||
ret += self.visualize_diff()
|
||||
return ret.rstrip() + '\n'
|
||||
return f'[New file {self.path} is created with the provided content.]\n'
|
||||
|
||||
# Use cached diff if available, otherwise compute it
|
||||
return self.visualize_diff().rstrip() + '\n'
|
||||
|
||||
@ -120,6 +120,10 @@ class EventStream:
|
||||
for callback_id in callback_ids:
|
||||
self._clean_up_subscriber(subscriber_id, callback_id)
|
||||
|
||||
# Clear queue
|
||||
while not self._queue.empty():
|
||||
self._queue.get()
|
||||
|
||||
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str):
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
|
||||
|
||||
@ -1,15 +1,29 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
from pytest import TempPathFactory
|
||||
|
||||
from openhands.core.schema.observation import ObservationType
|
||||
from openhands.core.schema import ActionType, ObservationType
|
||||
from openhands.events import EventSource, EventStream
|
||||
from openhands.events.action import (
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.action.files import (
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import FileEditSource, FileReadSource
|
||||
from openhands.events.observation import NullObservation
|
||||
from openhands.events.observation.files import (
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
@ -185,3 +199,103 @@ def test_get_matching_events_limit_validation(temp_dir: str):
|
||||
assert len(events) == 1
|
||||
events = event_stream.get_matching_events(limit=100)
|
||||
assert len(events) == 1
|
||||
|
||||
|
||||
def test_memory_usage_file_operations(temp_dir: str):
|
||||
"""Test memory usage during file operations in EventStream.
|
||||
|
||||
This test verifies that memory usage during file operations is reasonable
|
||||
and that memory is properly cleaned up after operations complete.
|
||||
"""
|
||||
|
||||
def get_memory_mb():
|
||||
"""Get current memory usage in MB"""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
# Create a test file with 100kb content
|
||||
test_file = os.path.join(temp_dir, 'test_file.txt')
|
||||
test_content = 'x' * (100 * 1024) # 100kb of data
|
||||
with open(test_file, 'w') as f:
|
||||
f.write(test_content)
|
||||
|
||||
# Initialize FileStore and EventStream
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Record initial memory usage
|
||||
gc.collect()
|
||||
initial_memory = get_memory_mb()
|
||||
max_memory_increase = 0
|
||||
|
||||
# Perform operations 20 times
|
||||
for i in range(20):
|
||||
event_stream = EventStream('test_session', file_store)
|
||||
|
||||
# 1. Read file
|
||||
read_action = FileReadAction(
|
||||
path=test_file,
|
||||
start=0,
|
||||
end=-1,
|
||||
thought='Reading file',
|
||||
action=ActionType.READ,
|
||||
impl_source=FileReadSource.DEFAULT,
|
||||
)
|
||||
event_stream.add_event(read_action, EventSource.AGENT)
|
||||
|
||||
read_obs = FileReadObservation(
|
||||
path=test_file, impl_source=FileReadSource.DEFAULT, content=test_content
|
||||
)
|
||||
event_stream.add_event(read_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
# 2. Write file
|
||||
write_action = FileWriteAction(
|
||||
path=test_file,
|
||||
content=test_content,
|
||||
start=0,
|
||||
end=-1,
|
||||
thought='Writing file',
|
||||
action=ActionType.WRITE,
|
||||
)
|
||||
event_stream.add_event(write_action, EventSource.AGENT)
|
||||
|
||||
write_obs = FileWriteObservation(path=test_file, content=test_content)
|
||||
event_stream.add_event(write_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
# 3. Edit file
|
||||
edit_action = FileEditAction(
|
||||
path=test_file,
|
||||
content=test_content,
|
||||
start=1,
|
||||
end=-1,
|
||||
thought='Editing file',
|
||||
action=ActionType.EDIT,
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
)
|
||||
event_stream.add_event(edit_action, EventSource.AGENT)
|
||||
|
||||
edit_obs = FileEditObservation(
|
||||
path=test_file,
|
||||
prev_exist=True,
|
||||
old_content=test_content,
|
||||
new_content=test_content,
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content=test_content,
|
||||
)
|
||||
event_stream.add_event(edit_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
# Close event stream and force garbage collection
|
||||
event_stream.close()
|
||||
gc.collect()
|
||||
|
||||
# Check memory usage
|
||||
current_memory = get_memory_mb()
|
||||
memory_increase = current_memory - initial_memory
|
||||
max_memory_increase = max(max_memory_increase, memory_increase)
|
||||
|
||||
# Clean up
|
||||
os.remove(test_file)
|
||||
|
||||
# Memory increase should be reasonable (less than 50MB after 20 iterations)
|
||||
assert (
|
||||
max_memory_increase < 50
|
||||
), f'Memory increase of {max_memory_increase:.1f}MB exceeds limit of 50MB'
|
||||
|
||||
135
tests/unit/test_file_edit_observation.py
Normal file
135
tests/unit/test_file_edit_observation.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""Tests for FileEditObservation class."""
|
||||
|
||||
from openhands.events.event import FileEditSource
|
||||
from openhands.events.observation.files import FileEditObservation
|
||||
|
||||
|
||||
def test_file_edit_observation_basic():
|
||||
"""Test basic properties of FileEditObservation."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content='Hello\nWorld\n',
|
||||
new_content='Hello\nNew World\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='Hello\nWorld\n', # Initial content is old_content
|
||||
)
|
||||
|
||||
assert obs.path == '/test/file.txt'
|
||||
assert obs.prev_exist is True
|
||||
assert obs.old_content == 'Hello\nWorld\n'
|
||||
assert obs.new_content == 'Hello\nNew World\n'
|
||||
assert obs.impl_source == FileEditSource.LLM_BASED_EDIT
|
||||
assert obs.message == 'I edited the file /test/file.txt.'
|
||||
|
||||
|
||||
def test_file_edit_observation_diff_cache():
|
||||
"""Test that diff visualization is cached."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content='Hello\nWorld\n',
|
||||
new_content='Hello\nNew World\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='Hello\nWorld\n', # Initial content is old_content
|
||||
)
|
||||
|
||||
# First call should compute diff
|
||||
diff1 = obs.visualize_diff()
|
||||
assert obs._diff_cache is not None
|
||||
|
||||
# Second call should use cache
|
||||
diff2 = obs.visualize_diff()
|
||||
assert diff1 == diff2
|
||||
|
||||
|
||||
def test_file_edit_observation_no_changes():
|
||||
"""Test behavior when content hasn't changed."""
|
||||
content = 'Hello\nWorld\n'
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content=content,
|
||||
new_content=content,
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content=content, # Initial content is old_content
|
||||
)
|
||||
|
||||
diff = obs.visualize_diff()
|
||||
assert '(no changes detected' in diff
|
||||
|
||||
|
||||
def test_file_edit_observation_get_edit_groups():
|
||||
"""Test the get_edit_groups method."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content='Line 1\nLine 2\nLine 3\nLine 4\n',
|
||||
new_content='Line 1\nNew Line 2\nLine 3\nNew Line 4\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='Line 1\nLine 2\nLine 3\nLine 4\n', # Initial content is old_content
|
||||
)
|
||||
|
||||
groups = obs.get_edit_groups(n_context_lines=1)
|
||||
assert len(groups) > 0
|
||||
|
||||
# Check structure of edit groups
|
||||
for group in groups:
|
||||
assert 'before_edits' in group
|
||||
assert 'after_edits' in group
|
||||
assert isinstance(group['before_edits'], list)
|
||||
assert isinstance(group['after_edits'], list)
|
||||
|
||||
# Verify line numbers and content
|
||||
first_group = groups[0]
|
||||
assert any('Line 2' in line for line in first_group['before_edits'])
|
||||
assert any('New Line 2' in line for line in first_group['after_edits'])
|
||||
|
||||
|
||||
def test_file_edit_observation_new_file():
|
||||
"""Test behavior when editing a new file."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/new_file.txt',
|
||||
prev_exist=False,
|
||||
old_content='',
|
||||
new_content='Hello\nWorld\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='', # Initial content is old_content (empty for new file)
|
||||
)
|
||||
|
||||
assert obs.prev_exist is False
|
||||
assert obs.old_content == ''
|
||||
assert (
|
||||
str(obs)
|
||||
== '[New file /test/new_file.txt is created with the provided content.]\n'
|
||||
)
|
||||
|
||||
# Test that trying to visualize diff for a new file works
|
||||
diff = obs.visualize_diff()
|
||||
assert diff is not None
|
||||
|
||||
|
||||
def test_file_edit_observation_context_lines():
|
||||
"""Test diff visualization with different context line settings."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content='Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n',
|
||||
new_content='Line 1\nNew Line 2\nLine 3\nNew Line 4\nLine 5\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n', # Initial content is old_content
|
||||
)
|
||||
|
||||
# Test with 0 context lines
|
||||
groups_0 = obs.get_edit_groups(n_context_lines=0)
|
||||
# Test with 2 context lines
|
||||
groups_2 = obs.get_edit_groups(n_context_lines=2)
|
||||
|
||||
# More context should mean more lines in the groups
|
||||
total_lines_0 = sum(
|
||||
len(g['before_edits']) + len(g['after_edits']) for g in groups_0
|
||||
)
|
||||
total_lines_2 = sum(
|
||||
len(g['before_edits']) + len(g['after_edits']) for g in groups_2
|
||||
)
|
||||
assert total_lines_2 > total_lines_0
|
||||
Loading…
x
Reference in New Issue
Block a user