Improve fileop logic to handle multiple cases

This commit is contained in:
Anas Dorbani 2024-04-10 16:23:53 +00:00
parent 04e9caab8b
commit d0ba806f35

View File

@ -1,6 +1,8 @@
import os
import subprocess
import sys
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
@ -19,10 +21,8 @@ from .base import ExecutableAction
# The LLM sometimes returns paths with this prefix, so we need to remove it
PATH_PREFIX = '/workspace/'
# claude generated this and I have no clue if it works properly
def validate_file_content(file_path, content):
def validate_file_content(file_path: str, content: str) -> Optional[str]:
"""
Validates the content of a code file by checking for syntax errors.
@ -47,7 +47,7 @@ def validate_file_content(file_path, content):
ignore_types = ['txt', 'md', 'doc', 'pdf']
if file_extension in ignore_types:
return ''
return None
elif file_extension in validation_commands:
try:
@ -62,13 +62,13 @@ def validate_file_content(file_path, content):
except subprocess.CalledProcessError as e:
return e.stderr.strip()
else:
return ''
return None
else:
# If the file extension is not recognized, return a default error message
return f'Unsupported file type: {file_extension}'
def resolve_path(base_path, file_path):
def resolve_path(base_path: str, file_path: str) -> str:
if file_path.startswith(PATH_PREFIX):
file_path = file_path[len(PATH_PREFIX):]
return os.path.join(base_path, file_path)
@ -82,22 +82,33 @@ class FileReadAction(ExecutableAction):
"""
path: str
start_index: int = 0
max_lines: int = 100
action: str = ActionType.READ
def run(self, controller) -> FileReadObservation:
path = resolve_path(controller.workdir, self.path)
with open(path, 'r', encoding='utf-8') as file:
all_lines = file.readlines()
total_lines = len(all_lines)
if total_lines >= 100:
end_index = self.start_index + 100 if total_lines - \
self.start_index - 100 >= 0 else -1
code_slice = all_lines[self.start_index: end_index]
else:
code_slice = all_lines[:]
if isinstance(code_slice, list) and len(code_slice) > 1:
# def run(self, controller) -> FileReadObservation:
# path = resolve_path(controller.workdir, self.path)
def run(self):
path = resolve_path('./workspace', self.path)
if not os.path.exists(path):
return FileReadObservation(path=path, content='File not found')
try:
all_lines = []
with open(path, 'r', encoding='utf-8') as file:
for line in file:
all_lines.append(line.strip('\n'))
total_lines = len(all_lines)
if total_lines >= self.max_lines:
end_index = self.start_index + self.max_lines - 1 if total_lines - \
self.start_index - self.max_lines >= 0 else -1
code_slice = all_lines[self.start_index - 1: end_index]
else:
code_slice = all_lines[:]
code_view = '\n'.join(code_slice)
return FileReadObservation(path=path, content=code_view)
except (IOError, UnicodeDecodeError) as e:
return FileReadObservation(path=path, content=f'Error reading file: {e}')
return FileReadObservation(path=path, content=code_view)
@property
def message(self) -> str:
@ -113,21 +124,46 @@ class FileWriteAction(ExecutableAction):
action: str = ActionType.WRITE
def run(self, controller) -> Observation:
whole_path = resolve_path(controller.workdir, self.path)
full_path = resolve_path(controller.workdir, self.path)
parent_dir = os.path.dirname(full_path)
Path(parent_dir).mkdir(parents=True, exist_ok=True)
with open(whole_path, 'w', encoding='utf-8') as file:
all_lines = file.readlines()
insert = self.content.split('\n')
new_file = all_lines[:self.start] + insert + all_lines[self.end:]
content_str = '\n'.join(new_file)
validation_error = validate_file_content(whole_path, content_str)
if validation_error:
file.write(content_str)
return FileWriteObservation(content='', path=self.path)
else:
# Revert to the old file content
file.write('\n'.join(all_lines))
return AgentErrorObservation(content=validation_error)
all_lines = []
try:
with open(full_path, 'r', encoding='utf-8') as file:
for line in file:
all_lines.append(line.strip('\n'))
except (IOError, UnicodeDecodeError):
all_lines = []
# Split the content into lines
new_lines = self.content.split('\n')
# Check if the start and end indices are valid
if self.start < 0 or self.end < 0:
return AgentErrorObservation(content=f'Invalid start or end index: {self.start}, {self.end}')
elif self.start <= len(all_lines) and self.start + len(new_lines) <= len(all_lines):
new_file_lines = all_lines[:self.start-1] + \
new_lines + all_lines[len(all_lines)-1:]
elif self.start <= len(all_lines) and self.start + len(new_lines) > len(all_lines):
new_file_lines = all_lines[:self.start-1] + new_lines
elif self.start > len(all_lines):
new_file_lines = all_lines + \
['' for i in range(len(all_lines), self.start-1)] + new_lines
new_file_content = '\n'.join(new_file_lines)
validation_status = validate_file_content(full_path, new_file_content)
if not validation_status:
try:
with open(full_path, 'w', encoding='utf-8') as file:
file.write(new_file_content)
except IOError as e:
return AgentErrorObservation(content=f'Error writing file: {e}')
return FileWriteObservation(content=self.content, path=self.path)
else:
return AgentErrorObservation(content=validation_status)
@property
def message(self) -> str: