From a1a9d2f17575bd4f202577d1d298a1b37a133dee Mon Sep 17 00:00:00 2001 From: tofarr Date: Mon, 11 Nov 2024 15:36:07 -0700 Subject: [PATCH] Refactor websocket (#4879) Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com> --- .../components/chat/chat-interface.test.tsx | 14 +- .../__tests__/hooks/use-terminal.test.tsx | 20 +- frontend/src/components/AgentControlBar.tsx | 4 +- frontend/src/components/chat-interface.tsx | 4 +- .../components/chat/ConfirmationButtons.tsx | 4 +- frontend/src/components/event-handler.tsx | 188 +++++++++++ .../project-menu/ProjectMenuCard.tsx | 4 +- frontend/src/context/socket.tsx | 146 --------- frontend/src/context/ws-client-provider.tsx | 175 ++++++++++ frontend/src/entry.client.tsx | 11 +- frontend/src/hooks/useTerminal.ts | 4 +- frontend/src/routes/_oh.app.tsx | 309 ++++-------------- frontend/src/routes/_oh.tsx | 18 +- frontend/src/services/actions.ts | 31 +- frontend/src/services/agentStateService.ts | 9 +- frontend/src/services/chatService.ts | 2 +- frontend/src/services/terminalService.ts | 2 +- frontend/src/utils/verified-models.ts | 1 + frontend/test-utils.tsx | 4 +- 19 files changed, 486 insertions(+), 464 deletions(-) create mode 100644 frontend/src/components/event-handler.tsx delete mode 100644 frontend/src/context/socket.tsx create mode 100644 frontend/src/context/ws-client-provider.tsx diff --git a/frontend/__tests__/components/chat/chat-interface.test.tsx b/frontend/__tests__/components/chat/chat-interface.test.tsx index 71116e9651..fc4c03e3f6 100644 --- a/frontend/__tests__/components/chat/chat-interface.test.tsx +++ b/frontend/__tests__/components/chat/chat-interface.test.tsx @@ -16,14 +16,14 @@ describe("Empty state", () => { send: vi.fn(), })); - const { useSocket: useSocketMock } = vi.hoisted(() => ({ - useSocket: vi.fn(() => ({ send: sendMock, runtimeActive: true })), + const { useWsClient: useWsClientMock } = vi.hoisted(() => ({ + useWsClient: vi.fn(() => ({ send: sendMock, runtimeActive: true })), })); beforeAll(() => { vi.mock("#/context/socket", async (importActual) => ({ - ...(await importActual()), - useSocket: useSocketMock, + ...(await importActual()), + useWsClient: useWsClientMock, })); }); @@ -77,7 +77,7 @@ describe("Empty state", () => { "should load the a user message to the input when selecting", async () => { // this is to test that the message is in the UI before the socket is called - useSocketMock.mockImplementation(() => ({ + useWsClientMock.mockImplementation(() => ({ send: sendMock, runtimeActive: false, // mock an inactive runtime setup })); @@ -106,7 +106,7 @@ describe("Empty state", () => { it.fails( "should send the message to the socket only if the runtime is active", async () => { - useSocketMock.mockImplementation(() => ({ + useWsClientMock.mockImplementation(() => ({ send: sendMock, runtimeActive: false, // mock an inactive runtime setup })); @@ -123,7 +123,7 @@ describe("Empty state", () => { await user.click(displayedSuggestions[0]); expect(sendMock).not.toHaveBeenCalled(); - useSocketMock.mockImplementation(() => ({ + useWsClientMock.mockImplementation(() => ({ send: sendMock, runtimeActive: true, // mock an active runtime setup })); diff --git a/frontend/__tests__/hooks/use-terminal.test.tsx b/frontend/__tests__/hooks/use-terminal.test.tsx index aec9633d3a..7a5d6e9c54 100644 --- a/frontend/__tests__/hooks/use-terminal.test.tsx +++ b/frontend/__tests__/hooks/use-terminal.test.tsx @@ -2,8 +2,9 @@ import { beforeAll, describe, expect, it, vi } from "vitest"; import { render } from "@testing-library/react"; import { afterEach } from "node:test"; import { useTerminal } from "#/hooks/useTerminal"; -import { SocketProvider } from "#/context/socket"; import { Command } from "#/state/commandSlice"; +import { WsClientProvider } from "#/context/ws-client-provider"; +import { ReactNode } from "react"; interface TestTerminalComponentProps { commands: Command[]; @@ -18,6 +19,17 @@ function TestTerminalComponent({ return
; } +interface WrapperProps { + children: ReactNode; +} + + +function Wrapper({children}: WrapperProps) { + return ( + {children} + ) +} + describe("useTerminal", () => { const mockTerminal = vi.hoisted(() => ({ loadAddon: vi.fn(), @@ -50,7 +62,7 @@ describe("useTerminal", () => { it("should render", () => { render(, { - wrapper: SocketProvider, + wrapper: Wrapper, }); }); @@ -61,7 +73,7 @@ describe("useTerminal", () => { ]; render(, { - wrapper: SocketProvider, + wrapper: Wrapper, }); expect(mockTerminal.writeln).toHaveBeenNthCalledWith(1, "echo hello"); @@ -85,7 +97,7 @@ describe("useTerminal", () => { secrets={[secret, anotherSecret]} />, { - wrapper: SocketProvider, + wrapper: Wrapper, }, ); diff --git a/frontend/src/components/AgentControlBar.tsx b/frontend/src/components/AgentControlBar.tsx index 7dfc0e3817..f6bcea8090 100644 --- a/frontend/src/components/AgentControlBar.tsx +++ b/frontend/src/components/AgentControlBar.tsx @@ -6,7 +6,7 @@ import PlayIcon from "#/assets/play"; import { generateAgentStateChangeEvent } from "#/services/agentStateService"; import { RootState } from "#/store"; import AgentState from "#/types/AgentState"; -import { useSocket } from "#/context/socket"; +import { useWsClient } from "#/context/ws-client-provider"; const IgnoreTaskStateMap: Record = { [AgentState.PAUSED]: [ @@ -72,7 +72,7 @@ function ActionButton({ } function AgentControlBar() { - const { send } = useSocket(); + const { send } = useWsClient(); const { curAgentState } = useSelector((state: RootState) => state.agent); const handleAction = (action: AgentState) => { diff --git a/frontend/src/components/chat-interface.tsx b/frontend/src/components/chat-interface.tsx index 54c38e6641..626166f462 100644 --- a/frontend/src/components/chat-interface.tsx +++ b/frontend/src/components/chat-interface.tsx @@ -1,7 +1,6 @@ import { useDispatch, useSelector } from "react-redux"; import React from "react"; import posthog from "posthog-js"; -import { useSocket } from "#/context/socket"; import { convertImageToBase64 } from "#/utils/convert-image-to-base-64"; import { ChatMessage } from "./chat-message"; import { FeedbackActions } from "./feedback-actions"; @@ -22,13 +21,14 @@ import { ScrollToBottomButton } from "./scroll-to-bottom-button"; import { Suggestions } from "./suggestions"; import { SUGGESTIONS } from "#/utils/suggestions"; import BuildIt from "#/icons/build-it.svg?react"; +import { useWsClient } from "#/context/ws-client-provider"; const isErrorMessage = ( message: Message | ErrorMessage, ): message is ErrorMessage => "error" in message; export function ChatInterface() { - const { send } = useSocket(); + const { send } = useWsClient(); const dispatch = useDispatch(); const scrollRef = React.useRef(null); const { scrollDomToBottom, onChatBodyScroll, hitBottom } = diff --git a/frontend/src/components/chat/ConfirmationButtons.tsx b/frontend/src/components/chat/ConfirmationButtons.tsx index fb1f64f14a..c06dd76fe0 100644 --- a/frontend/src/components/chat/ConfirmationButtons.tsx +++ b/frontend/src/components/chat/ConfirmationButtons.tsx @@ -5,7 +5,7 @@ import RejectIcon from "#/assets/reject"; import { I18nKey } from "#/i18n/declaration"; import AgentState from "#/types/AgentState"; import { generateAgentStateChangeEvent } from "#/services/agentStateService"; -import { useSocket } from "#/context/socket"; +import { useWsClient } from "#/context/ws-client-provider"; interface ActionTooltipProps { type: "confirm" | "reject"; @@ -37,7 +37,7 @@ function ActionTooltip({ type, onClick }: ActionTooltipProps) { function ConfirmationButtons() { const { t } = useTranslation(); - const { send } = useSocket(); + const { send } = useWsClient(); const handleStateChange = (state: AgentState) => { const event = generateAgentStateChangeEvent(state); diff --git a/frontend/src/components/event-handler.tsx b/frontend/src/components/event-handler.tsx new file mode 100644 index 0000000000..75eef41164 --- /dev/null +++ b/frontend/src/components/event-handler.tsx @@ -0,0 +1,188 @@ +import React from "react"; +import { + useFetcher, + useLoaderData, + useRouteLoaderData, +} from "@remix-run/react"; +import { useDispatch, useSelector } from "react-redux"; +import toast from "react-hot-toast"; + +import posthog from "posthog-js"; +import { + useWsClient, + WsClientProviderStatus, +} from "#/context/ws-client-provider"; +import { ErrorObservation } from "#/types/core/observations"; +import { addErrorMessage, addUserMessage } from "#/state/chatSlice"; +import { handleAssistantMessage } from "#/services/actions"; +import { + getCloneRepoCommand, + getGitHubTokenCommand, +} from "#/services/terminalService"; +import { + clearFiles, + clearSelectedRepository, + setImportedProjectZip, +} from "#/state/initial-query-slice"; +import { clientLoader as appClientLoader } from "#/routes/_oh.app"; +import store, { RootState } from "#/store"; +import { createChatMessage } from "#/services/chatService"; +import { clientLoader as rootClientLoader } from "#/routes/_oh"; +import { isGitHubErrorReponse } from "#/api/github"; +import OpenHands from "#/api/open-hands"; +import { base64ToBlob } from "#/utils/base64-to-blob"; +import { setCurrentAgentState } from "#/state/agentSlice"; +import AgentState from "#/types/AgentState"; +import { getSettings } from "#/services/settings"; + +interface ServerError { + error: boolean | string; + message: string; + [key: string]: unknown; +} + +const isServerError = (data: object): data is ServerError => "error" in data; + +const isErrorObservation = (data: object): data is ErrorObservation => + "observation" in data && data.observation === "error"; + +export function EventHandler({ children }: React.PropsWithChildren) { + const { events, status, send } = useWsClient(); + const statusRef = React.useRef(null); + const runtimeActive = status === WsClientProviderStatus.ACTIVE; + const fetcher = useFetcher(); + const dispatch = useDispatch(); + const { files, importedProjectZip } = useSelector( + (state: RootState) => state.initalQuery, + ); + const { ghToken, repo } = useLoaderData(); + const initialQueryRef = React.useRef( + store.getState().initalQuery.initialQuery, + ); + + const sendInitialQuery = (query: string, base64Files: string[]) => { + const timestamp = new Date().toISOString(); + send(createChatMessage(query, base64Files, timestamp)); + }; + const data = useRouteLoaderData("routes/_oh"); + const userId = React.useMemo(() => { + if (data?.user && !isGitHubErrorReponse(data.user)) return data.user.id; + return null; + }, [data?.user]); + const userSettings = getSettings(); + + React.useEffect(() => { + if (!events.length) { + return; + } + const event = events[events.length - 1]; + if (event.token) { + fetcher.submit({ token: event.token as string }, { method: "post" }); + return; + } + + if (isServerError(event)) { + if (event.error_code === 401) { + toast.error("Session expired."); + fetcher.submit({}, { method: "POST", action: "/end-session" }); + return; + } + + if (typeof event.error === "string") { + toast.error(event.error); + } else { + toast.error(event.message); + } + return; + } + + if (isErrorObservation(event)) { + dispatch( + addErrorMessage({ + id: event.extras?.error_id, + message: event.message, + }), + ); + return; + } + handleAssistantMessage(event); + }, [events.length]); + + React.useEffect(() => { + if (statusRef.current === status) { + return; // This is a check because of strict mode - if the status did not change, don't do anything + } + statusRef.current = status; + const initialQuery = initialQueryRef.current; + + if (status === WsClientProviderStatus.ACTIVE) { + let additionalInfo = ""; + if (ghToken && repo) { + send(getCloneRepoCommand(ghToken, repo)); + additionalInfo = `Repository ${repo} has been cloned to /workspace. Please check the /workspace for files.`; + dispatch(clearSelectedRepository()); // reset selected repository; maybe better to move this to '/'? + } + // if there's an uploaded project zip, add it to the chat + else if (importedProjectZip) { + additionalInfo = `Files have been uploaded. Please check the /workspace for files.`; + } + + if (initialQuery) { + if (additionalInfo) { + sendInitialQuery(`${initialQuery}\n\n[${additionalInfo}]`, files); + } else { + sendInitialQuery(initialQuery, files); + } + dispatch(clearFiles()); // reset selected files + initialQueryRef.current = null; + } + } + + if (status === WsClientProviderStatus.OPENING && initialQuery) { + dispatch( + addUserMessage({ + content: initialQuery, + imageUrls: files, + timestamp: new Date().toISOString(), + }), + ); + } + + if (status === WsClientProviderStatus.STOPPED) { + store.dispatch(setCurrentAgentState(AgentState.STOPPED)); + } + }, [status]); + + React.useEffect(() => { + if (runtimeActive && userId && ghToken) { + // Export if the user valid, this could happen mid-session so it is handled here + send(getGitHubTokenCommand(ghToken)); + } + }, [userId, ghToken, runtimeActive]); + + React.useEffect(() => { + (async () => { + if (runtimeActive && importedProjectZip) { + // upload files action + try { + const blob = base64ToBlob(importedProjectZip); + const file = new File([blob], "imported-project.zip", { + type: blob.type, + }); + await OpenHands.uploadFiles([file]); + dispatch(setImportedProjectZip(null)); + } catch (error) { + toast.error("Failed to upload project files."); + } + } + })(); + }, [runtimeActive, importedProjectZip]); + + React.useEffect(() => { + if (userSettings.LLM_API_KEY) { + posthog.capture("user_activated"); + } + }, [userSettings.LLM_API_KEY]); + + return children; +} diff --git a/frontend/src/components/project-menu/ProjectMenuCard.tsx b/frontend/src/components/project-menu/ProjectMenuCard.tsx index 923833ca38..1a32c2f802 100644 --- a/frontend/src/components/project-menu/ProjectMenuCard.tsx +++ b/frontend/src/components/project-menu/ProjectMenuCard.tsx @@ -6,13 +6,13 @@ import EllipsisH from "#/icons/ellipsis-h.svg?react"; import { ModalBackdrop } from "../modals/modal-backdrop"; import { ConnectToGitHubModal } from "../modals/connect-to-github-modal"; import { addUserMessage } from "#/state/chatSlice"; -import { useSocket } from "#/context/socket"; import { createChatMessage } from "#/services/chatService"; import { ProjectMenuCardContextMenu } from "./project.menu-card-context-menu"; import { ProjectMenuDetailsPlaceholder } from "./project-menu-details-placeholder"; import { ProjectMenuDetails } from "./project-menu-details"; import { downloadWorkspace } from "#/utils/download-workspace"; import { LoadingSpinner } from "../modals/LoadingProject"; +import { useWsClient } from "#/context/ws-client-provider"; interface ProjectMenuCardProps { isConnectedToGitHub: boolean; @@ -27,7 +27,7 @@ export function ProjectMenuCard({ isConnectedToGitHub, githubData, }: ProjectMenuCardProps) { - const { send } = useSocket(); + const { send } = useWsClient(); const dispatch = useDispatch(); const [contextMenuIsOpen, setContextMenuIsOpen] = React.useState(false); diff --git a/frontend/src/context/socket.tsx b/frontend/src/context/socket.tsx deleted file mode 100644 index 7bf1ab1d57..0000000000 --- a/frontend/src/context/socket.tsx +++ /dev/null @@ -1,146 +0,0 @@ -import React from "react"; -import { Data } from "ws"; -import posthog from "posthog-js"; -import EventLogger from "#/utils/event-logger"; - -interface WebSocketClientOptions { - token: string | null; - onOpen?: (event: Event) => void; - onMessage?: (event: MessageEvent) => void; - onError?: (event: Event) => void; - onClose?: (event: Event) => void; -} - -interface WebSocketContextType { - send: (data: string | ArrayBufferLike | Blob | ArrayBufferView) => void; - start: (options?: WebSocketClientOptions) => void; - stop: () => void; - setRuntimeIsInitialized: () => void; - runtimeActive: boolean; - isConnected: boolean; - events: Record[]; -} - -const SocketContext = React.createContext( - undefined, -); - -interface SocketProviderProps { - children: React.ReactNode; -} - -function SocketProvider({ children }: SocketProviderProps) { - const wsRef = React.useRef(null); - const [isConnected, setIsConnected] = React.useState(false); - const [runtimeActive, setRuntimeActive] = React.useState(false); - const [events, setEvents] = React.useState[]>([]); - - const setRuntimeIsInitialized = () => { - setRuntimeActive(true); - }; - - const start = React.useCallback((options?: WebSocketClientOptions): void => { - if (wsRef.current) { - EventLogger.warning( - "WebSocket connection is already established, but a new one is starting anyways.", - ); - } - - const baseUrl = - import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host; - const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; - const sessionToken = options?.token || "NO_JWT"; // not allowed to be empty or duplicated - const ghToken = localStorage.getItem("ghToken") || "NO_GITHUB"; - - const ws = new WebSocket(`${protocol}//${baseUrl}/ws`, [ - "openhands", - sessionToken, - ghToken, - ]); - - ws.addEventListener("open", (event) => { - posthog.capture("socket_opened"); - setIsConnected(true); - options?.onOpen?.(event); - }); - - ws.addEventListener("message", (event) => { - EventLogger.message(event); - - setEvents((prevEvents) => [...prevEvents, JSON.parse(event.data)]); - options?.onMessage?.(event); - }); - - ws.addEventListener("error", (event) => { - posthog.capture("socket_error"); - EventLogger.event(event, "SOCKET ERROR"); - options?.onError?.(event); - }); - - ws.addEventListener("close", (event) => { - posthog.capture("socket_closed"); - EventLogger.event(event, "SOCKET CLOSE"); - - setIsConnected(false); - setRuntimeActive(false); - wsRef.current = null; - options?.onClose?.(event); - }); - - wsRef.current = ws; - }, []); - - const stop = React.useCallback((): void => { - if (wsRef.current) { - wsRef.current.close(); - wsRef.current = null; - } - }, []); - - const send = React.useCallback( - (data: string | ArrayBufferLike | Blob | ArrayBufferView) => { - if (!wsRef.current) { - EventLogger.error("WebSocket is not connected."); - return; - } - setEvents((prevEvents) => [...prevEvents, JSON.parse(data.toString())]); - wsRef.current.send(data); - }, - [], - ); - - const value = React.useMemo( - () => ({ - send, - start, - stop, - setRuntimeIsInitialized, - runtimeActive, - isConnected, - events, - }), - [ - send, - start, - stop, - setRuntimeIsInitialized, - runtimeActive, - isConnected, - events, - ], - ); - - return ( - {children} - ); -} - -function useSocket() { - const context = React.useContext(SocketContext); - if (context === undefined) { - throw new Error("useSocket must be used within a SocketProvider"); - } - return context; -} - -export { SocketProvider, useSocket }; diff --git a/frontend/src/context/ws-client-provider.tsx b/frontend/src/context/ws-client-provider.tsx new file mode 100644 index 0000000000..bfecefadd0 --- /dev/null +++ b/frontend/src/context/ws-client-provider.tsx @@ -0,0 +1,175 @@ +import posthog from "posthog-js"; +import React from "react"; +import { Settings } from "#/services/settings"; +import ActionType from "#/types/ActionType"; +import EventLogger from "#/utils/event-logger"; +import AgentState from "#/types/AgentState"; + +export enum WsClientProviderStatus { + STOPPED, + OPENING, + ACTIVE, + ERROR, +} + +interface UseWsClient { + status: WsClientProviderStatus; + events: Record[]; + send: (event: Record) => void; +} + +const WsClientContext = React.createContext({ + status: WsClientProviderStatus.STOPPED, + events: [], + send: () => { + throw new Error("not connected"); + }, +}); + +interface WsClientProviderProps { + enabled: boolean; + token: string | null; + ghToken: string | null; + settings: Settings | null; +} + +export function WsClientProvider({ + enabled, + token, + ghToken, + settings, + children, +}: React.PropsWithChildren) { + const wsRef = React.useRef(null); + const tokenRef = React.useRef(token); + const ghTokenRef = React.useRef(ghToken); + const closeRef = React.useRef | null>(null); + const [status, setStatus] = React.useState(WsClientProviderStatus.STOPPED); + const [events, setEvents] = React.useState[]>([]); + + function send(event: Record) { + if (!wsRef.current) { + EventLogger.error("WebSocket is not connected."); + return; + } + wsRef.current.send(JSON.stringify(event)); + } + + function handleOpen() { + setStatus(WsClientProviderStatus.OPENING); + const initEvent = { + action: ActionType.INIT, + args: settings, + }; + send(initEvent); + } + + function handleMessage(messageEvent: MessageEvent) { + const event = JSON.parse(messageEvent.data); + setEvents((prevEvents) => [...prevEvents, event]); + if (event.extras?.agent_state === AgentState.INIT) { + setStatus(WsClientProviderStatus.ACTIVE); + } + if ( + status !== WsClientProviderStatus.ACTIVE && + event?.observation === "error" + ) { + setStatus(WsClientProviderStatus.ERROR); + } + } + + function handleClose() { + setStatus(WsClientProviderStatus.STOPPED); + setEvents([]); + wsRef.current = null; + } + + function handleError(event: Event) { + posthog.capture("socket_error"); + EventLogger.event(event, "SOCKET ERROR"); + setStatus(WsClientProviderStatus.ERROR); + } + + // Connect websocket + React.useEffect(() => { + let ws = wsRef.current; + + // If disabled close any existing websockets... + if (!enabled) { + if (ws) { + ws.close(); + } + wsRef.current = null; + return () => {}; + } + + // If there is no websocket or the tokens have changed or the current websocket is closed, + // create a new one + if ( + !ws || + (tokenRef.current && token !== tokenRef.current) || + ghToken !== ghTokenRef.current || + ws.readyState === WebSocket.CLOSED || + ws.readyState === WebSocket.CLOSING + ) { + ws?.close(); + const baseUrl = + import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host; + const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; + ws = new WebSocket(`${protocol}//${baseUrl}/ws`, [ + "openhands", + token || "NO_JWT", + ghToken || "NO_GITHUB", + ]); + } + ws.addEventListener("open", handleOpen); + ws.addEventListener("message", handleMessage); + ws.addEventListener("error", handleError); + ws.addEventListener("close", handleClose); + wsRef.current = ws; + tokenRef.current = token; + ghTokenRef.current = ghToken; + + return () => { + ws.removeEventListener("open", handleOpen); + ws.removeEventListener("message", handleMessage); + ws.removeEventListener("error", handleError); + ws.removeEventListener("close", handleClose); + }; + }, [enabled, token, ghToken]); + + // Strict mode mounts and unmounts each component twice, so we have to wait in the destructor + // before actually closing the socket and cancel the operation if the component gets remounted. + React.useEffect(() => { + const timeout = closeRef.current; + if (timeout != null) { + clearTimeout(timeout); + } + + return () => { + closeRef.current = setTimeout(() => { + wsRef.current?.close(); + }, 100); + }; + }, []); + + const value = React.useMemo( + () => ({ + status, + events, + send, + }), + [status, events], + ); + + return ( + + {children} + + ); +} + +export function useWsClient() { + const context = React.useContext(WsClientContext); + return context; +} diff --git a/frontend/src/entry.client.tsx b/frontend/src/entry.client.tsx index 8a6d4fac2d..4fe347f703 100644 --- a/frontend/src/entry.client.tsx +++ b/frontend/src/entry.client.tsx @@ -10,7 +10,6 @@ import React, { startTransition, StrictMode } from "react"; import { hydrateRoot } from "react-dom/client"; import { Provider } from "react-redux"; import posthog from "posthog-js"; -import { SocketProvider } from "./context/socket"; import "./i18n"; import store from "./store"; @@ -43,12 +42,10 @@ prepareApp().then(() => hydrateRoot( document, - - - - - - + + + + , ); }), diff --git a/frontend/src/hooks/useTerminal.ts b/frontend/src/hooks/useTerminal.ts index b45618eeb1..1409fdb7c4 100644 --- a/frontend/src/hooks/useTerminal.ts +++ b/frontend/src/hooks/useTerminal.ts @@ -4,7 +4,7 @@ import React from "react"; import { Command } from "#/state/commandSlice"; import { getTerminalCommand } from "#/services/terminalService"; import { parseTerminalOutput } from "#/utils/parseTerminalOutput"; -import { useSocket } from "#/context/socket"; +import { useWsClient } from "#/context/ws-client-provider"; /* NOTE: Tests for this hook are indirectly covered by the tests for the XTermTerminal component. @@ -15,7 +15,7 @@ export const useTerminal = ( commands: Command[] = [], secrets: string[] = [], ) => { - const { send } = useSocket(); + const { send } = useWsClient(); const terminal = React.useRef(null); const fitAddon = React.useRef(null); const ref = React.useRef(null); diff --git a/frontend/src/routes/_oh.app.tsx b/frontend/src/routes/_oh.app.tsx index c072e55bca..8bdfe10835 100644 --- a/frontend/src/routes/_oh.app.tsx +++ b/frontend/src/routes/_oh.app.tsx @@ -2,72 +2,29 @@ import { useDisclosure } from "@nextui-org/react"; import React from "react"; import { Outlet, - useFetcher, useLoaderData, json, ClientActionFunctionArgs, - useRouteLoaderData, } from "@remix-run/react"; -import { useDispatch, useSelector } from "react-redux"; -import WebSocket from "ws"; -import toast from "react-hot-toast"; -import posthog from "posthog-js"; +import { useDispatch } from "react-redux"; import { getSettings } from "#/services/settings"; import Security from "../components/modals/security/Security"; import { Controls } from "#/components/controls"; -import store, { RootState } from "#/store"; +import store from "#/store"; import { Container } from "#/components/container"; -import ActionType from "#/types/ActionType"; -import { handleAssistantMessage } from "#/services/actions"; -import { - addErrorMessage, - addUserMessage, - clearMessages, -} from "#/state/chatSlice"; -import { useSocket } from "#/context/socket"; -import { - getGitHubTokenCommand, - getCloneRepoCommand, -} from "#/services/terminalService"; +import { clearMessages } from "#/state/chatSlice"; import { clearTerminal } from "#/state/commandSlice"; import { useEffectOnce } from "#/utils/use-effect-once"; import CodeIcon from "#/icons/code.svg?react"; import GlobeIcon from "#/icons/globe.svg?react"; import ListIcon from "#/icons/list-type-number.svg?react"; -import { createChatMessage } from "#/services/chatService"; -import { - clearFiles, - clearInitialQuery, - clearSelectedRepository, - setImportedProjectZip, -} from "#/state/initial-query-slice"; +import { clearInitialQuery } from "#/state/initial-query-slice"; import { isGitHubErrorReponse, retrieveLatestGitHubCommit } from "#/api/github"; -import OpenHands from "#/api/open-hands"; -import AgentState from "#/types/AgentState"; -import { base64ToBlob } from "#/utils/base64-to-blob"; -import { clientLoader as rootClientLoader } from "#/routes/_oh"; import { clearJupyter } from "#/state/jupyterSlice"; import { FilesProvider } from "#/context/files"; -import { ErrorObservation } from "#/types/core/observations"; import { ChatInterface } from "#/components/chat-interface"; - -interface ServerError { - error: boolean | string; - message: string; - [key: string]: unknown; -} - -const isServerError = (data: object): data is ServerError => "error" in data; - -const isErrorObservation = (data: object): data is ErrorObservation => - "observation" in data && data.observation === "error"; - -const isAgentStateChange = ( - data: object, -): data is { extras: { agent_state: AgentState } } => - "extras" in data && - data.extras instanceof Object && - "agent_state" in data.extras; +import { WsClientProvider } from "#/context/ws-client-provider"; +import { EventHandler } from "#/components/event-handler"; export const clientLoader = async () => { const ghToken = localStorage.getItem("ghToken"); @@ -117,179 +74,26 @@ export const clientAction = async ({ request }: ClientActionFunctionArgs) => { function App() { const dispatch = useDispatch(); - const { files, importedProjectZip } = useSelector( - (state: RootState) => state.initalQuery, - ); - const { start, send, setRuntimeIsInitialized, runtimeActive } = useSocket(); - const { settings, token, ghToken, repo, q, lastCommit } = + const { settings, token, ghToken, lastCommit } = useLoaderData(); - const fetcher = useFetcher(); - const data = useRouteLoaderData("routes/_oh"); const secrets = React.useMemo( () => [ghToken, token].filter((secret) => secret !== null), [ghToken, token], ); - // To avoid re-rendering the component when the user object changes, we memoize the user ID. - // We use this to ensure the github token is valid before exporting it to the terminal. - const userId = React.useMemo(() => { - if (data?.user && !isGitHubErrorReponse(data.user)) return data.user.id; - return null; - }, [data?.user]); - const Terminal = React.useMemo( () => React.lazy(() => import("../components/terminal/Terminal")), [], ); - const addIntialQueryToChat = ( - query: string, - base64Files: string[], - timestamp = new Date().toISOString(), - ) => { - dispatch( - addUserMessage({ - content: query, - imageUrls: base64Files, - timestamp, - }), - ); - }; - - const sendInitialQuery = (query: string, base64Files: string[]) => { - const timestamp = new Date().toISOString(); - send(createChatMessage(query, base64Files, timestamp)); - - const userSettings = getSettings(); - if (userSettings.LLM_API_KEY) { - posthog.capture("user_activated"); - } - }; - - const handleOpen = React.useCallback(() => { - const initEvent = { - action: ActionType.INIT, - args: settings, - }; - send(JSON.stringify(initEvent)); - - // display query in UI, but don't send it to the server - if (q) addIntialQueryToChat(q, files); - }, [settings]); - - const handleMessage = React.useCallback( - (message: MessageEvent) => { - // set token received from the server - const parsed = JSON.parse(message.data.toString()); - if ("token" in parsed) { - fetcher.submit({ token: parsed.token }, { method: "post" }); - return; - } - - if (isServerError(parsed)) { - if (parsed.error_code === 401) { - toast.error("Session expired."); - fetcher.submit({}, { method: "POST", action: "/end-session" }); - return; - } - - if (typeof parsed.error === "string") { - toast.error(parsed.error); - } else { - toast.error(parsed.message); - } - - return; - } - if (isErrorObservation(parsed)) { - dispatch( - addErrorMessage({ - id: parsed.extras?.error_id, - message: parsed.message, - }), - ); - return; - } - - handleAssistantMessage(message.data.toString()); - - // handle first time connection - if ( - isAgentStateChange(parsed) && - parsed.extras.agent_state === AgentState.INIT - ) { - setRuntimeIsInitialized(); - - // handle new session - if (!token) { - let additionalInfo = ""; - if (ghToken && repo) { - send(getCloneRepoCommand(ghToken, repo)); - additionalInfo = `Repository ${repo} has been cloned to /workspace. Please check the /workspace for files.`; - dispatch(clearSelectedRepository()); // reset selected repository; maybe better to move this to '/'? - } - // if there's an uploaded project zip, add it to the chat - else if (importedProjectZip) { - additionalInfo = `Files have been uploaded. Please check the /workspace for files.`; - } - - if (q) { - if (additionalInfo) { - sendInitialQuery(`${q}\n\n[${additionalInfo}]`, files); - } else { - sendInitialQuery(q, files); - } - dispatch(clearFiles()); // reset selected files - } - } - } - }, - [token, ghToken, repo, q, files], - ); - - const startSocketConnection = React.useCallback(() => { - start({ - token, - onOpen: handleOpen, - onMessage: handleMessage, - }); - }, [token, handleOpen, handleMessage]); - useEffectOnce(() => { - // clear and restart the socket connection dispatch(clearMessages()); dispatch(clearTerminal()); dispatch(clearJupyter()); dispatch(clearInitialQuery()); // Clear initial query when navigating to /app - startSocketConnection(); }); - React.useEffect(() => { - if (runtimeActive && userId && ghToken) { - // Export if the user valid, this could happen mid-session so it is handled here - send(getGitHubTokenCommand(ghToken)); - } - }, [userId, ghToken, runtimeActive]); - - React.useEffect(() => { - (async () => { - if (runtimeActive && importedProjectZip) { - // upload files action - try { - const blob = base64ToBlob(importedProjectZip); - const file = new File([blob], "imported-project.zip", { - type: blob.type, - }); - await OpenHands.uploadFiles([file]); - dispatch(setImportedProjectZip(null)); - } catch (error) { - toast.error("Failed to upload project files."); - } - } - })(); - }, [runtimeActive, importedProjectZip]); - const { isOpen: securityModalIsOpen, onOpen: onSecurityModalOpen, @@ -297,53 +101,62 @@ function App() { } = useDisclosure(); return ( -
-
- - - + + +
+
+ + + -
- }, - { label: "Jupyter", to: "jupyter", icon: }, - { - label: "Browser", - to: "browser", - icon: , - isBeta: true, - }, - ]} - > - - - - - {/* Terminal uses some API that is not compatible in a server-environment. For this reason, we lazy load it to ensure - * that it loads only in the client-side. */} - - }> - - - +
+ }, + { label: "Jupyter", to: "jupyter", icon: }, + { + label: "Browser", + to: "browser", + icon: , + isBeta: true, + }, + ]} + > + + + + + {/* Terminal uses some API that is not compatible in a server-environment. For this reason, we lazy load it to ensure + * that it loads only in the client-side. */} + + }> + + + +
+
+ +
+ +
+
-
- -
- -
- -
+ + ); } diff --git a/frontend/src/routes/_oh.tsx b/frontend/src/routes/_oh.tsx index e7f7036882..585bbe2437 100644 --- a/frontend/src/routes/_oh.tsx +++ b/frontend/src/routes/_oh.tsx @@ -21,7 +21,6 @@ import { DangerModal } from "#/components/modals/confirmation-modals/danger-moda import { LoadingSpinner } from "#/components/modals/LoadingProject"; import { ModalBackdrop } from "#/components/modals/modal-backdrop"; import { UserActions } from "#/components/user-actions"; -import { useSocket } from "#/context/socket"; import i18n from "#/i18n"; import { getSettings, settingsAreUpToDate } from "#/services/settings"; import AllHandsLogo from "#/assets/branding/all-hands-logo.svg?react"; @@ -135,7 +134,6 @@ type SettingsFormData = { }; export default function MainApp() { - const { stop, isConnected } = useSocket(); const navigation = useNavigation(); const location = useLocation(); const { @@ -202,14 +200,6 @@ export default function MainApp() { } }, [user]); - React.useEffect(() => { - if (location.pathname === "/") { - // If the user is on the home page, we should stop the socket connection. - // This is relevant when the user redirects here for whatever reason. - if (isConnected) stop(); - } - }, [location.pathname]); - const handleUserLogout = () => { logoutFetcher.submit( {}, @@ -313,11 +303,9 @@ export default function MainApp() {

To continue, connect an OpenAI, Anthropic, or other LLM account

- {isConnected && ( -

- Changing settings during an active session will end the session -

- )} +

+ Changing settings during an active session will end the session +

) { + if (message.action) { + handleActionMessage(message as unknown as ActionMessage); + } else if (message.observation) { + handleObservationMessage(message as unknown as ObservationMessage); + } else if (message.status_update) { + handleStatusMessage(message as unknown as StatusMessage); } else { - socketMessage = data; - } - - if ("action" in socketMessage) { - handleActionMessage(socketMessage); - } else if ("observation" in socketMessage) { - handleObservationMessage(socketMessage); - } else if ("status_update" in socketMessage) { - handleStatusMessage(socketMessage); - } else { - console.error("Unknown message type", socketMessage); + console.error("Unknown message type", message); } } diff --git a/frontend/src/services/agentStateService.ts b/frontend/src/services/agentStateService.ts index ac194a8b10..c07afdb12d 100644 --- a/frontend/src/services/agentStateService.ts +++ b/frontend/src/services/agentStateService.ts @@ -1,8 +1,7 @@ import ActionType from "#/types/ActionType"; import AgentState from "#/types/AgentState"; -export const generateAgentStateChangeEvent = (state: AgentState) => - JSON.stringify({ - action: ActionType.CHANGE_AGENT_STATE, - args: { agent_state: state }, - }); +export const generateAgentStateChangeEvent = (state: AgentState) => ({ + action: ActionType.CHANGE_AGENT_STATE, + args: { agent_state: state }, +}); diff --git a/frontend/src/services/chatService.ts b/frontend/src/services/chatService.ts index a623e776c1..a19240f812 100644 --- a/frontend/src/services/chatService.ts +++ b/frontend/src/services/chatService.ts @@ -9,5 +9,5 @@ export function createChatMessage( action: ActionType.MESSAGE, args: { content: message, images_urls, timestamp }, }; - return JSON.stringify(event); + return event; } diff --git a/frontend/src/services/terminalService.ts b/frontend/src/services/terminalService.ts index 121f96bc36..0af45c3747 100644 --- a/frontend/src/services/terminalService.ts +++ b/frontend/src/services/terminalService.ts @@ -2,7 +2,7 @@ import ActionType from "#/types/ActionType"; export function getTerminalCommand(command: string, hidden: boolean = false) { const event = { action: ActionType.RUN, args: { command, hidden } }; - return JSON.stringify(event); + return event; } export function getGitHubTokenCommand(gitHubToken: string) { diff --git a/frontend/src/utils/verified-models.ts b/frontend/src/utils/verified-models.ts index f2643a526a..885bd7ac7e 100644 --- a/frontend/src/utils/verified-models.ts +++ b/frontend/src/utils/verified-models.ts @@ -20,6 +20,7 @@ export const VERIFIED_ANTHROPIC_MODELS = [ "claude-2", "claude-2.1", "claude-3-5-sonnet-20240620", + "claude-3-5-sonnet-20241022", "claude-3-haiku-20240307", "claude-3-opus-20240229", "claude-3-sonnet-20240229", diff --git a/frontend/test-utils.tsx b/frontend/test-utils.tsx index b88ee1063b..9558019bc1 100644 --- a/frontend/test-utils.tsx +++ b/frontend/test-utils.tsx @@ -6,7 +6,7 @@ import { configureStore } from "@reduxjs/toolkit"; // eslint-disable-next-line import/no-extraneous-dependencies import { RenderOptions, render } from "@testing-library/react"; import { AppStore, RootState, rootReducer } from "./src/store"; -import { SocketProvider } from "#/context/socket"; +import { WsClientProvider } from "#/context/ws-client-provider"; const setupStore = (preloadedState?: Partial): AppStore => configureStore({ @@ -35,7 +35,7 @@ export function renderWithProviders( function Wrapper({ children }: PropsWithChildren): JSX.Element { return ( - {children} + {children} ); }