feat: websocket connection management and sandbox bound to session. (#559)

* feat: websocket connection management and sandbox bound to session.

* fix: set default value to id

* feat: add session management.

* fix for mypy

* fix for mypy

* fix the pnpm-lock.

* fix the default model is empty will throw error.
This commit is contained in:
Leo 2024-04-06 01:19:52 +08:00 committed by GitHub
parent fe9815d57b
commit adbcfefd8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 1131 additions and 274 deletions

1
.gitignore vendored
View File

@ -198,6 +198,7 @@ logs
.envrc
/workspace
/debug
cache
# configuration
config.toml

View File

@ -12974,4 +12974,4 @@
"dev": true
}
}
}
}

View File

@ -23,6 +23,7 @@
"@xterm/xterm": "^5.4.0",
"eslint-config-airbnb-typescript": "^18.0.0",
"framer-motion": "^11.0.24",
"jose": "^5.2.3",
"i18next": "^23.10.1",
"i18next-browser-languagedetector": "^7.2.1",
"i18next-http-backend": "^2.5.0",

View File

@ -53,6 +53,9 @@ dependencies:
framer-motion:
specifier: ^11.0.24
version: 11.0.24(react-dom@18.2.0)(react@18.2.0)
jose:
specifier: ^5.2.3
version: 5.2.3
i18next:
specifier: ^23.10.1
version: 23.10.1
@ -6479,6 +6482,10 @@ packages:
resolution: {integrity: sha512-gFqAIbuKyyso/3G2qhiO2OM6shY6EPP/R0+mkDbyspxKazh8BXDC5FiFsUjlczgdNz/vfra0da2y+aHrusLG/Q==}
hasBin: true
/jose@5.2.3:
resolution: {integrity: sha512-KUXdbctm1uHVL8BYhnyHkgp3zDX5KW8ZhAKVFEfUbU2P8Alpzjb+48hHvjOdQIyPshoblhzsuqOwEEAbtHVirA==}
dev: false
/js-tokens@4.0.0:
resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==}

View File

