Refactor file service (#7533)

This commit is contained in:
sp.wack 2025-04-08 19:41:22 +04:00 committed by GitHub
parent c8904e4672
commit 255e209886
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 135 additions and 361 deletions

View File

@ -0,0 +1,29 @@
import { describe, expect, it } from "vitest";
import { FileService } from "#/api/file-service/file-service.api";
import {
FILE_VARIANTS_1,
FILE_VARIANTS_2,
} from "#/mocks/file-service-handlers";
/**
* File service API tests. The actual API calls are mocked using MSW.
* You can find the mock handlers in `frontend/src/mocks/file-service-handlers.ts`.
*/
describe("FileService", () => {
it("should get a list of files", async () => {
await expect(FileService.getFiles("test-conversation-id")).resolves.toEqual(
FILE_VARIANTS_1,
);
await expect(
FileService.getFiles("test-conversation-id-2"),
).resolves.toEqual(FILE_VARIANTS_2);
});
it("should get content of a file", async () => {
await expect(
FileService.getFile("test-conversation-id", "file1.txt"),
).resolves.toEqual("Content of file1.txt");
});
});

View File

@ -1,15 +1,12 @@
import { screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { renderWithProviders } from "test-utils";
import { describe, it, expect, vi, Mock, afterEach } from "vitest";
import toast from "#/utils/toast";
import { describe, it, expect, vi, afterEach } from "vitest";
import { AgentState } from "#/types/agent-state";
import OpenHands from "#/api/open-hands";
import { FileExplorer } from "#/components/features/file-explorer/file-explorer";
import { FileService } from "#/api/file-service/file-service.api";
const toastSpy = vi.spyOn(toast, "error");
const uploadFilesSpy = vi.spyOn(OpenHands, "uploadFiles");
const getFilesSpy = vi.spyOn(OpenHands, "getFiles");
const getFilesSpy = vi.spyOn(FileService, "getFiles");
vi.mock("../../services/fileService", async () => ({
uploadFiles: vi.fn(),
@ -64,41 +61,4 @@ describe.skip("FileExplorer", () => {
expect(folder1).toBeInTheDocument();
expect(folder1).not.toBeVisible();
});
it("should upload files", async () => {
const user = userEvent.setup();
renderFileExplorerWithRunningAgentState();
const file = new File([""], "file-name");
const uploadFileInput = await screen.findByTestId("file-input");
await user.upload(uploadFileInput, file);
// TODO: Improve this test by passing expected argument to `uploadFiles`
expect(uploadFilesSpy).toHaveBeenCalledOnce();
expect(getFilesSpy).toHaveBeenCalled();
const file2 = new File([""], "file-name-2");
const uploadDirInput = await screen.findByTestId("file-input");
await user.upload(uploadDirInput, [file, file2]);
expect(uploadFilesSpy).toHaveBeenCalledTimes(2);
expect(getFilesSpy).toHaveBeenCalled();
});
it("should display an error toast if file upload fails", async () => {
(uploadFilesSpy as Mock).mockRejectedValue(new Error());
const user = userEvent.setup();
renderFileExplorerWithRunningAgentState();
const uploadFileInput = await screen.findByTestId("file-input");
const file = new File([""], "test");
await user.upload(uploadFileInput, file);
expect(uploadFilesSpy).rejects.toThrow();
expect(toastSpy).toHaveBeenCalledWith(
expect.stringContaining("upload-error"),
expect.any(String),
);
});
});

View File

@ -3,10 +3,10 @@ import userEvent from "@testing-library/user-event";
import { renderWithProviders } from "test-utils";
import { vi, describe, afterEach, it, expect } from "vitest";
import TreeNode from "#/components/features/file-explorer/tree-node";
import OpenHands from "#/api/open-hands";
import { FileService } from "#/api/file-service/file-service.api";
const getFileSpy = vi.spyOn(OpenHands, "getFile");
const getFilesSpy = vi.spyOn(OpenHands, "getFiles");
const getFileSpy = vi.spyOn(FileService, "getFile");
const getFilesSpy = vi.spyOn(FileService, "getFiles");
vi.mock("../../services/fileService", async () => ({
uploadFile: vi.fn(),

View File

@ -0,0 +1,38 @@
import { openHands } from "../open-hands-axios";
import { GetFilesResponse, GetFileResponse } from "./file-service.types";
import { getConversationUrl } from "./file-service.utils";
export class FileService {
/**
* Retrieve the list of files available in the workspace
* @param conversationId ID of the conversation
* @param path Path to list files from. If provided, it lists all the files in the given path
* @returns List of files available in the given path. If path is not provided, it lists all the files in the workspace
*/
static async getFiles(
conversationId: string,
path?: string,
): Promise<GetFilesResponse> {
const url = `${getConversationUrl(conversationId)}/list-files`;
const { data } = await openHands.get<GetFilesResponse>(url, {
params: { path },
});
return data;
}
/**
* Retrieve the content of a file
* @param conversationId ID of the conversation
* @param path Full path of the file to retrieve
* @returns Code content of the file
*/
static async getFile(conversationId: string, path: string): Promise<string> {
const url = `${getConversationUrl(conversationId)}/select-file`;
const { data } = await openHands.get<GetFileResponse>(url, {
params: { file: path },
});
return data.code;
}
}

View File

@ -0,0 +1,5 @@
export type GetFilesResponse = string[];
export interface GetFileResponse {
code: string;
}

View File

@ -0,0 +1,7 @@
/**
* Returns a URL compatible for the file service
* @param conversationId ID of the conversation
* @returns URL of the conversation
*/
export const getConversationUrl = (conversationId: string) =>
`/api/conversations/${conversationId}`;

View File

@ -1,10 +1,7 @@
import {
SaveFileSuccessResponse,
FileUploadSuccessResponse,
Feedback,
FeedbackResponse,
GitHubAccessTokenResponse,
ErrorResponse,
GetConfigResponse,
GetVSCodeUrlResponse,
AuthenticateResponse,
@ -53,80 +50,6 @@ class OpenHands {
return data;
}
/**
* Retrieve the list of files available in the workspace
* @param path Path to list files from
* @returns List of files available in the given path. If path is not provided, it lists all the files in the workspace
*/
static async getFiles(
conversationId: string,
path?: string,
): Promise<string[]> {
const url = `/api/conversations/${conversationId}/list-files`;
const { data } = await openHands.get<string[]>(url, {
params: { path },
});
return data;
}
/**
* Retrieve the content of a file
* @param path Full path of the file to retrieve
* @returns Content of the file
*/
static async getFile(conversationId: string, path: string): Promise<string> {
const url = `/api/conversations/${conversationId}/select-file`;
const { data } = await openHands.get<{ code: string }>(url, {
params: { file: path },
});
return data.code;
}
/**
* Save the content of a file
* @param path Full path of the file to save
* @param content Content to save in the file
* @returns Success message or error message
*/
static async saveFile(
conversationId: string,
path: string,
content: string,
): Promise<SaveFileSuccessResponse> {
const url = `/api/conversations/${conversationId}/save-file`;
const { data } = await openHands.post<
SaveFileSuccessResponse | ErrorResponse
>(url, {
filePath: path,
content,
});
if ("error" in data) throw new Error(data.error);
return data;
}
/**
* Upload a file to the workspace
* @param file File to upload
* @returns Success message or error message
*/
static async uploadFiles(
conversationId: string,
files: File[],
): Promise<FileUploadSuccessResponse> {
const url = `/api/conversations/${conversationId}/upload-files`;
const formData = new FormData();
files.forEach((file) => formData.append("files", file));
const { data } = await openHands.post<
FileUploadSuccessResponse | ErrorResponse
>(url, formData);
if ("error" in data) throw new Error(data.error);
return data;
}
/**
* Send feedback to the server
* @param data Feedback data

View File

@ -1,16 +0,0 @@
import { useMutation } from "@tanstack/react-query";
import OpenHands from "#/api/open-hands";
import { useConversation } from "#/context/conversation-context";
type UploadFilesArgs = {
files: File[];
};
export const useUploadFiles = () => {
const { conversationId } = useConversation();
return useMutation({
mutationFn: ({ files }: UploadFilesArgs) =>
OpenHands.uploadFiles(conversationId, files),
});
};

View File

@ -1,6 +1,6 @@
import { useQuery } from "@tanstack/react-query";
import OpenHands from "#/api/open-hands";
import { useConversation } from "#/context/conversation-context";
import { FileService } from "#/api/file-service/file-service.api";
interface UseListFileConfig {
path: string;
@ -9,8 +9,8 @@ interface UseListFileConfig {
export const useListFile = (config: UseListFileConfig) => {
const { conversationId } = useConversation();
return useQuery({
queryKey: ["file", conversationId, config.path],
queryFn: () => OpenHands.getFile(conversationId, config.path),
queryKey: ["files", conversationId, config.path],
queryFn: () => FileService.getFile(conversationId, config.path),
enabled: false, // don't fetch by default, trigger manually via `refetch`
});
};

View File

@ -1,9 +1,9 @@
import { useQuery } from "@tanstack/react-query";
import { useSelector } from "react-redux";
import OpenHands from "#/api/open-hands";
import { useConversation } from "#/context/conversation-context";
import { RootState } from "#/store";
import { RUNTIME_INACTIVE_STATES } from "#/types/agent-state";
import { FileService } from "#/api/file-service/file-service.api";
interface UseListFilesConfig {
path?: string;
@ -17,12 +17,12 @@ const DEFAULT_CONFIG: UseListFilesConfig = {
export const useListFiles = (config: UseListFilesConfig = DEFAULT_CONFIG) => {
const { conversationId } = useConversation();
const { curAgentState } = useSelector((state: RootState) => state.agent);
const isActive = !RUNTIME_INACTIVE_STATES.includes(curAgentState);
const runtimeIsActive = !RUNTIME_INACTIVE_STATES.includes(curAgentState);
return useQuery({
queryKey: ["files", conversationId, config?.path],
queryFn: () => OpenHands.getFiles(conversationId, config?.path),
enabled: !!(isActive && config?.enabled),
queryFn: () => FileService.getFiles(conversationId, config?.path),
enabled: runtimeIsActive && !!config?.enabled,
staleTime: 1000 * 60 * 5, // 5 minutes
gcTime: 1000 * 60 * 15, // 15 minutes
});

View File

@ -0,0 +1,39 @@
import { delay, http, HttpResponse } from "msw";
export const FILE_VARIANTS_1 = ["file1.txt", "file2.txt", "file3.txt"];
export const FILE_VARIANTS_2 = [
"reboot_skynet.exe",
"target_list.txt",
"terminator_blueprint.txt",
];
export const FILE_SERVICE_HANDLERS = [
http.get(
"/api/conversations/:conversationId/list-files",
async ({ params }) => {
await delay();
const cid = params.conversationId?.toString();
if (!cid) return HttpResponse.json(null, { status: 400 });
return cid === "test-conversation-id-2"
? HttpResponse.json(FILE_VARIANTS_2)
: HttpResponse.json(FILE_VARIANTS_1);
},
),
http.get(
"/api/conversations/:conversationId/select-file",
async ({ request }) => {
await delay();
const url = new URL(request.url);
const file = url.searchParams.get("file")?.toString();
if (file) {
return HttpResponse.json({ code: `Content of ${file}` });
}
return HttpResponse.json(null, { status: 404 });
},
),
];

View File

@ -7,6 +7,7 @@ import {
import { DEFAULT_SETTINGS } from "#/services/settings";
import { STRIPE_BILLING_HANDLERS } from "./billing-handlers";
import { ApiSettings, PostApiSettings } from "#/types/settings";
import { FILE_SERVICE_HANDLERS } from "./file-service-handlers";
import { GitUser } from "#/types/git";
export const MOCK_DEFAULT_USER_SETTINGS: ApiSettings | PostApiSettings = {
@ -91,52 +92,6 @@ const openHandsHandlers = [
HttpResponse.json(["mock-invariant"]),
),
http.get(
"http://localhost:3001/api/conversations/:conversationId/list-files",
async ({ params }) => {
await delay();
const cid = params.conversationId?.toString();
if (!cid) return HttpResponse.json([], { status: 404 });
let data = ["file1.txt", "file2.txt", "file3.txt"];
if (cid === "3") {
data = [
"reboot_skynet.exe",
"target_list.txt",
"terminator_blueprint.txt",
];
}
return HttpResponse.json(data);
},
),
http.post("http://localhost:3001/api/save-file", () =>
HttpResponse.json(null, { status: 200 }),
),
http.get("http://localhost:3001/api/select-file", async ({ request }) => {
await delay();
const token = request.headers
.get("Authorization")
?.replace("Bearer", "")
.trim();
if (!token) {
return HttpResponse.json([], { status: 401 });
}
const url = new URL(request.url);
const file = url.searchParams.get("file")?.toString();
if (file) {
return HttpResponse.json({ code: `Content of ${file}` });
}
return HttpResponse.json(null, { status: 404 });
}),
http.post("http://localhost:3001/api/submit-feedback", async () => {
await delay(1200);
@ -149,6 +104,7 @@ const openHandsHandlers = [
export const handlers = [
...STRIPE_BILLING_HANDLERS,
...FILE_SERVICE_HANDLERS,
...openHandsHandlers,
http.get("/api/user/repositories", () =>
HttpResponse.json([

View File

@ -1,11 +1,9 @@
import os
import tempfile
from fastapi import (
APIRouter,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.responses import FileResponse, JSONResponse
@ -17,19 +15,14 @@ from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
FileReadAction,
FileWriteAction,
)
from openhands.events.observation import (
ErrorObservation,
FileReadObservation,
FileWriteObservation,
)
from openhands.runtime.base import Runtime
from openhands.server.file_config import (
FILES_TO_IGNORE,
MAX_FILE_SIZE_MB,
is_extension_allowed,
sanitize_filename,
)
from openhands.utils.async_utils import call_sync_from_async
@ -37,7 +30,7 @@ app = APIRouter(prefix='/api/conversations/{conversation_id}')
@app.get('/list-files')
async def list_files(request: Request, conversation_id: str, path: str | None = None):
async def list_files(request: Request, path: str | None = None):
"""List files in the specified path.
This function retrieves a list of files from the agent's runtime file store,
@ -148,168 +141,8 @@ async def select_file(file: str, request: Request):
)
@app.post('/upload-files')
async def upload_file(request: Request, conversation_id: str, files: list[UploadFile]):
"""Upload a list of files to the workspace.
To upload a files:
```sh
curl -X POST -F "file=@<file_path1>" -F "file=@<file_path2>" http://localhost:3000/api/conversations/{conversation_id}/upload-files
```
Args:
request (Request): The incoming request object.
files (list[UploadFile]): A list of files to be uploaded.
Returns:
dict: A message indicating the success of the upload operation.
Raises:
HTTPException: If there's an error saving the files.
"""
try:
uploaded_files = []
skipped_files = []
for file in files:
safe_filename = sanitize_filename(file.filename)
file_contents = await file.read()
if (
MAX_FILE_SIZE_MB > 0
and len(file_contents) > MAX_FILE_SIZE_MB * 1024 * 1024
):
skipped_files.append(
{
'name': safe_filename,
'reason': f'Exceeds maximum size limit of {MAX_FILE_SIZE_MB}MB',
}
)
continue
if not is_extension_allowed(safe_filename):
skipped_files.append(
{'name': safe_filename, 'reason': 'File type not allowed'}
)
continue
# copy the file to the runtime
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_file_path = os.path.join(tmp_dir, safe_filename)
with open(tmp_file_path, 'wb') as tmp_file:
tmp_file.write(file_contents)
tmp_file.flush()
runtime: Runtime = request.state.conversation.runtime
try:
await call_sync_from_async(
runtime.copy_to,
tmp_file_path,
runtime.config.workspace_mount_path_in_sandbox,
)
except AgentRuntimeUnavailableError as e:
logger.error(f'Error saving file {safe_filename}: {e}')
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': f'Error saving file: {e}'},
)
uploaded_files.append(safe_filename)
response_content = {
'message': 'File upload process completed',
'uploaded_files': uploaded_files,
'skipped_files': skipped_files,
}
if not uploaded_files and skipped_files:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
**response_content,
'error': 'No files were uploaded successfully',
},
)
return JSONResponse(status_code=status.HTTP_200_OK, content=response_content)
except Exception as e:
logger.error(f'Error during file upload: {e}')
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
'error': f'Error during file upload: {str(e)}',
'uploaded_files': [],
'skipped_files': [],
},
)
@app.post('/save-file')
async def save_file(request: Request):
"""Save a file to the agent's runtime file store.
This endpoint allows saving a file when the agent is in a paused, finished,
or awaiting user input state. It checks the agent's state before proceeding
with the file save operation.
Args:
request (Request): The incoming FastAPI request object.
Returns:
JSONResponse: A JSON response indicating the success of the operation.
Raises:
HTTPException:
- 403 error if the agent is not in an allowed state for editing.
- 400 error if the file path or content is missing.
- 500 error if there's an unexpected error during the save operation.
"""
try:
# Extract file path and content from the request
data = await request.json()
file_path = data.get('filePath')
content = data.get('content')
# Validate the presence of required data
if not file_path or content is None:
raise HTTPException(status_code=400, detail='Missing filePath or content')
# Save the file to the agent's runtime file store
runtime: Runtime = request.state.conversation.runtime
file_path = os.path.join(
runtime.config.workspace_mount_path_in_sandbox, file_path
)
write_action = FileWriteAction(file_path, content)
try:
observation = await call_sync_from_async(runtime.run_action, write_action)
except AgentRuntimeUnavailableError as e:
logger.error(f'Error saving file: {e}')
return JSONResponse(
status_code=500,
content={'error': f'Error saving file: {e}'},
)
if isinstance(observation, FileWriteObservation):
return JSONResponse(
status_code=200, content={'message': 'File saved successfully'}
)
elif isinstance(observation, ErrorObservation):
return JSONResponse(
status_code=500,
content={'error': f'Failed to save file: {observation}'},
)
else:
return JSONResponse(
status_code=500,
content={'error': f'Unexpected observation: {observation}'},
)
except Exception as e:
# Log the error and return a 500 response
logger.error(f'Error saving file: {e}')
raise HTTPException(status_code=500, detail=f'Error saving file: {e}')
@app.get('/zip-directory')
def zip_current_workspace(request: Request, conversation_id: str):
def zip_current_workspace(request: Request):
try:
logger.debug('Zipping workspace')
runtime: Runtime = request.state.conversation.runtime