mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Improve fileop logic to handle multiple cases
This commit is contained in:
parent
04e9caab8b
commit
d0ba806f35
@ -1,6 +1,8 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -19,10 +21,8 @@ from .base import ExecutableAction
|
||||
# The LLM sometimes returns paths with this prefix, so we need to remove it
|
||||
PATH_PREFIX = '/workspace/'
|
||||
|
||||
# claude generated this and I have no clue if it works properly
|
||||
|
||||
|
||||
def validate_file_content(file_path, content):
|
||||
def validate_file_content(file_path: str, content: str) -> Optional[str]:
|
||||
"""
|
||||
Validates the content of a code file by checking for syntax errors.
|
||||
|
||||
@ -47,7 +47,7 @@ def validate_file_content(file_path, content):
|
||||
ignore_types = ['txt', 'md', 'doc', 'pdf']
|
||||
|
||||
if file_extension in ignore_types:
|
||||
return ''
|
||||
return None
|
||||
|
||||
elif file_extension in validation_commands:
|
||||
try:
|
||||
@ -62,13 +62,13 @@ def validate_file_content(file_path, content):
|
||||
except subprocess.CalledProcessError as e:
|
||||
return e.stderr.strip()
|
||||
else:
|
||||
return ''
|
||||
return None
|
||||
else:
|
||||
# If the file extension is not recognized, return a default error message
|
||||
return f'Unsupported file type: {file_extension}'
|
||||
|
||||
|
||||
def resolve_path(base_path, file_path):
|
||||
def resolve_path(base_path: str, file_path: str) -> str:
|
||||
if file_path.startswith(PATH_PREFIX):
|
||||
file_path = file_path[len(PATH_PREFIX):]
|
||||
return os.path.join(base_path, file_path)
|
||||
@ -82,22 +82,33 @@ class FileReadAction(ExecutableAction):
|
||||
"""
|
||||
path: str
|
||||
start_index: int = 0
|
||||
max_lines: int = 100
|
||||
action: str = ActionType.READ
|
||||
|
||||
def run(self, controller) -> FileReadObservation:
|
||||
path = resolve_path(controller.workdir, self.path)
|
||||
with open(path, 'r', encoding='utf-8') as file:
|
||||
all_lines = file.readlines()
|
||||
total_lines = len(all_lines)
|
||||
if total_lines >= 100:
|
||||
end_index = self.start_index + 100 if total_lines - \
|
||||
self.start_index - 100 >= 0 else -1
|
||||
code_slice = all_lines[self.start_index: end_index]
|
||||
else:
|
||||
code_slice = all_lines[:]
|
||||
if isinstance(code_slice, list) and len(code_slice) > 1:
|
||||
# def run(self, controller) -> FileReadObservation:
|
||||
# path = resolve_path(controller.workdir, self.path)
|
||||
def run(self):
|
||||
path = resolve_path('./workspace', self.path)
|
||||
if not os.path.exists(path):
|
||||
return FileReadObservation(path=path, content='File not found')
|
||||
|
||||
try:
|
||||
all_lines = []
|
||||
with open(path, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
all_lines.append(line.strip('\n'))
|
||||
total_lines = len(all_lines)
|
||||
if total_lines >= self.max_lines:
|
||||
end_index = self.start_index + self.max_lines - 1 if total_lines - \
|
||||
self.start_index - self.max_lines >= 0 else -1
|
||||
code_slice = all_lines[self.start_index - 1: end_index]
|
||||
else:
|
||||
code_slice = all_lines[:]
|
||||
code_view = '\n'.join(code_slice)
|
||||
return FileReadObservation(path=path, content=code_view)
|
||||
except (IOError, UnicodeDecodeError) as e:
|
||||
return FileReadObservation(path=path, content=f'Error reading file: {e}')
|
||||
|
||||
return FileReadObservation(path=path, content=code_view)
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
@ -113,21 +124,46 @@ class FileWriteAction(ExecutableAction):
|
||||
action: str = ActionType.WRITE
|
||||
|
||||
def run(self, controller) -> Observation:
|
||||
whole_path = resolve_path(controller.workdir, self.path)
|
||||
full_path = resolve_path(controller.workdir, self.path)
|
||||
parent_dir = os.path.dirname(full_path)
|
||||
Path(parent_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(whole_path, 'w', encoding='utf-8') as file:
|
||||
all_lines = file.readlines()
|
||||
insert = self.content.split('\n')
|
||||
new_file = all_lines[:self.start] + insert + all_lines[self.end:]
|
||||
content_str = '\n'.join(new_file)
|
||||
validation_error = validate_file_content(whole_path, content_str)
|
||||
if validation_error:
|
||||
file.write(content_str)
|
||||
return FileWriteObservation(content='', path=self.path)
|
||||
else:
|
||||
# Revert to the old file content
|
||||
file.write('\n'.join(all_lines))
|
||||
return AgentErrorObservation(content=validation_error)
|
||||
all_lines = []
|
||||
try:
|
||||
with open(full_path, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
all_lines.append(line.strip('\n'))
|
||||
except (IOError, UnicodeDecodeError):
|
||||
all_lines = []
|
||||
|
||||
# Split the content into lines
|
||||
new_lines = self.content.split('\n')
|
||||
|
||||
# Check if the start and end indices are valid
|
||||
if self.start < 0 or self.end < 0:
|
||||
return AgentErrorObservation(content=f'Invalid start or end index: {self.start}, {self.end}')
|
||||
elif self.start <= len(all_lines) and self.start + len(new_lines) <= len(all_lines):
|
||||
new_file_lines = all_lines[:self.start-1] + \
|
||||
new_lines + all_lines[len(all_lines)-1:]
|
||||
elif self.start <= len(all_lines) and self.start + len(new_lines) > len(all_lines):
|
||||
new_file_lines = all_lines[:self.start-1] + new_lines
|
||||
elif self.start > len(all_lines):
|
||||
new_file_lines = all_lines + \
|
||||
['' for i in range(len(all_lines), self.start-1)] + new_lines
|
||||
|
||||
new_file_content = '\n'.join(new_file_lines)
|
||||
|
||||
validation_status = validate_file_content(full_path, new_file_content)
|
||||
|
||||
if not validation_status:
|
||||
try:
|
||||
with open(full_path, 'w', encoding='utf-8') as file:
|
||||
file.write(new_file_content)
|
||||
except IOError as e:
|
||||
return AgentErrorObservation(content=f'Error writing file: {e}')
|
||||
return FileWriteObservation(content=self.content, path=self.path)
|
||||
else:
|
||||
return AgentErrorObservation(content=validation_status)
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user