@ -1,11 +1,17 @@
import React, { useState } from "react";
import React, { useEffect, useState } from "react";
import "./App.css";
import { useSelector } from "react-redux";
import CogTooth from "./assets/cog-tooth";
import ChatInterface from "./components/ChatInterface";
import Errors from "./components/Errors";
import SettingModal from "./components/SettingModal";
import Terminal from "./components/Terminal";
import Workspace from "./components/Workspace";
import store, { RootState } from "./store";
import { setInitialized } from "./state/globalSlice";
import { fetchMsgTotal } from "./services/session";
import LoadMessageModal from "./components/LoadMessageModal";
import { ResFetchMsgTotal } from "./types/ResponseType";
interface Props {
setSettingOpen: (isOpen: boolean) => void;
@ -25,7 +31,22 @@ function LeftNav({ setSettingOpen }: Props): JSX.Element {
}
function App(): JSX.Element {
const { initialized } = useSelector((state: RootState) => state.global);
const [settingOpen, setSettingOpen] = useState(false);
const [loadMsgWarning, setLoadMsgWarning] = useState(false);
useEffect(() => {
if (!initialized) {
fetchMsgTotal()
.then((data: ResFetchMsgTotal) => {
if (data.msg_total > 0) {
setLoadMsgWarning(true);
}
store.dispatch(setInitialized(true));
})
.catch();
}
}, []);
const handleCloseModal = () => {
setSettingOpen(false);
@ -48,6 +69,11 @@ function App(): JSX.Element {
</div>
</div>
<SettingModal isOpen={settingOpen} onClose={handleCloseModal} />
<LoadMessageModal
isOpen={loadMsgWarning}
onClose={() => setLoadMsgWarning(false)}
/>
<Errors />
</div>
);

View File

@ -10,7 +10,7 @@ import {
addAssistantMessageToChat,
setCurrentQueueMarkerState,
setCurrentTypingMsgState,
setTypingAcitve,
setTypingActive,
} from "../services/chatService";
import { Message } from "../state/chatSlice";
import { RootState } from "../store";
@ -34,7 +34,7 @@ function TypingChat() {
const messageContent = useTypingEffect([currentTypingMessage], {
loop: false,
setTypingAcitve,
setTypingActive,
setCurrentQueueMarkerState,
currentQueueMarker,
playbackRate: 0.1,
@ -117,7 +117,7 @@ function MessageList(): JSX.Element {
useEffect(() => {
if (currentTypingMessage === "") return;
if (!typingActive) setTypingAcitve(true);
if (!typingActive) setTypingActive(true);
}, [currentTypingMessage]);
useEffect(() => {

View File

@ -1,15 +1,14 @@
import React from "react";
import { useSelector } from "react-redux";
import { RootState } from "../store";
import "./css/Errors.css";
function Errors(): JSX.Element {
const errors = useSelector((state: RootState) => state.errors.errors);
return (
<div className="errors">
<div className="fixed left-1/2 transform -translate-x-1/2 top-4 z-50">
{errors.map((error, index) => (
<div key={index} className="error">
<div key={index} className="bg-red-800 p-4 rounded-md shadow-md mb-2">
ERROR: {error}
</div>
))}

View File

@ -0,0 +1,78 @@
import React from "react";
import {
Modal,
ModalContent,
ModalHeader,
ModalBody,
ModalFooter,
Button,
} from "@nextui-org/react";
import { fetchMsgs, clearMsgs } from "../services/session";
import { sendChatMessageFromEvent } from "../services/chatService";
import { handleAssistantMessage } from "../services/actions";
import { ResFetchMsg } from "../types/ResponseType";
interface Props {
isOpen: boolean;
onClose: () => void;
}
function LoadMessageModal({ isOpen, onClose }: Props): JSX.Element {
const handleDelMsg = () => {
clearMsgs().then().catch();
onClose();
};
const handleLoadMsg = () => {
fetchMsgs()
.then((data) => {
if (
data === undefined ||
data.messages === undefined ||
data.messages.length === 0
) {
return;
}
const { messages } = data;
messages.forEach((msg: ResFetchMsg) => {
switch (msg.role) {
case "user":
sendChatMessageFromEvent(msg.payload);
break;
case "assistant":
handleAssistantMessage(msg.payload);
break;
default:
}
});
})
.catch();
onClose();
};
return (
<Modal isOpen={isOpen} onClose={onClose} hideCloseButton backdrop="blur">
<ModalContent>
<>
<ModalHeader className="flex flex-col gap-1">
Unfinished Session Detected
</ModalHeader>
<ModalBody>
You have an unfinished session. Do you want to load it?
</ModalBody>
<ModalFooter>
<Button color="danger" variant="light" onPress={handleDelMsg}>
No, start a new session
</Button>
<Button color="primary" onPress={handleLoadMsg}>
Okay, load it
</Button>
</ModalFooter>
</>
</ModalContent>
</Modal>
);
}
export default LoadMessageModal;

View File

@ -15,23 +15,15 @@ import {
} from "@nextui-org/react";
import { KeyboardEvent } from "@react-types/shared/src/events";
import { useTranslation } from "react-i18next";
import i18next from "i18next";
import {
INITIAL_AGENTS,
fetchModels,
fetchAgents,
INITIAL_MODELS,
sendSettings,
saveSettings,
getInitialModel,
} from "../services/settingsService";
import {
setModel,
setAgent,
setWorkspaceDirectory,
setLanguage,
} from "../state/settingsSlice";
import store, { RootState } from "../store";
import socket from "../socket/socket";
import { RootState } from "../store";
import { I18nKey } from "../i18n/declaration";
import { AvailableLanguages } from "../i18n";
@ -48,13 +40,22 @@ const cachedAgents = JSON.parse(
);
function SettingModal({ isOpen, onClose }: Props): JSX.Element {
const { t } = useTranslation();
const model = useSelector((state: RootState) => state.settings.model);
const agent = useSelector((state: RootState) => state.settings.agent);
const workspaceDirectory = useSelector(
const defModel = useSelector((state: RootState) => state.settings.model);
const [model, setModel] = useState(defModel);
const defAgent = useSelector((state: RootState) => state.settings.agent);
const [agent, setAgent] = useState(defAgent);
const defWorkspaceDirectory = useSelector(
(state: RootState) => state.settings.workspaceDirectory,
);
const language = useSelector((state: RootState) => state.settings.language);
const [workspaceDirectory, setWorkspaceDirectory] = useState(
defWorkspaceDirectory,
);
const defLanguage = useSelector(
(state: RootState) => state.settings.language,
);
const [language, setLanguage] = useState(defLanguage);
const { t } = useTranslation();
const [supportedModels, setSupportedModels] = useState(
cachedModels.length > 0 ? cachedModels : INITIAL_MODELS,
@ -64,11 +65,11 @@ function SettingModal({ isOpen, onClose }: Props): JSX.Element {
);
useEffect(() => {
async function setInitialModel() {
const initialModel = await getInitialModel();
store.dispatch(setModel(initialModel));
}
setInitialModel();
getInitialModel()
.then((initialModel) => {
setModel(initialModel);
})
.catch();
fetchModels().then((fetchedModels) => {
setSupportedModels(fetchedModels);
@ -81,24 +82,12 @@ function SettingModal({ isOpen, onClose }: Props): JSX.Element {
}, []);
const handleSaveCfg = () => {
const previousModel = localStorage.getItem("model");
const previousWorkspaceDirectory =
localStorage.getItem("workspaceDirectory");
const previousAgent = localStorage.getItem("agent");
if (
model !== previousModel ||
agent !== previousAgent ||
workspaceDirectory !== previousWorkspaceDirectory
) {
sendSettings(socket, { model, agent, workspaceDirectory, language });
}
localStorage.setItem("model", model);
localStorage.setItem("workspaceDirectory", workspaceDirectory);
localStorage.setItem("agent", agent);
localStorage.setItem("language", language);
i18next.changeLanguage(language);
saveSettings(
{ model, agent, workspaceDirectory, language },
model !== defModel &&
agent !== defAgent &&
workspaceDirectory !== defWorkspaceDirectory,
);
onClose();
};
@ -122,9 +111,7 @@ function SettingModal({ isOpen, onClose }: Props): JSX.Element {
placeholder={t(
I18nKey.CONFIGURATION$OPENDEVIN_WORKSPACE_DIRECTORY_INPUT_PLACEHOLDER,
)}
onChange={(e) =>
store.dispatch(setWorkspaceDirectory(e.target.value))
}
onChange={(e) => setWorkspaceDirectory(e.target.value)}
/>
<Autocomplete
@ -136,7 +123,7 @@ function SettingModal({ isOpen, onClose }: Props): JSX.Element {
placeholder={t(I18nKey.CONFIGURATION$MODEL_SELECT_PLACEHOLDER)}
selectedKey={model}
onSelectionChange={(key) => {
store.dispatch(setModel(key as string));
setModel(key as string);
}}
onKeyDown={(e: KeyboardEvent) => e.continuePropagation()}
defaultFilter={customFilter}
@ -157,7 +144,7 @@ function SettingModal({ isOpen, onClose }: Props): JSX.Element {
placeholder={t(I18nKey.CONFIGURATION$AGENT_SELECT_PLACEHOLDER)}
defaultSelectedKey={agent}
onSelectionChange={(key) => {
store.dispatch(setAgent(key as string));
setAgent(key as string);
}}
onKeyDown={(e: KeyboardEvent) => e.continuePropagation()}
defaultFilter={customFilter}
@ -170,9 +157,7 @@ function SettingModal({ isOpen, onClose }: Props): JSX.Element {
</Autocomplete>
<Select
selectionMode="single"
onChange={(e) => {
store.dispatch(setLanguage(e.target.value));
}}
onChange={(e) => setLanguage(e.target.value)}
selectedKeys={[language]}
label={t(I18nKey.CONFIGURATION$LANGUAGE_SELECT_LABEL)}
>

View File

@ -1,16 +1,15 @@
import React, { useEffect, useRef } from "react";
import { IDisposable, Terminal as XtermTerminal } from "@xterm/xterm";
import "@xterm/xterm/css/xterm.css";
import React, { useEffect, useRef } from "react";
import { useSelector } from "react-redux";
import { FitAddon } from "xterm-addon-fit";
import socket from "../socket/socket";
import Socket from "../services/socket";
import { RootState } from "../store";
class JsonWebsocketAddon {
_socket: WebSocket;
_disposables: IDisposable[];
constructor(_socket: WebSocket) {
this._socket = _socket;
constructor() {
this._disposables = [];
}
@ -18,10 +17,10 @@ class JsonWebsocketAddon {
this._disposables.push(
terminal.onData((data) => {
const payload = JSON.stringify({ action: "terminal", data });
this._socket.send(payload);
Socket.send(payload);
}),
);
this._socket.addEventListener("message", (event) => {
Socket.addEventListener("message", (event) => {
const { action, args, observation, content } = JSON.parse(event.data);
if (action === "run") {
terminal.writeln(args.command);
@ -37,7 +36,7 @@ class JsonWebsocketAddon {
dispose() {
this._disposables.forEach((d) => d.dispose());
this._socket.removeEventListener("message", () => {});
Socket.removeEventListener("message", () => {});
}
}
@ -48,6 +47,7 @@ class JsonWebsocketAddon {
function Terminal(): JSX.Element {
const terminalRef = useRef<HTMLDivElement>(null);
const { commands } = useSelector((state: RootState) => state.cmd);
useEffect(() => {
const bgColor = getComputedStyle(document.documentElement)
@ -80,13 +80,25 @@ function Terminal(): JSX.Element {
fitAddon.fit();
}, 1);
const jsonWebsocketAddon = new JsonWebsocketAddon(socket);
const jsonWebsocketAddon = new JsonWebsocketAddon();
terminal.loadAddon(jsonWebsocketAddon);
// FIXME, temporary solution to display the terminal,
// but it will rerender the terminal every time the commands change
commands.forEach((command) => {
if (command.type === "input") {
terminal.writeln(command.content);
} else {
command.content.split("\n").forEach((line: string) => {
terminal.writeln(line);
});
terminal.write("\n$ ");
}
});
return () => {
terminal.dispose();
};
}, []);
}, [commands]);
return (
<div className="flex flex-col h-full">

View File

@ -1,15 +0,0 @@
.errors {
position: fixed;
left: 50%;
transform: translateX(-50%);
top: 1rem;
z-index: 1000;
}
.error {
background-color: #B00020;
padding: 1rem;
border-radius: 0.5rem;
box-shadow: 1px 1px 5px rgba(0, 0, 0, 0.5);
margin-bottom: 0.5rem;
}

View File

@ -8,7 +8,7 @@ export const useTypingEffect = (
{
loop = false,
playbackRate = 0.1,
setTypingAcitve = () => {},
setTypingActive = () => {},
setCurrentQueueMarkerState = () => {},
currentQueueMarker = 0,
addAssistantMessageToChat = () => {},
@ -16,7 +16,7 @@ export const useTypingEffect = (
}: {
loop?: boolean;
playbackRate?: number;
setTypingAcitve?: (bool: boolean) => void;
setTypingActive?: (bool: boolean) => void;
setCurrentQueueMarkerState?: (marker: number) => void;
currentQueueMarker?: number;
addAssistantMessageToChat?: (msg: Message) => void;
@ -24,7 +24,7 @@ export const useTypingEffect = (
} = {
loop: false,
playbackRate: 0.1,
setTypingAcitve: () => {},
setTypingActive: () => {},
currentQueueMarker: 0,
addAssistantMessageToChat: () => {},
assistantMessageObj: { content: "", sender: "assistant" },
@ -49,7 +49,7 @@ export const useTypingEffect = (
stringIndex++;
if (stringIndex === strings.length) {
if (!loop) {
setTypingAcitve(false);
setTypingActive(false);
setCurrentQueueMarkerState(currentQueueMarker + 1);
addAssistantMessageToChat(assistantMessageObj);
return;

View File

@ -4,15 +4,24 @@ import { setScreenshotSrc, setUrl } from "../state/browserSlice";
import { appendAssistantMessage } from "../state/chatSlice";
import { setCode } from "../state/codeSlice";
import { setInitialized } from "../state/taskSlice";
import { handleObservationMessage } from "./observations";
import { appendInput } from "../state/commandSlice";
import { SocketMessage } from "../types/ResponseType";
let isInitialized = false;
const messageActions = {
initialize: () => {
store.dispatch(setInitialized(true));
if (isInitialized) {
return;
}
store.dispatch(
appendAssistantMessage(
"Hi! I'm OpenDevin, an AI Software Engineer. What would you like to build with me today?",
),
);
isInitialized = true;
},
browse: (message: ActionMessage) => {
const { url, screenshotSrc } = message.args;
@ -28,6 +37,9 @@ const messageActions = {
finish: (message: ActionMessage) => {
store.dispatch(appendAssistantMessage(message.message));
},
run: (message: ActionMessage) => {
store.dispatch(appendInput(message.args.command));
},
};
export function handleActionMessage(message: ActionMessage) {
@ -37,3 +49,19 @@ export function handleActionMessage(message: ActionMessage) {
actionFn(message);
}
}
export function handleAssistantMessage(data: string | SocketMessage) {
let socketMessage: SocketMessage;
if (typeof data === "string") {
socketMessage = JSON.parse(data) as SocketMessage;
} else {
socketMessage = data;
}
if ("action" in socketMessage) {
handleActionMessage(socketMessage);
} else {
handleObservationMessage(socketMessage);
}
}

View File

@ -0,0 +1,44 @@
import * as jose from "jose";
import { ResFetchToken } from "../types/ResponseType";
const fetchToken = async (): Promise<ResFetchToken> => {
const headers = new Headers({
"Content-Type": "application/json",
Authorization: `Bearer ${localStorage.getItem("token")}`,
});
const response = await fetch(`/api/auth`, { headers });
if (response.status !== 200) {
throw new Error("Get token failed.");
}
const data: ResFetchToken = await response.json();
return data;
};
const validateToken = (token: string): boolean => {
try {
const claims = jose.decodeJwt(token);
return !(claims.sid === undefined || claims.sid === "");
} catch (error) {
return false;
}
};
const getToken = async (): Promise<string> => {
const token = localStorage.getItem("token") ?? "";
if (validateToken(token)) {
return token;
}
const data = await fetchToken();
if (data.token === undefined || data.token === "") {
throw new Error("Get token failed.");
}
const newToken = data.token;
if (validateToken(newToken)) {
localStorage.setItem("token", newToken);
return newToken;
}
throw new Error("Token validation failed.");
};
export { getToken, fetchToken };

View File

@ -1,23 +1,41 @@
import {
Message,
appeendToNewChatSequence,
appendToNewChatSequence,
appendUserMessage,
emptyOutQueuedTyping,
setCurrentQueueMarker,
setCurrentTypingMessage,
toggleTypingActive,
} from "../state/chatSlice";
import socket from "../socket/socket";
import Socket from "./socket";
import store from "../store";
import { SocketMessage } from "../types/ResponseType";
import { ActionMessage } from "../types/Message";
export function sendChatMessage(message: string): void {
store.dispatch(appendUserMessage(message));
const event = { action: "start", args: { task: message } };
const eventString = JSON.stringify(event);
socket.send(eventString);
Socket.send(eventString);
}
export function setTypingAcitve(bool: boolean): void {
export function sendChatMessageFromEvent(event: string | SocketMessage): void {
try {
let data: ActionMessage;
if (typeof event === "string") {
data = JSON.parse(event);
} else {
data = event as ActionMessage;
}
if (data && data.args && data.args.task) {
store.dispatch(appendUserMessage(data.args.task));
}
} catch (error) {
//
}
}
export function setTypingActive(bool: boolean): void {
store.dispatch(toggleTypingActive(bool));
}
@ -32,5 +50,5 @@ export function setCurrentQueueMarkerState(index: number): void {
store.dispatch(setCurrentQueueMarker(index));
}
export function addAssistantMessageToChat(msg: Message): void {
store.dispatch(appeendToNewChatSequence(msg));
store.dispatch(appendToNewChatSequence(msg));
}

View File

@ -0,0 +1,25 @@
import { appendAssistantMessage } from "../state/chatSlice";
import { setUrl, setScreenshotSrc } from "../state/browserSlice";
import store from "../store";
import { ObservationMessage } from "../types/Message";
import { appendOutput } from "../state/commandSlice";
import ObservationType from "../types/ObservationType";
export function handleObservationMessage(message: ObservationMessage) {
switch (message.observation) {
case ObservationType.RUN:
store.dispatch(appendOutput(message.content));
break;
case ObservationType.BROWSE:
if (message.extras?.screenshot) {
store.dispatch(setScreenshotSrc(message.extras.screenshot));
}
if (message.extras?.url) {
store.dispatch(setUrl(message.extras.url));
}
break;
default:
store.dispatch(appendAssistantMessage(message.message));
break;
}
}

View File

@ -0,0 +1,49 @@
import {
ResDelMsg,
ResFetchMsgs,
ResFetchMsgTotal,
} from "../types/ResponseType";
const fetchMsgTotal = async (): Promise<ResFetchMsgTotal> => {
const headers = new Headers({
"Content-Type": "application/json",
Authorization: `Bearer ${localStorage.getItem("token")}`,
});
const response = await fetch(`/api/messages/total`, { headers });
if (response.status !== 200) {
throw new Error("Get message total failed.");
}
const data: ResFetchMsgTotal = await response.json();
return data;
};
const fetchMsgs = async (): Promise<ResFetchMsgs> => {
const headers = new Headers({
"Content-Type": "application/json",
Authorization: `Bearer ${localStorage.getItem("token")}`,
});
const response = await fetch(`/api/messages`, { headers });
if (response.status !== 200) {
throw new Error("Get messages failed.");
}
const data: ResFetchMsgs = await response.json();
return data;
};
const clearMsgs = async (): Promise<ResDelMsg> => {
const headers = new Headers({
"Content-Type": "application/json",
Authorization: `Bearer ${localStorage.getItem("token")}`,
});
const response = await fetch(`/api/messages`, {
method: "DELETE",
headers,
});
if (response.status !== 200) {
throw new Error("Delete messages failed.");
}
const data: ResDelMsg = await response.json();
return data;
};
export { fetchMsgTotal, fetchMsgs, clearMsgs };

View File

@ -1,6 +1,13 @@
import { appendAssistantMessage } from "../state/chatSlice";
import { setInitialized } from "../state/taskSlice";
import store from "../store";
import Socket from "./socket";
import {
setAgent,
setLanguage,
setModel,
setWorkspaceDirectory,
} from "../state/settingsSlice";
export async function getInitialModel() {
if (localStorage.getItem("model")) {
@ -43,24 +50,28 @@ const SETTINGS_MAP = new Map<string, string>([
]);
// Send settings to the server
export function sendSettings(
socket: WebSocket,
export function saveSettings(
reduxSettings: { [id: string]: string },
appendMessages: boolean = true,
needToSend: boolean = false,
): void {
const socketSettings = Object.fromEntries(
Object.entries(reduxSettings).map(([setting, value]) => [
SETTINGS_MAP.get(setting) || setting,
value,
]),
);
const event = { action: "initialize", args: socketSettings };
const eventString = JSON.stringify(event);
socket.send(eventString);
store.dispatch(setInitialized(false));
if (appendMessages) {
for (const [setting, value] of Object.entries(reduxSettings)) {
store.dispatch(appendAssistantMessage(`Set ${setting} to "${value}"`));
}
if (needToSend) {
const socketSettings = Object.fromEntries(
Object.entries(reduxSettings).map(([setting, value]) => [
SETTINGS_MAP.get(setting) || setting,
value,
]),
);
const event = { action: "initialize", args: socketSettings };
const eventString = JSON.stringify(event);
store.dispatch(setInitialized(false));
Socket.send(eventString);
}
for (const [setting, value] of Object.entries(reduxSettings)) {
localStorage.setItem(setting, value);
store.dispatch(appendAssistantMessage(`Set ${setting} to "${value}"`));
}
store.dispatch(setModel(reduxSettings.model));
store.dispatch(setAgent(reduxSettings.agent));
store.dispatch(setWorkspaceDirectory(reduxSettings.workspaceDirectory));
store.dispatch(setLanguage(reduxSettings.language));
}

View File

@ -0,0 +1,99 @@
import store from "../store";
import { appendError, removeError } from "../state/errorsSlice";
import { handleAssistantMessage } from "./actions";
import { getToken } from "./auth";
import ActionType from "../types/ActionType";
class Socket {
private static _socket: WebSocket | null = null;
public static initialize(): void {
getToken()
.then((token) => {
Socket._initialize(token);
})
.catch((err) => {
const msg = `Failed to get token: ${err}.`;
store.dispatch(appendError(msg));
setTimeout(() => {
store.dispatch(removeError(msg));
}, 2000);
});
}
private static _initialize(token: string): void {
if (!Socket._socket || Socket._socket.readyState !== WebSocket.OPEN) {
const WS_URL = `ws://${window.location.host}/ws?token=${token}`;
Socket._socket = new WebSocket(WS_URL);
Socket._socket.onopen = () => {
const model = localStorage.getItem("model") || "gpt-3.5-turbo-1106";
const agent = localStorage.getItem("agent") || "MonologueAgent";
const workspaceDirectory =
localStorage.getItem("workspaceDirectory") || "./workspace";
Socket._socket?.send(
JSON.stringify({
action: ActionType.INIT,
args: {
model,
agent_cls: agent,
directory: workspaceDirectory,
},
}),
);
};
Socket._socket.onmessage = (e) => {
handleAssistantMessage(e.data);
};
Socket._socket.onerror = () => {
const msg = "Failed connection to server";
store.dispatch(appendError(msg));
setTimeout(() => {
store.dispatch(removeError(msg));
}, 2000);
};
Socket._socket.onclose = () => {
// Reconnect after a delay
setTimeout(() => {
Socket.initialize();
}, 3000); // Reconnect after 3 seconds
};
}
}
static send(message: string): void {
Socket.initialize();
if (Socket._socket && Socket._socket.readyState === WebSocket.OPEN) {
Socket._socket.send(message);
} else {
store.dispatch(appendError("WebSocket connection is not ready."));
}
}
static addEventListener(
event: string,
callback: (e: MessageEvent) => void,
): void {
Socket._socket?.addEventListener(
event as keyof WebSocketEventMap,
callback as (
this: WebSocket,
ev: WebSocketEventMap[keyof WebSocketEventMap],
) => never,
);
}
static removeEventListener(
event: string,
listener: (e: Event) => void,
): void {
Socket._socket?.removeEventListener(event, listener);
}
}
Socket.initialize();
export default Socket;

View File

@ -1,16 +0,0 @@
import { appendAssistantMessage } from "../state/chatSlice";
import { setUrl, setScreenshotSrc } from "../state/browserSlice";
import store from "../store";
import { ObservationMessage } from "../types/Message";
export function handleObservationMessage(message: ObservationMessage) {
store.dispatch(appendAssistantMessage(message.message));
if (message.observation === "browse") {
if (message.extras?.screenshot) {
store.dispatch(setScreenshotSrc(message.extras.screenshot));
}
if (message.extras?.url) {
store.dispatch(setUrl(message.extras.url));
}
}
}

View File

@ -1,44 +0,0 @@
import store from "../store";
import { ActionMessage, ObservationMessage } from "../types/Message";
import { appendError } from "../state/errorsSlice";
import { handleActionMessage } from "./actions";
import { handleObservationMessage } from "./observations";
import { sendSettings } from "../services/settingsService";
type SocketMessage = ActionMessage | ObservationMessage;
const WS_URL = `ws://${window.location.host}/ws`;
const socket = new WebSocket(WS_URL);
socket.addEventListener("open", () => {
const settingKeys = ["model", "agent", "workspaceDirectory"];
const settings = settingKeys.reduce(
(acc, key) => {
const value = localStorage.getItem(key);
if (value) {
acc[key] = value;
}
return acc;
},
{} as Record<string, string>,
);
sendSettings(socket, settings, false);
});
socket.addEventListener("message", (event) => {
const socketMessage = JSON.parse(event.data) as SocketMessage;
if ("action" in socketMessage) {
handleActionMessage(socketMessage);
} else {
handleObservationMessage(socketMessage);
}
});
socket.addEventListener("error", () => {
store.dispatch(
appendError(
`Failed connection to server. Please ensure the server is reachable at ${WS_URL}.`,
),
);
});
export default socket;

View File

@ -49,7 +49,7 @@ export const chatSlice = createSlice({
state.currentTypingMessage = action.payload;
// state.currentQueueMarker += 1;
},
appeendToNewChatSequence: (state, action) => {
appendToNewChatSequence: (state, action) => {
state.newChatSequence.push(action.payload);
},
},
@ -62,7 +62,7 @@ export const {
emptyOutQueuedTyping,
setCurrentTypingMessage,
setCurrentQueueMarker,
appeendToNewChatSequence,
appendToNewChatSequence,
} = chatSlice.actions;
export default chatSlice.reducer;

View File

@ -0,0 +1,27 @@
import { createSlice } from "@reduxjs/toolkit";
export type Command = {
content: string;
type: "input" | "output";
};
const initialCommands: Command[] = [];
export const commandSlice = createSlice({
name: "command",
initialState: {
commands: initialCommands,
},
reducers: {
appendInput: (state, action) => {
state.commands.push({ content: action.payload, type: "input" });
},
appendOutput: (state, action) => {
state.commands.push({ content: action.payload, type: "output" });
},
},
});
export const { appendInput, appendOutput } = commandSlice.actions;
export default commandSlice.reducer;

View File

@ -11,9 +11,12 @@ export const errorsSlice = createSlice({
appendError: (state, action) => {
state.errors.push(action.payload);
},
removeError: (state, action) => {
state.errors = state.errors.filter((error) => error !== action.payload);
},
},
});
export const { appendError } = errorsSlice.actions;
export const { appendError, removeError } = errorsSlice.actions;
export default errorsSlice.reducer;

View File

@ -0,0 +1,17 @@
import { createSlice } from "@reduxjs/toolkit";
export const globalSlice = createSlice({
name: "global",
initialState: {
initialized: false,
},
reducers: {
setInitialized: (state, action) => {
state.initialized = action.payload;
},
},
});
export const { setInitialized } = globalSlice.actions;
export default globalSlice.reducer;

View File

@ -1,9 +1,10 @@
import { createSlice } from "@reduxjs/toolkit";
import i18next from "i18next";
export const settingsSlice = createSlice({
name: "settings",
initialState: {
model: localStorage.getItem("model") || "",
model: localStorage.getItem("model") || "gpt-3.5-turbo-1106",
agent: localStorage.getItem("agent") || "MonologueAgent",
workspaceDirectory:
localStorage.getItem("workspaceDirectory") || "./workspace",
@ -11,16 +12,21 @@ export const settingsSlice = createSlice({
},
reducers: {
setModel: (state, action) => {
localStorage.setItem("model", action.payload);
state.model = action.payload;
},
setAgent: (state, action) => {
localStorage.setItem("agent", action.payload);
state.agent = action.payload;
},
setWorkspaceDirectory: (state, action) => {
localStorage.setItem("workspaceDirectory", action.payload);
state.workspaceDirectory = action.payload;
},
setLanguage: (state, action) => {
localStorage.setItem("workspaceDirectory", action.payload);
state.language = action.payload;
i18next.changeLanguage(action.payload);
},
},
});

View File

@ -2,8 +2,10 @@ import { configureStore } from "@reduxjs/toolkit";
import browserReducer from "./state/browserSlice";
import chatReducer from "./state/chatSlice";
import codeReducer from "./state/codeSlice";
import commandReducer from "./state/commandSlice";
import taskReducer from "./state/taskSlice";
import errorsReducer from "./state/errorsSlice";
import globalReducer from "./state/globalSlice";
import settingsReducer from "./state/settingsSlice";
const store = configureStore({
@ -11,8 +13,10 @@ const store = configureStore({
browser: browserReducer,
chat: chatReducer,
code: codeReducer,
cmd: commandReducer,
task: taskReducer,
errors: errorsReducer,
global: globalReducer,
settings: settingsReducer,
},
});

View File

@ -0,0 +1,34 @@
import { ActionMessage, ObservationMessage } from "./Message";
interface ResFetchToken {
token: string;
}
interface ResFetchMsgTotal {
msg_total: number;
}
interface ResFetchMsg {
id: string;
role: string;
payload: SocketMessage;
}
interface ResFetchMsgs {
messages: ResFetchMsg[];
}
interface ResDelMsg {
ok: string;
}
type SocketMessage = ActionMessage | ObservationMessage;
export {
type ResFetchToken,
type ResFetchMsgTotal,
type ResFetchMsg,
type ResFetchMsgs,
type ResDelMsg,
type SocketMessage,
};

View File

@ -1,4 +1,3 @@
import asyncio
import inspect
import traceback
@ -14,23 +13,39 @@ from opendevin.action import (
NullAction,
AgentFinishAction,
AddTaskAction,
ModifyTaskAction
)
from opendevin.observation import (
Observation,
AgentErrorObservation,
NullObservation
ModifyTaskAction,
)
from opendevin.observation import Observation, AgentErrorObservation, NullObservation
from opendevin import config
from .command_manager import CommandManager
ColorType = Literal['red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'light_grey', 'dark_grey', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_magenta', 'light_cyan', 'white']
ColorType = Literal[
"red",
"green",
"yellow",
"blue",
"magenta",
"cyan",
"light_grey",
"dark_grey",
"light_red",
"light_green",
"light_yellow",
"light_blue",
"light_magenta",
"light_cyan",
"white",
]
DISABLE_COLOR_PRINTING = config.get_or_default("DISABLE_COLOR", "false").lower() == "true"
DISABLE_COLOR_PRINTING = (
config.get_or_default("DISABLE_COLOR", "false").lower() == "true"
)
MAX_ITERATIONS = config.get("MAX_ITERATIONS")
def print_with_color(text: Any, print_type: str = "INFO"):
TYPE_TO_COLOR: Mapping[str, ColorType] = {
"BACKGROUND LOG": "blue",
@ -50,19 +65,24 @@ def print_with_color(text: Any, print_type: str = "INFO"):
flush=True,
)
class AgentController:
id: str
def __init__(
self,
agent: Agent,
workdir: str,
id: str = "",
max_iterations: int = MAX_ITERATIONS,
container_image: str | None = None,
callbacks: List[Callable] = [],
):
self.id = id
self.agent = agent
self.max_iterations = max_iterations
self.workdir = workdir
self.command_manager = CommandManager(workdir,container_image)
self.command_manager = CommandManager(self.id, workdir, container_image)
self.callbacks = callbacks
def update_state_for_step(self, i):
@ -120,7 +140,7 @@ class AgentController:
traceback.print_exc()
# TODO Change to more robust error handling
if "The api_key client option must be set" in observation.content:
raise
raise
self.update_state_after_step()
await self._run_callbacks(action)
@ -172,4 +192,6 @@ class AgentController:
except Exception as e:
print("Callback error:" + str(idx), e, flush=True)
pass
await asyncio.sleep(0.001) # Give back control for a tick, so we can await in callbacks
await asyncio.sleep(
0.001
) # Give back control for a tick, so we can await in callbacks

View File

@ -3,10 +3,18 @@ from typing import List
from opendevin.observation import CmdOutputObservation
from opendevin.sandbox.sandbox import DockerInteractive
class CommandManager:
def __init__(self, dir: str, container_image: str | None = None,):
def __init__(
self,
id: str,
dir: str,
container_image: str | None = None,
):
self.directory = dir
self.shell = DockerInteractive(id="default", workspace_dir=dir, container_image=container_image)
self.shell = DockerInteractive(
id=(id or "default"), workspace_dir=dir, container_image=container_image
)
def run_command(self, command: str, background=False) -> CmdOutputObservation:
if background:
@ -17,10 +25,7 @@ class CommandManager:
def _run_immediately(self, command: str) -> CmdOutputObservation:
exit_code, output = self.shell.execute(command)
return CmdOutputObservation(
command_id=-1,
content=output,
command=command,
exit_code=exit_code
command_id=-1, content=output, command=command, exit_code=exit_code
)
def _run_background(self, command: str) -> CmdOutputObservation:
@ -29,7 +34,7 @@ class CommandManager:
content=f"Background command started. To stop it, send a `kill` action with id {bg_cmd.id}",
command_id=bg_cmd.id,
command=command,
exit_code=0
exit_code=0,
)
def kill_command(self, id: int) -> CmdOutputObservation:
@ -38,7 +43,7 @@ class CommandManager:
content=f"Background command with id {id} has been killed.",
command_id=id,
command=cmd.command,
exit_code=0
exit_code=0,
)
def get_background_obs(self) -> List[CmdOutputObservation]:

View File

@ -9,26 +9,61 @@ from opendevin.agent import Agent
from opendevin.controller import AgentController
from opendevin.llm.llm import LLM
def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
return file.read()
def read_task_from_stdin() -> str:
"""Read task from stdin."""
return sys.stdin.read()
def parse_arguments():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="Run an agent with a specific task")
parser.add_argument("-d", "--directory", required=True, type=str, help="The working directory for the agent")
parser.add_argument("-t", "--task", type=str, default="", help="The task for the agent to perform")
parser.add_argument("-f", "--file", type=str, help="Path to a file containing the task. Overrides -t if both are provided.")
parser.add_argument("-c", "--agent-cls", default="MonologueAgent", type=str, help="The agent class to use")
parser.add_argument("-m", "--model-name", default=config.get("LLM_MODEL"), type=str, help="The (litellm) model name to use")
parser.add_argument("-i", "--max-iterations", default=config.get("MAX_ITERATIONS"), type=int, help="The maximum number of iterations to run the agent")
parser.add_argument(
"-d",
"--directory",
required=True,
type=str,
help="The working directory for the agent",
)
parser.add_argument(
"-t", "--task", type=str, default="", help="The task for the agent to perform"
)
parser.add_argument(
"-f",
"--file",
type=str,
help="Path to a file containing the task. Overrides -t if both are provided.",
)
parser.add_argument(
"-c",
"--agent-cls",
default="MonologueAgent",
type=str,
help="The agent class to use",
)
parser.add_argument(
"-m",
"--model-name",
default=config.get_or_default("LLM_MODEL", "gpt-4-0125-preview"),
type=str,
help="The (litellm) model name to use",
)
parser.add_argument(
"-i",
"--max-iterations",
default=100,
type=int,
help="The maximum number of iterations to run the agent",
)
return parser.parse_args()
async def main():
"""Main coroutine to run the agent controller with task input flexibility."""
args = parse_arguments()
@ -44,13 +79,18 @@ async def main():
if not task:
raise ValueError("No task provided. Please specify a task through -t, -f.")
print(f"Running agent {args.agent_cls} (model: {args.model_name}, directory: {args.directory}) with task: \"{task}\"")
print(
f'Running agent {args.agent_cls} (model: {args.model_name}, directory: {args.directory}) with task: "{task}"'
)
llm = LLM(args.model_name)
AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
agent = AgentCls(llm=llm)
controller = AgentController(agent, workdir=args.directory, max_iterations=args.max_iterations)
controller = AgentController(
agent=agent, workdir=args.directory, max_iterations=args.max_iterations
)
await controller.start_loop(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -15,7 +15,9 @@ from opendevin import config
InputType = namedtuple("InputType", ["content"])
OutputType = namedtuple("OutputType", ["content"])
DIRECTORY_REWRITE = config.get("DIRECTORY_REWRITE") # helpful for docker-in-docker scenarios
DIRECTORY_REWRITE = config.get(
"DIRECTORY_REWRITE"
) # helpful for docker-in-docker scenarios
CONTAINER_IMAGE = config.get("SANDBOX_CONTAINER_IMAGE")
# FIXME: On some containers, the devin user doesn't have enough permission, e.g. to install packages
@ -120,7 +122,8 @@ class DockerInteractive:
self.container_name = f"sandbox-{self.instance_id}"
self.restart_docker_container()
if not self.is_container_running():
self.restart_docker_container()
if RUN_AS_DEVIN:
self.setup_devin_user()
atexit.register(self.cleanup)
@ -150,20 +153,21 @@ class DockerInteractive:
def execute(self, cmd: str) -> Tuple[int, str]:
# TODO: each execute is not stateful! We need to keep track of the current working directory
def run_command(container, command):
return container.exec_run(command,workdir="/workspace")
return container.exec_run(command, workdir="/workspace")
# Use ThreadPoolExecutor to control command and set timeout
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_command, self.container, self.get_exec_cmd(cmd))
future = executor.submit(
run_command, self.container, self.get_exec_cmd(cmd)
)
try:
exit_code, logs = future.result(timeout=self.timeout)
except concurrent.futures.TimeoutError:
print("Command timed out, killing process...")
pid = self.get_pid(cmd)
if pid is not None:
self.container.exec_run(
f"kill -9 {pid}", workdir="/workspace"
)
return -1, f"Command: \"{cmd}\" timed out"
self.container.exec_run(f"kill -9 {pid}", workdir="/workspace")
return -1, f'Command: "{cmd}" timed out'
return exit_code, logs.decode("utf-8")
def execute_in_background(self, cmd: str) -> BackgroundCommand:
@ -179,12 +183,12 @@ class DockerInteractive:
def get_pid(self, cmd):
exec_result = self.container.exec_run("ps aux")
processes = exec_result.output.decode('utf-8').splitlines()
processes = exec_result.output.decode("utf-8").splitlines()
cmd = " ".join(self.get_exec_cmd(cmd))
for process in processes:
if cmd in process:
pid = process.split()[1] # second column is the pid
pid = process.split()[1] # second column is the pid
return pid
return None
@ -193,9 +197,7 @@ class DockerInteractive:
raise ValueError("Invalid background command id")
bg_cmd = self.background_commands[id]
if bg_cmd.pid is not None:
self.container.exec_run(
f"kill -9 {bg_cmd.pid}", workdir="/workspace"
)
self.container.exec_run(f"kill -9 {bg_cmd.pid}", workdir="/workspace")
bg_cmd.result.output.close()
self.background_commands.pop(id)
return bg_cmd
@ -210,7 +212,7 @@ class DockerInteractive:
try:
docker_client = docker.from_env()
except docker.errors.DockerException as e:
print('Please check Docker is running using `docker ps`.')
print("Please check Docker is running using `docker ps`.")
print(f"Error! {e}", flush=True)
raise e
@ -228,12 +230,23 @@ class DockerInteractive:
except docker.errors.NotFound:
pass
def is_container_running(self):
try:
docker_client = docker.from_env()
container = docker_client.containers.get(self.container_name)
if container.status == "running":
self.container = container
return True
return False
except docker.errors.NotFound:
return False
def restart_docker_container(self):
try:
self.stop_docker_container()
except docker.errors.DockerException as e:
print(f"Failed to stop container: {e}")
raise e
raise e
try:
# Initialize docker client. Throws an exception if Docker is not reachable.
@ -273,7 +286,10 @@ class DockerInteractive:
def cleanup(self):
if self.closed:
return
self.container.remove(force=True)
try:
self.container.remove(force=True)
except docker.errors.NotFound:
pass
if __name__ == "__main__":

View File

View File

@ -0,0 +1,3 @@
from .manager import AgentManager
__all__ = ["AgentManager"]

View File

@ -2,8 +2,6 @@ import asyncio
import os
from typing import Optional
from fastapi import WebSocketDisconnect
from opendevin import config
from opendevin.action import (
Action,
@ -14,6 +12,7 @@ from opendevin.agent import Agent
from opendevin.controller import AgentController
from opendevin.llm.llm import LLM
from opendevin.observation import Observation, UserMessageObservation
from opendevin.server.session import session_manager
DEFAULT_API_KEY = config.get("LLM_API_KEY")
DEFAULT_BASE_URL = config.get("LLM_BASE_URL")
@ -22,22 +21,21 @@ LLM_MODEL = config.get("LLM_MODEL")
CONTAINER_IMAGE = config.get("SANDBOX_CONTAINER_IMAGE")
MAX_ITERATIONS = config.get("MAX_ITERATIONS")
class Session:
class AgentManager:
"""Represents a session with an agent.
Attributes:
websocket: The WebSocket connection associated with the session.
controller: The AgentController instance for controlling the agent.
agent: The Agent instance representing the agent.
agent_task: The task representing the agent's execution.
"""
def __init__(self, websocket):
"""Initializes a new instance of the Session class.
Args:
websocket: The WebSocket connection associated with the session.
"""
self.websocket = websocket
sid: str
def __init__(self, sid):
"""Initializes a new instance of the Session class."""
self.sid = sid
self.controller: Optional[AgentController] = None
self.agent: Optional[Agent] = None
self.agent_task = None
@ -48,7 +46,7 @@ class Session:
Args:
message: The error message to send.
"""
await self.send({"error": True, "message": message})
await session_manager.send_error(self.sid, message)
async def send_message(self, message):
"""Sends a message to the client.
@ -56,7 +54,7 @@ class Session:
Args:
message: The message to send.
"""
await self.send({"message": message})
await session_manager.send_message(self.sid, message)
async def send(self, data):
"""Sends data to the client.
@ -64,42 +62,27 @@ class Session:
Args:
data: The data to send.
"""
if self.websocket is None:
await session_manager.send(self.sid, data)
async def dispatch(self, action: str | None, data: dict):
"""Dispatches actions to the agent from the client."""
if action is None:
await self.send_error("Invalid action")
return
try:
await self.websocket.send_json(data)
except Exception as e:
print("Error sending data to client", e)
async def start_listening(self):
"""Starts listening for messages from the client."""
try:
while True:
try:
data = await self.websocket.receive_json()
except ValueError:
await self.send_error("Invalid JSON")
continue
action = data.get("action", None)
if action is None:
await self.send_error("Invalid event")
continue
if action == "initialize":
await self.create_controller(data)
elif action == "start":
await self.start_task(data)
else:
if self.controller is None:
await self.send_error("No agent started. Please wait a second...")
elif action == "chat":
self.controller.add_history(NullAction(), UserMessageObservation(data["message"]))
else:
await self.send_error("I didn't recognize this action:" + action)
except WebSocketDisconnect as e:
print("Client websocket disconnected", e)
self.disconnect()
if action == "initialize":
await self.create_controller(data)
elif action == "start":
await self.start_task(data)
else:
if self.controller is None:
await self.send_error("No agent started. Please wait a second...")
elif action == "chat":
self.controller.add_history(
NullAction(), UserMessageObservation(data["message"])
)
else:
await self.send_error("I didn't recognize this action:" + action)
async def create_controller(self, start_event=None):
"""Creates an AgentController instance.
@ -128,6 +111,15 @@ class Session:
max_iterations = MAX_ITERATIONS
if start_event and "max_iterations" in start_event["args"]:
max_iterations = start_event["args"]["max_iterations"]
# double check preventing error occurs
if directory == "":
directory = DEFAULT_WORKSPACE_DIR
if agent_cls == "":
agent_cls = "MonologueAgent"
if model == "":
model = LLM_MODEL
if not os.path.exists(directory):
print(f"Workspace directory {directory} does not exist. Creating it...")
os.makedirs(directory)
@ -136,10 +128,19 @@ class Session:
AgentCls = Agent.get_cls(agent_cls)
self.agent = AgentCls(llm)
try:
self.controller = AgentController(self.agent, workdir=directory, max_iterations=max_iterations, container_image=container_image, callbacks=[self.on_agent_event])
self.controller = AgentController(
id=self.sid,
agent=self.agent,
workdir=directory,
max_iterations=max_iterations,
container_image=container_image,
callbacks=[self.on_agent_event],
)
except Exception:
print("Error creating controller.")
await self.send_error("Error creating controller. Please check Docker is running using `docker ps`.")
await self.send_error(
"Error creating controller. Please check Docker is running using `docker ps`."
)
return
await self.send({"action": "initialize", "message": "Control loop started."})
@ -158,7 +159,9 @@ class Session:
await self.send_error("No agent started. Please wait a second...")
return
try:
self.agent_task = await asyncio.create_task(self.controller.start_loop(task), name="agent loop")
self.agent_task = await asyncio.create_task(
self.controller.start_loop(task), name="agent loop"
)
except Exception:
await self.send_error("Error during task loop.")
@ -174,7 +177,7 @@ class Session:
return
event_dict = event.to_dict()
asyncio.create_task(self.send(event_dict), name="send event in callback")
def disconnect(self):
self.websocket = None
if self.agent_task:

View File

@ -0,0 +1,3 @@
from .auth import get_sid_from_token, sign_token
__all__ = ["get_sid_from_token", "sign_token"]

View File

@ -0,0 +1,27 @@
import os
import jwt
from typing import Dict
JWT_SECRET = os.getenv("JWT_SECRET", "5ecRe7")
def get_sid_from_token(token: str) -> str:
"""Gets the session id from a JWT token."""
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
if payload is None:
print("Invalid token")
return ""
return payload["sid"]
except Exception as e:
print("Error decoding token:", e)
return ""
def sign_token(payload: Dict[str, object]) -> str:
"""Signs a JWT token."""
# payload = {
# "sid": sid,
# # "exp": datetime.now(timezone.utc) + timedelta(minutes=15),
# }
return jwt.encode(payload, JWT_SECRET, algorithm="HS256")

View File

@ -1,13 +1,20 @@
from opendevin.server.session import Session
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
import agenthub # noqa F401 (we import this to get the agents registered)
import litellm
import uuid
from opendevin.server.session import session_manager, message_stack
from opendevin.server.auth import get_sid_from_token, sign_token
from opendevin.agent import Agent
from opendevin.server.agent import AgentManager
import agenthub # noqa F401 (we import this to get the agents registered)
from fastapi import FastAPI, WebSocket, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import litellm
from starlette import status
from starlette.responses import JSONResponse
from opendevin import config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3001"],
@ -16,14 +23,21 @@ app.add_middleware(
allow_headers=["*"],
)
security_scheme = HTTPBearer()
# This endpoint receives events from the client (i.e. the browser)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
session = Session(websocket)
# TODO: should this use asyncio instead of await?
await session.start_listening()
sid = get_sid_from_token(websocket.query_params.get("token") or "")
if sid == "":
return
session_manager.add_session(sid, websocket)
# TODO: actually the agent_manager is created for each websocket connection, even if the session id is the same,
# we need to manage the agent in memory for reconnecting the same session id to the same agent
agent_manager = AgentManager(sid)
await session_manager.loop_recv(sid, agent_manager.dispatch)
@app.get("/litellm-models")
@ -41,6 +55,60 @@ async def get_litellm_agents():
"""
return Agent.listAgents()
@app.get("/auth")
async def get_token(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
):
"""
Get token for authentication when starts a websocket connection.
"""
sid = get_sid_from_token(credentials.credentials) or str(uuid.uuid4())
token = sign_token({"sid": sid})
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"token": token},
)
@app.get("/messages")
async def get_messages(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
):
data = []
sid = get_sid_from_token(credentials.credentials)
if sid != "":
data = message_stack.get_messages(sid)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"messages": data},
)
@app.get("/messages/total")
async def get_message_total(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
):
sid = get_sid_from_token(credentials.credentials)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"msg_total": message_stack.get_message_total(sid)},
)
@app.delete("/messages")
async def del_messages(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
):
sid = get_sid_from_token(credentials.credentials)
message_stack.del_messages(sid)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"ok": True},
)
@app.get("/default-model")
def read_default_model():
return config.get_or_error("LLM_MODEL")

View File

@ -0,0 +1,4 @@
from .action import ActionType
from .observation import ObservationType
__all__ = ["ActionType", "ObservationType"]

View File

@ -0,0 +1,6 @@
from .session import Session
from .manager import SessionManager
from .manager import session_manager
from .msg_stack import message_stack
__all__ = ["Session", "SessionManager", "session_manager", "message_stack"]

View File

@ -0,0 +1,90 @@
import os
import json
import atexit
import signal
from typing import Dict, Callable
from fastapi import WebSocket
from .session import Session
from .msg_stack import message_stack
CACHE_DIR = os.getenv("CACHE_DIR", "cache")
SESSION_CACHE_FILE = os.path.join(CACHE_DIR, "sessions.json")
class SessionManager:
_sessions: Dict[str, Session] = {}
def __init__(self):
self._load_sessions()
atexit.register(self.close)
signal.signal(signal.SIGINT, self.handle_signal)
signal.signal(signal.SIGTERM, self.handle_signal)
def add_session(self, sid: str, ws_conn: WebSocket):
if sid not in self._sessions:
self._sessions[sid] = Session(sid=sid, ws=ws_conn)
return
self._sessions[sid].update_connection(ws_conn)
async def loop_recv(self, sid: str, dispatch: Callable):
print(f"Starting loop_recv for sid: {sid}, {sid not in self._sessions}")
"""Starts listening for messages from the client."""
if sid not in self._sessions:
return
await self._sessions[sid].loop_recv(dispatch)
def close(self):
self._save_sessions()
def handle_signal(self, signum, _):
print(f"Received signal {signum}, exiting...")
self.close()
exit(0)
async def send(self, sid: str, data: Dict[str, object]) -> bool:
"""Sends data to the client."""
message_stack.add_message(sid, "assistant", data)
if sid not in self._sessions:
return False
return await self._sessions[sid].send(data)
async def send_error(self, sid: str, message: str) -> bool:
"""Sends an error message to the client."""
return await self.send(sid, {"error": True, "message": message})
async def send_message(self, sid: str, message: str) -> bool:
"""Sends a message to the client."""
return await self.send(sid, {"message": message})
def _save_sessions(self):
data = {}
for sid, conn in self._sessions.items():
data[sid] = {
"sid": conn.sid,
"last_active_ts": conn.last_active_ts,
"is_alive": conn.is_alive,
}
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
with open(SESSION_CACHE_FILE, "w+") as file:
json.dump(data, file)
def _load_sessions(self):
try:
with open(SESSION_CACHE_FILE, "r") as file:
data = json.load(file)
for sid, sdata in data.items():
conn = Session(sid, None)
ok = conn.load_from_data(sdata)
if ok:
self._sessions[sid] = conn
except FileNotFoundError:
pass
except json.decoder.JSONDecodeError:
pass
session_manager = SessionManager()

View File

@ -0,0 +1,99 @@
import os
import json
import atexit
import signal
import uuid
from typing import Dict, List
from opendevin.server.schema.action import ActionType
CACHE_DIR = os.getenv("CACHE_DIR", "cache")
MSG_CACHE_FILE = os.path.join(CACHE_DIR, "messages.json")
class Message:
id: str = str(uuid.uuid4())
role: str # "user"| "assistant"
payload: Dict[str, object]
def __init__(self, role: str, payload: Dict[str, object]):
self.role = role
self.payload = payload
def to_dict(self):
return {"id": self.id, "role": self.role, "payload": self.payload}
@classmethod
def from_dict(cls, data: Dict):
m = cls(data["role"], data["payload"])
m.id = data["id"]
return m
class MessageStack:
_messages: Dict[str, List[Message]] = {}
def __init__(self):
self._load_messages()
atexit.register(self.close)
signal.signal(signal.SIGINT, self.handle_signal)
signal.signal(signal.SIGTERM, self.handle_signal)
def close(self):
self._save_messages()
def handle_signal(self, signum, _):
print(f"Received signal {signum}, exiting...")
self.close()
exit(0)
def add_message(self, sid: str, role: str, message: Dict[str, object]):
if sid not in self._messages:
self._messages[sid] = []
self._messages[sid].append(Message(role, message))
def del_messages(self, sid: str):
if sid not in self._messages:
return
del self._messages[sid]
def get_messages(self, sid: str) -> List[Dict[str, object]]:
if sid not in self._messages:
return []
return [msg.to_dict() for msg in self._messages[sid]]
def get_message_total(self, sid: str) -> int:
if sid not in self._messages:
return 0
cnt = 0
for msg in self._messages[sid]:
# Ignore assistant init message for now.
if "action" in msg.payload and msg.payload["action"] == ActionType.INIT:
continue
cnt += 1
return cnt
def _save_messages(self):
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
data = {}
for sid, msgs in self._messages.items():
data[sid] = [msg.to_dict() for msg in msgs]
with open(MSG_CACHE_FILE, "w+") as file:
json.dump(data, file)
def _load_messages(self):
try:
# TODO: delete useless messages
with open(MSG_CACHE_FILE, "r") as file:
data = json.load(file)
for sid, msgs in data.items():
self._messages[sid] = [Message.from_dict(msg) for msg in msgs]
except FileNotFoundError:
pass
except json.decoder.JSONDecodeError:
pass
message_stack = MessageStack()

View File

@ -0,0 +1,72 @@
import time
from typing import Dict, Callable
from fastapi import WebSocket, WebSocketDisconnect
from .msg_stack import message_stack
DEL_DELT_SEC = 60 * 60 * 5
class Session:
sid: str
websocket: WebSocket | None
last_active_ts: int = 0
is_alive: bool = True
def __init__(self, sid: str, ws: WebSocket | None):
self.sid = sid
self.websocket = ws
self.last_active_ts = int(time.time())
async def loop_recv(self, dispatch: Callable):
try:
if self.websocket is None:
return
while True:
try:
data = await self.websocket.receive_json()
except ValueError:
await self.send_error("Invalid JSON")
continue
message_stack.add_message(self.sid, "user", data)
action = data.get("action", None)
await dispatch(action, data)
except WebSocketDisconnect:
self.is_alive = False
print(f"WebSocket disconnected, sid: {self.sid}")
except RuntimeError as e:
# WebSocket is not connected
if "WebSocket is not connected" in str(e):
self.is_alive = False
print(f"Error in loop_recv: {e}")
async def send(self, data: Dict[str, object]) -> bool:
try:
if self.websocket is None or not self.is_alive:
return False
await self.websocket.send_json(data)
self.last_active_ts = int(time.time())
return True
except WebSocketDisconnect:
self.is_alive = False
return False
async def send_error(self, message: str) -> bool:
"""Sends an error message to the client."""
return await self.send({"error": True, "message": message})
async def send_message(self, message: str) -> bool:
"""Sends a message to the client."""
return await self.send({"message": message})
def update_connection(self, ws: WebSocket):
self.websocket = ws
self.is_alive = True
self.last_active_ts = int(time.time())
def load_from_data(self, data: Dict) -> bool:
self.last_active_ts = data.get("last_active_ts", 0)
if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
return False
self.is_alive = data.get("is_alive", False)
return True