mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Separate agent controller and server via EventStream (#1538)
* move towards event stream * refactor agent state changes * move agent state logic * fix callbacks * break on finish * closer to working * change frontend to accomodate new flow * handle start action * fix locked stream * revert message * logspam * no async on close * get rid of agent_task * fix up closing * better asyncio handling * sleep to give back control * fix key * logspam * update frontend agent state actions * fix pause and cancel * delint * fix map * delint * wait for agent to finish * fix unit test * event stream enums * fix merge issues * fix lint * fix test * fix test * add user message action * add user message action * fix up user messages * fix main.py flow * refactor message waiting * lint * fix test * fix test
This commit is contained in:
parent
4e84aac577
commit
f7e0c6cd06
@ -4,47 +4,39 @@ import { useSelector } from "react-redux";
|
||||
import ArrowIcon from "#/assets/arrow";
|
||||
import PauseIcon from "#/assets/pause";
|
||||
import PlayIcon from "#/assets/play";
|
||||
import { changeTaskState } from "#/services/agentStateService";
|
||||
import { changeAgentState } from "#/services/agentStateService";
|
||||
import { clearMsgs } from "#/services/session";
|
||||
import store, { RootState } from "#/store";
|
||||
import AgentTaskAction from "#/types/AgentTaskAction";
|
||||
import AgentTaskState from "#/types/AgentTaskState";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import { clearMessages } from "#/state/chatSlice";
|
||||
|
||||
const TaskStateActionMap = {
|
||||
[AgentTaskAction.START]: AgentTaskState.RUNNING,
|
||||
[AgentTaskAction.PAUSE]: AgentTaskState.PAUSED,
|
||||
[AgentTaskAction.RESUME]: AgentTaskState.RUNNING,
|
||||
[AgentTaskAction.STOP]: AgentTaskState.STOPPED,
|
||||
};
|
||||
|
||||
const IgnoreTaskStateMap: { [k: string]: AgentTaskState[] } = {
|
||||
[AgentTaskAction.PAUSE]: [
|
||||
AgentTaskState.INIT,
|
||||
AgentTaskState.PAUSED,
|
||||
AgentTaskState.STOPPED,
|
||||
AgentTaskState.FINISHED,
|
||||
AgentTaskState.AWAITING_USER_INPUT,
|
||||
const IgnoreTaskStateMap: { [k: string]: AgentState[] } = {
|
||||
[AgentState.PAUSED]: [
|
||||
AgentState.INIT,
|
||||
AgentState.PAUSED,
|
||||
AgentState.STOPPED,
|
||||
AgentState.FINISHED,
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
],
|
||||
[AgentTaskAction.RESUME]: [
|
||||
AgentTaskState.INIT,
|
||||
AgentTaskState.RUNNING,
|
||||
AgentTaskState.STOPPED,
|
||||
AgentTaskState.FINISHED,
|
||||
AgentTaskState.AWAITING_USER_INPUT,
|
||||
[AgentState.RUNNING]: [
|
||||
AgentState.INIT,
|
||||
AgentState.RUNNING,
|
||||
AgentState.STOPPED,
|
||||
AgentState.FINISHED,
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
],
|
||||
[AgentTaskAction.STOP]: [
|
||||
AgentTaskState.INIT,
|
||||
AgentTaskState.STOPPED,
|
||||
AgentTaskState.FINISHED,
|
||||
[AgentState.STOPPED]: [
|
||||
AgentState.INIT,
|
||||
AgentState.STOPPED,
|
||||
AgentState.FINISHED,
|
||||
],
|
||||
};
|
||||
|
||||
interface ButtonProps {
|
||||
isDisabled: boolean;
|
||||
content: string;
|
||||
action: AgentTaskAction;
|
||||
handleAction: (action: AgentTaskAction) => void;
|
||||
action: AgentState;
|
||||
handleAction: (action: AgentState) => void;
|
||||
large?: boolean;
|
||||
}
|
||||
|
||||
@ -75,53 +67,53 @@ ActionButton.defaultProps = {
|
||||
};
|
||||
|
||||
function AgentControlBar() {
|
||||
const { curTaskState } = useSelector((state: RootState) => state.agent);
|
||||
const [desiredState, setDesiredState] = React.useState(AgentTaskState.INIT);
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
const [desiredState, setDesiredState] = React.useState(AgentState.INIT);
|
||||
const [isLoading, setIsLoading] = React.useState(false);
|
||||
|
||||
const handleAction = (action: AgentTaskAction) => {
|
||||
if (IgnoreTaskStateMap[action].includes(curTaskState)) {
|
||||
const handleAction = (action: AgentState) => {
|
||||
if (IgnoreTaskStateMap[action].includes(curAgentState)) {
|
||||
return;
|
||||
}
|
||||
|
||||
let act = action;
|
||||
|
||||
if (act === AgentTaskAction.STOP) {
|
||||
act = AgentTaskAction.STOP;
|
||||
if (act === AgentState.STOPPED) {
|
||||
act = AgentState.STOPPED;
|
||||
clearMsgs().then().catch();
|
||||
store.dispatch(clearMessages());
|
||||
} else {
|
||||
setIsLoading(true);
|
||||
}
|
||||
|
||||
setDesiredState(TaskStateActionMap[act]);
|
||||
changeTaskState(act);
|
||||
setDesiredState(act);
|
||||
changeAgentState(act);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (curTaskState === desiredState) {
|
||||
if (curTaskState === AgentTaskState.STOPPED) {
|
||||
if (curAgentState === desiredState) {
|
||||
if (curAgentState === AgentState.STOPPED) {
|
||||
clearMsgs().then().catch();
|
||||
store.dispatch(clearMessages());
|
||||
}
|
||||
setIsLoading(false);
|
||||
} else if (curTaskState === AgentTaskState.RUNNING) {
|
||||
setDesiredState(AgentTaskState.RUNNING);
|
||||
} else if (curAgentState === AgentState.RUNNING) {
|
||||
setDesiredState(AgentState.RUNNING);
|
||||
}
|
||||
// We only want to run this effect when curTaskState changes
|
||||
// We only want to run this effect when curAgentState changes
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [curTaskState]);
|
||||
}, [curAgentState]);
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-3">
|
||||
{curTaskState === AgentTaskState.PAUSED ? (
|
||||
{curAgentState === AgentState.PAUSED ? (
|
||||
<ActionButton
|
||||
isDisabled={
|
||||
isLoading ||
|
||||
IgnoreTaskStateMap[AgentTaskAction.RESUME].includes(curTaskState)
|
||||
IgnoreTaskStateMap[AgentState.RUNNING].includes(curAgentState)
|
||||
}
|
||||
content="Resume the agent task"
|
||||
action={AgentTaskAction.RESUME}
|
||||
action={AgentState.RUNNING}
|
||||
handleAction={handleAction}
|
||||
large
|
||||
>
|
||||
@ -131,10 +123,10 @@ function AgentControlBar() {
|
||||
<ActionButton
|
||||
isDisabled={
|
||||
isLoading ||
|
||||
IgnoreTaskStateMap[AgentTaskAction.PAUSE].includes(curTaskState)
|
||||
IgnoreTaskStateMap[AgentState.PAUSED].includes(curAgentState)
|
||||
}
|
||||
content="Pause the agent task"
|
||||
action={AgentTaskAction.PAUSE}
|
||||
action={AgentState.PAUSED}
|
||||
handleAction={handleAction}
|
||||
large
|
||||
>
|
||||
@ -144,7 +136,7 @@ function AgentControlBar() {
|
||||
<ActionButton
|
||||
isDisabled={isLoading}
|
||||
content="Restart a new agent task"
|
||||
action={AgentTaskAction.STOP}
|
||||
action={AgentState.STOPPED}
|
||||
handleAction={handleAction}
|
||||
>
|
||||
<ArrowIcon />
|
||||
|
||||
@ -3,31 +3,31 @@ import { useTranslation } from "react-i18next";
|
||||
import { useSelector } from "react-redux";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { RootState } from "#/store";
|
||||
import AgentTaskState from "#/types/AgentTaskState";
|
||||
import AgentState from "#/types/AgentState";
|
||||
|
||||
const AgentStatusMap: { [k: string]: { message: string; indicator: string } } =
|
||||
{
|
||||
[AgentTaskState.INIT]: {
|
||||
[AgentState.INIT]: {
|
||||
message: "Agent is initialized, waiting for task...",
|
||||
indicator: "bg-blue-500",
|
||||
},
|
||||
[AgentTaskState.RUNNING]: {
|
||||
[AgentState.RUNNING]: {
|
||||
message: "Agent is running task...",
|
||||
indicator: "bg-green-500",
|
||||
},
|
||||
[AgentTaskState.AWAITING_USER_INPUT]: {
|
||||
[AgentState.AWAITING_USER_INPUT]: {
|
||||
message: "Agent is awaiting user input...",
|
||||
indicator: "bg-orange-500",
|
||||
},
|
||||
[AgentTaskState.PAUSED]: {
|
||||
[AgentState.PAUSED]: {
|
||||
message: "Agent has paused.",
|
||||
indicator: "bg-yellow-500",
|
||||
},
|
||||
[AgentTaskState.STOPPED]: {
|
||||
[AgentState.STOPPED]: {
|
||||
message: "Agent has stopped.",
|
||||
indicator: "bg-red-500",
|
||||
},
|
||||
[AgentTaskState.FINISHED]: {
|
||||
[AgentState.FINISHED]: {
|
||||
message: "Agent has finished the task.",
|
||||
indicator: "bg-green-500",
|
||||
},
|
||||
@ -35,8 +35,7 @@ const AgentStatusMap: { [k: string]: { message: string; indicator: string } } =
|
||||
|
||||
function AgentStatusBar() {
|
||||
const { t } = useTranslation();
|
||||
const { initialized } = useSelector((state: RootState) => state.task);
|
||||
const { curTaskState } = useSelector((state: RootState) => state.agent);
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
|
||||
// TODO: Extend the agent status, e.g.:
|
||||
// - Agent is typing
|
||||
@ -46,13 +45,13 @@ function AgentStatusBar() {
|
||||
// - Agent is not available
|
||||
return (
|
||||
<div className="flex items-center">
|
||||
{initialized ? (
|
||||
{curAgentState !== AgentState.LOADING ? (
|
||||
<>
|
||||
<div
|
||||
className={`w-3 h-3 mr-2 rounded-full animate-pulse ${AgentStatusMap[curTaskState].indicator}`}
|
||||
className={`w-3 h-3 mr-2 rounded-full animate-pulse ${AgentStatusMap[curAgentState].indicator}`}
|
||||
/>
|
||||
<span className="text-sm text-stone-400">
|
||||
{AgentStatusMap[curTaskState].message}
|
||||
{AgentStatusMap[curAgentState].message}
|
||||
</span>
|
||||
</>
|
||||
) : (
|
||||
|
||||
@ -8,7 +8,7 @@ import ChatInterface from "./ChatInterface";
|
||||
import Socket from "#/services/socket";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { addAssistantMessage } from "#/state/chatSlice";
|
||||
import AgentTaskState from "#/types/AgentTaskState";
|
||||
import AgentState from "#/types/AgentState";
|
||||
|
||||
// avoid typing side-effect
|
||||
vi.mock("#/hooks/useTyping", () => ({
|
||||
@ -25,7 +25,6 @@ const renderChatInterface = () =>
|
||||
renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
task: {
|
||||
initialized: true,
|
||||
completed: false,
|
||||
},
|
||||
},
|
||||
@ -38,7 +37,16 @@ describe("ChatInterface", () => {
|
||||
});
|
||||
|
||||
it("should render the new message the user has typed", async () => {
|
||||
renderChatInterface();
|
||||
renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
task: {
|
||||
completed: false,
|
||||
},
|
||||
agent: {
|
||||
curAgentState: AgentState.INIT,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const input = screen.getByRole("textbox");
|
||||
|
||||
@ -73,11 +81,10 @@ describe("ChatInterface", () => {
|
||||
renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
task: {
|
||||
initialized: true,
|
||||
completed: false,
|
||||
},
|
||||
agent: {
|
||||
curTaskState: AgentTaskState.INIT,
|
||||
curAgentState: AgentState.INIT,
|
||||
},
|
||||
},
|
||||
});
|
||||
@ -95,11 +102,10 @@ describe("ChatInterface", () => {
|
||||
renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
task: {
|
||||
initialized: true,
|
||||
completed: false,
|
||||
},
|
||||
agent: {
|
||||
curTaskState: AgentTaskState.AWAITING_USER_INPUT,
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
},
|
||||
},
|
||||
});
|
||||
@ -110,8 +116,8 @@ describe("ChatInterface", () => {
|
||||
});
|
||||
|
||||
const event = {
|
||||
action: ActionType.USER_MESSAGE,
|
||||
args: { message: "my message" },
|
||||
action: ActionType.MESSAGE,
|
||||
args: { content: "my message" },
|
||||
};
|
||||
expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event));
|
||||
});
|
||||
@ -120,9 +126,11 @@ describe("ChatInterface", () => {
|
||||
renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
task: {
|
||||
initialized: false,
|
||||
completed: false,
|
||||
},
|
||||
agent: {
|
||||
curAgentState: AgentState.LOADING,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@ -1,32 +1,19 @@
|
||||
import React from "react";
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import { useSelector } from "react-redux";
|
||||
import { IoMdChatbubbles } from "react-icons/io";
|
||||
import ChatInput from "./ChatInput";
|
||||
import Chat from "./Chat";
|
||||
import { RootState } from "#/store";
|
||||
import AgentTaskState from "#/types/AgentTaskState";
|
||||
import { addUserMessage } from "#/state/chatSlice";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import Socket from "#/services/socket";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import { sendChatMessage } from "#/services/chatService";
|
||||
|
||||
function ChatInterface() {
|
||||
const { initialized } = useSelector((state: RootState) => state.task);
|
||||
const { messages } = useSelector((state: RootState) => state.chat);
|
||||
const { curTaskState } = useSelector((state: RootState) => state.agent);
|
||||
|
||||
const dispatch = useDispatch();
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
|
||||
const handleSendMessage = (content: string) => {
|
||||
dispatch(addUserMessage(content));
|
||||
|
||||
let event;
|
||||
if (curTaskState === AgentTaskState.INIT) {
|
||||
event = { action: ActionType.START, args: { task: content } };
|
||||
} else {
|
||||
event = { action: ActionType.USER_MESSAGE, args: { message: content } };
|
||||
}
|
||||
|
||||
Socket.send(JSON.stringify(event));
|
||||
const isTask = curAgentState === AgentState.INIT;
|
||||
sendChatMessage(content, isTask);
|
||||
};
|
||||
|
||||
return (
|
||||
@ -42,7 +29,10 @@ function ChatInterface() {
|
||||
{/* Fade between messages and input */}
|
||||
<div className="absolute bottom-0 left-0 right-0 h-4 bg-gradient-to-b from-transparent to-neutral-800" />
|
||||
</div>
|
||||
<ChatInput disabled={!initialized} onSendMessage={handleSendMessage} />
|
||||
<ChatInput
|
||||
disabled={curAgentState === AgentState.LOADING}
|
||||
onSendMessage={handleSendMessage}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ import { act, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import React from "react";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import AgentTaskState from "#/types/AgentTaskState";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import { Settings } from "#/services/settings";
|
||||
import SettingsForm from "./SettingsForm";
|
||||
|
||||
@ -80,7 +80,7 @@ describe("SettingsForm", () => {
|
||||
onLanguageChange={onLanguageChangeMock}
|
||||
onAPIKeyChange={onAPIKeyChangeMock}
|
||||
/>,
|
||||
{ preloadedState: { agent: { curTaskState: AgentTaskState.RUNNING } } },
|
||||
{ preloadedState: { agent: { curAgentState: AgentState.RUNNING } } },
|
||||
);
|
||||
const modelInput = screen.getByRole("combobox", { name: "model" });
|
||||
const agentInput = screen.getByRole("combobox", { name: "agent" });
|
||||
|
||||
@ -6,7 +6,7 @@ import { useSelector } from "react-redux";
|
||||
import { AvailableLanguages } from "../../../i18n";
|
||||
import { I18nKey } from "../../../i18n/declaration";
|
||||
import { RootState } from "../../../store";
|
||||
import AgentTaskState from "../../../types/AgentTaskState";
|
||||
import AgentState from "../../../types/AgentState";
|
||||
import { AutocompleteCombobox } from "./AutocompleteCombobox";
|
||||
import { Settings } from "#/services/settings";
|
||||
|
||||
@ -31,21 +31,21 @@ function SettingsForm({
|
||||
onLanguageChange,
|
||||
}: SettingsFormProps) {
|
||||
const { t } = useTranslation();
|
||||
const { curTaskState } = useSelector((state: RootState) => state.agent);
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
const [disabled, setDisabled] = React.useState<boolean>(false);
|
||||
const { isOpen: isVisible, onOpenChange: onVisibleChange } = useDisclosure();
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
curTaskState === AgentTaskState.RUNNING ||
|
||||
curTaskState === AgentTaskState.PAUSED ||
|
||||
curTaskState === AgentTaskState.AWAITING_USER_INPUT
|
||||
curAgentState === AgentState.RUNNING ||
|
||||
curAgentState === AgentState.PAUSED ||
|
||||
curAgentState === AgentState.AWAITING_USER_INPUT
|
||||
) {
|
||||
setDisabled(true);
|
||||
} else {
|
||||
setDisabled(false);
|
||||
}
|
||||
}, [curTaskState, setDisabled]);
|
||||
}, [curAgentState, setDisabled]);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import { changeTaskState } from "#/state/agentSlice";
|
||||
import { setScreenshotSrc, setUrl } from "#/state/browserSlice";
|
||||
import { addAssistantMessage } from "#/state/chatSlice";
|
||||
import { setCode, updatePath } from "#/state/codeSlice";
|
||||
import { appendInput } from "#/state/commandSlice";
|
||||
import { appendJupyterInput } from "#/state/jupyterSlice";
|
||||
import { setPlan } from "#/state/planSlice";
|
||||
import { setInitialized } from "#/state/taskSlice";
|
||||
import store from "#/store";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { ActionMessage } from "#/types/Message";
|
||||
@ -14,9 +12,6 @@ import { handleObservationMessage } from "./observations";
|
||||
import { getPlan } from "./planService";
|
||||
|
||||
const messageActions = {
|
||||
[ActionType.INIT]: () => {
|
||||
store.dispatch(setInitialized(true));
|
||||
},
|
||||
[ActionType.BROWSE]: (message: ActionMessage) => {
|
||||
const { url, screenshotSrc } = message.args;
|
||||
store.dispatch(setUrl(url));
|
||||
@ -54,9 +49,6 @@ const messageActions = {
|
||||
[ActionType.MODIFY_TASK]: () => {
|
||||
getPlan().then((fetchedPlan) => store.dispatch(setPlan(fetchedPlan)));
|
||||
},
|
||||
[ActionType.CHANGE_TASK_STATE]: (message: ActionMessage) => {
|
||||
store.dispatch(changeTaskState(message.args.task_state));
|
||||
},
|
||||
};
|
||||
|
||||
export function handleActionMessage(message: ActionMessage) {
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
import { setInitialized } from "#/state/taskSlice";
|
||||
import store from "#/store";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { initializeAgent } from "./agent";
|
||||
import { Settings } from "./settings";
|
||||
import Socket from "./socket";
|
||||
|
||||
const sendSpy = vi.spyOn(Socket, "send");
|
||||
const dispatchSpy = vi.spyOn(store, "dispatch");
|
||||
|
||||
describe("initializeAgent", () => {
|
||||
it("Should initialize the agent with the current settings", () => {
|
||||
@ -27,6 +24,5 @@ describe("initializeAgent", () => {
|
||||
initializeAgent(settings);
|
||||
|
||||
expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(setInitialized(false));
|
||||
});
|
||||
});
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import { setInitialized } from "#/state/taskSlice";
|
||||
import store from "#/store";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { Settings } from "./settings";
|
||||
import Socket from "./socket";
|
||||
@ -11,7 +9,5 @@ import Socket from "./socket";
|
||||
export const initializeAgent = (settings: Settings) => {
|
||||
const event = { action: ActionType.INIT, args: settings };
|
||||
const eventString = JSON.stringify(event);
|
||||
|
||||
store.dispatch(setInitialized(false));
|
||||
Socket.send(eventString);
|
||||
};
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import ActionType from "#/types/ActionType";
|
||||
import AgentTaskAction from "#/types/AgentTaskAction";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import Socket from "./socket";
|
||||
|
||||
export function changeTaskState(message: AgentTaskAction): void {
|
||||
export function changeAgentState(message: AgentState): void {
|
||||
const eventString = JSON.stringify({
|
||||
action: ActionType.CHANGE_TASK_STATE,
|
||||
args: { task_state_action: message },
|
||||
action: ActionType.CHANGE_AGENT_STATE,
|
||||
args: { agent_state: message },
|
||||
});
|
||||
Socket.send(eventString);
|
||||
}
|
||||
|
||||
@ -11,7 +11,7 @@ export function sendChatMessage(message: string, isTask: boolean = true): void {
|
||||
if (isTask) {
|
||||
event = { action: ActionType.START, args: { task: message } };
|
||||
} else {
|
||||
event = { action: ActionType.USER_MESSAGE, args: { message } };
|
||||
event = { action: ActionType.MESSAGE, args: { content: message } };
|
||||
}
|
||||
const eventString = JSON.stringify(event);
|
||||
Socket.send(eventString);
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { changeAgentState } from "#/state/agentSlice";
|
||||
import { setUrl, setScreenshotSrc } from "#/state/browserSlice";
|
||||
import store from "#/store";
|
||||
import { ObservationMessage } from "#/types/Message";
|
||||
@ -23,6 +24,9 @@ export function handleObservationMessage(message: ObservationMessage) {
|
||||
store.dispatch(setUrl(message.extras.url));
|
||||
}
|
||||
break;
|
||||
case ObservationType.AGENT_STATE_CHANGED:
|
||||
store.dispatch(changeAgentState(message.extras.agent_state));
|
||||
break;
|
||||
default:
|
||||
store.dispatch(addAssistantMessage(message.message));
|
||||
break;
|
||||
|
||||
@ -1,18 +1,18 @@
|
||||
import { createSlice } from "@reduxjs/toolkit";
|
||||
import AgentTaskState from "#/types/AgentTaskState";
|
||||
import AgentState from "#/types/AgentState";
|
||||
|
||||
export const agentSlice = createSlice({
|
||||
name: "agent",
|
||||
initialState: {
|
||||
curTaskState: AgentTaskState.INIT,
|
||||
curAgentState: AgentState.LOADING,
|
||||
},
|
||||
reducers: {
|
||||
changeTaskState: (state, action) => {
|
||||
state.curTaskState = action.payload;
|
||||
changeAgentState: (state, action) => {
|
||||
state.curAgentState = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { changeTaskState } = agentSlice.actions;
|
||||
export const { changeAgentState } = agentSlice.actions;
|
||||
|
||||
export default agentSlice.reducer;
|
||||
|
||||
@ -3,19 +3,15 @@ import { createSlice } from "@reduxjs/toolkit";
|
||||
export const taskSlice = createSlice({
|
||||
name: "task",
|
||||
initialState: {
|
||||
initialized: false,
|
||||
completed: false,
|
||||
},
|
||||
reducers: {
|
||||
setInitialized: (state, action) => {
|
||||
state.initialized = action.payload;
|
||||
},
|
||||
setCompleted: (state, action) => {
|
||||
state.completed = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { setInitialized, setCompleted } = taskSlice.actions;
|
||||
export const { setCompleted } = taskSlice.actions;
|
||||
|
||||
export default taskSlice.reducer;
|
||||
|
||||
@ -2,12 +2,12 @@ enum ActionType {
|
||||
// Initializes the agent. Only sent by client.
|
||||
INIT = "initialize",
|
||||
|
||||
// Sends a message from the user
|
||||
USER_MESSAGE = "user_message",
|
||||
|
||||
// Starts a new development task
|
||||
// Starts a new development task.
|
||||
START = "start",
|
||||
|
||||
// Represents a message from the user or agent.
|
||||
MESSAGE = "message",
|
||||
|
||||
// Reads the contents of a file.
|
||||
READ = "read",
|
||||
|
||||
@ -45,7 +45,8 @@ enum ActionType {
|
||||
// Updates a task in the plan.
|
||||
MODIFY_TASK = "modify_task",
|
||||
|
||||
CHANGE_TASK_STATE = "change_task_state",
|
||||
// Changes the state of the agent, e.g. to paused or running
|
||||
CHANGE_AGENT_STATE = "change_agent_state",
|
||||
}
|
||||
|
||||
export default ActionType;
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
enum AgentTaskState {
|
||||
enum AgentState {
|
||||
LOADING = "loading",
|
||||
INIT = "init",
|
||||
RUNNING = "running",
|
||||
AWAITING_USER_INPUT = "awaiting_user_input",
|
||||
@ -8,4 +9,4 @@ enum AgentTaskState {
|
||||
ERROR = "error",
|
||||
}
|
||||
|
||||
export default AgentTaskState;
|
||||
export default AgentState;
|
||||
@ -1,15 +0,0 @@
|
||||
enum AgentTaskAction {
|
||||
// Starts the task.
|
||||
START = "start",
|
||||
|
||||
// Pauses the task.
|
||||
PAUSE = "pause",
|
||||
|
||||
// Resumes the task.
|
||||
RESUME = "resume",
|
||||
|
||||
// Stops the task.
|
||||
STOP = "stop",
|
||||
}
|
||||
|
||||
export default AgentTaskAction;
|
||||
@ -16,6 +16,9 @@ enum ObservationType {
|
||||
|
||||
// A message from the user
|
||||
CHAT = "chat",
|
||||
|
||||
// Agent state has changed
|
||||
AGENT_STATE_CHANGED = "agent_state_changed",
|
||||
}
|
||||
|
||||
export default ObservationType;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Callable, List, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from opendevin.controller.action_manager import ActionManager
|
||||
@ -14,23 +14,27 @@ from opendevin.core.exceptions import (
|
||||
MaxCharsExceedError,
|
||||
)
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import TaskState
|
||||
from opendevin.core.schema import AgentState
|
||||
from opendevin.core.schema.config import ConfigType
|
||||
from opendevin.events.action import (
|
||||
Action,
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
AgentTalkAction,
|
||||
ChangeAgentStateAction,
|
||||
MessageAction,
|
||||
NullAction,
|
||||
TaskStateChangedAction,
|
||||
)
|
||||
from opendevin.events.event import Event
|
||||
from opendevin.events.observation import (
|
||||
AgentDelegateObservation,
|
||||
AgentErrorObservation,
|
||||
AgentStateChangedObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
UserMessageObservation,
|
||||
)
|
||||
from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
|
||||
from opendevin.runtime import DockerSSHBox
|
||||
from opendevin.runtime.browser.browser_env import BrowserEnv
|
||||
|
||||
@ -43,22 +47,21 @@ class AgentController:
|
||||
agent: Agent
|
||||
max_iterations: int
|
||||
action_manager: ActionManager
|
||||
callbacks: List[Callable]
|
||||
browser: BrowserEnv
|
||||
|
||||
event_stream: EventStream
|
||||
agent_task: Optional[asyncio.Task] = None
|
||||
delegate: 'AgentController | None' = None
|
||||
state: State | None = None
|
||||
|
||||
_task_state: TaskState = TaskState.INIT
|
||||
_agent_state: AgentState = AgentState.LOADING
|
||||
_cur_step: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: Agent,
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
max_iterations: int = MAX_ITERATIONS,
|
||||
max_chars: int = MAX_CHARS,
|
||||
callbacks: List[Callable] = [],
|
||||
):
|
||||
"""Initializes a new instance of the AgentController class.
|
||||
|
||||
@ -67,14 +70,16 @@ class AgentController:
|
||||
sid: The session ID of the agent.
|
||||
max_iterations: The maximum number of iterations the agent can run.
|
||||
max_chars: The maximum number of characters the agent can output.
|
||||
callbacks: A list of callback functions to run after each action.
|
||||
"""
|
||||
self.id = sid
|
||||
self.agent = agent
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
self.action_manager = ActionManager(self.id)
|
||||
self.max_chars = max_chars
|
||||
self.callbacks = callbacks
|
||||
# Initialize agent-required plugins for sandbox (if any)
|
||||
self.action_manager.init_sandbox_plugins(agent.sandbox_plugins)
|
||||
# Initialize browser environment
|
||||
@ -87,7 +92,12 @@ class AgentController:
|
||||
'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
|
||||
)
|
||||
|
||||
self._await_user_message_queue: asyncio.Queue = asyncio.Queue()
|
||||
async def close(self):
|
||||
if self.agent_task is not None:
|
||||
self.agent_task.cancel()
|
||||
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
|
||||
self.action_manager.sandbox.close()
|
||||
await self.set_agent_state_to(AgentState.STOPPED)
|
||||
|
||||
def update_state_for_step(self, i):
|
||||
if self.state is None:
|
||||
@ -100,9 +110,14 @@ class AgentController:
|
||||
return
|
||||
self.state.updated_info = []
|
||||
|
||||
def add_history(self, action: Action, observation: Observation):
|
||||
async def add_error_to_history(self, message: str):
|
||||
await self.add_history(NullAction(), AgentErrorObservation(message))
|
||||
|
||||
async def add_history(
|
||||
self, action: Action, observation: Observation, add_to_stream=True
|
||||
):
|
||||
if self.state is None:
|
||||
return
|
||||
raise ValueError('Added history while state was None')
|
||||
if not isinstance(action, Action):
|
||||
raise TypeError(
|
||||
f'action must be an instance of Action, got {type(action).__name__} instead'
|
||||
@ -113,12 +128,15 @@ class AgentController:
|
||||
)
|
||||
self.state.history.append((action, observation))
|
||||
self.state.updated_info.append((action, observation))
|
||||
if add_to_stream:
|
||||
await self.event_stream.add_event(action, EventSource.AGENT)
|
||||
await self.event_stream.add_event(observation, EventSource.AGENT)
|
||||
|
||||
async def _run(self):
|
||||
if self.state is None:
|
||||
return
|
||||
|
||||
if self._task_state != TaskState.RUNNING:
|
||||
if self._agent_state != AgentState.RUNNING:
|
||||
raise ValueError('Task is not in running state')
|
||||
|
||||
for i in range(self._cur_step, self.max_iterations):
|
||||
@ -126,117 +144,79 @@ class AgentController:
|
||||
try:
|
||||
finished = await self.step(i)
|
||||
if finished:
|
||||
self._task_state = TaskState.FINISHED
|
||||
await self.set_agent_state_to(AgentState.FINISHED)
|
||||
break
|
||||
except Exception:
|
||||
logger.error('Error in loop', exc_info=True)
|
||||
await self._run_callbacks(
|
||||
AgentErrorObservation(
|
||||
'Oops! Something went wrong while completing your task. You can check the logs for more info.'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
await self.add_error_to_history(
|
||||
'Oops! Something went wrong while completing your task. You can check the logs for more info.'
|
||||
)
|
||||
await self.set_task_state_to(TaskState.STOPPED)
|
||||
break
|
||||
|
||||
if self._task_state == TaskState.FINISHED:
|
||||
logger.info('Task finished by agent')
|
||||
await self.reset_task()
|
||||
break
|
||||
elif self._task_state == TaskState.STOPPED:
|
||||
logger.info('Task stopped by user')
|
||||
await self.reset_task()
|
||||
break
|
||||
elif self._task_state == TaskState.PAUSED:
|
||||
logger.info('Task paused')
|
||||
self._cur_step = i + 1
|
||||
await self.notify_task_state_changed()
|
||||
break
|
||||
|
||||
if self._is_stuck():
|
||||
logger.info('Loop detected, stopping task')
|
||||
observation = AgentErrorObservation(
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
await self.add_error_to_history(
|
||||
'I got stuck into a loop, the task has stopped.'
|
||||
)
|
||||
await self._run_callbacks(observation)
|
||||
await self.set_task_state_to(TaskState.STOPPED)
|
||||
break
|
||||
await asyncio.sleep(
|
||||
0.001
|
||||
) # Give back control for a tick, so other async stuff can run
|
||||
|
||||
async def setup_task(self, task: str, inputs: dict = {}):
|
||||
"""Sets up the agent controller with a task."""
|
||||
self._task_state = TaskState.RUNNING
|
||||
await self.notify_task_state_changed()
|
||||
await self.set_agent_state_to(AgentState.INIT)
|
||||
self.state = State(Plan(task))
|
||||
self.state.inputs = inputs
|
||||
|
||||
async def start(self, task: str):
|
||||
"""Starts the agent controller with a task.
|
||||
If task already run before, it will continue from the last step.
|
||||
"""
|
||||
await self.setup_task(task)
|
||||
await self._run()
|
||||
|
||||
async def resume(self):
|
||||
if self.state is None:
|
||||
raise ValueError('No task to resume')
|
||||
|
||||
self._task_state = TaskState.RUNNING
|
||||
await self.notify_task_state_changed()
|
||||
|
||||
await self._run()
|
||||
async def on_event(self, event: Event):
|
||||
if isinstance(event, ChangeAgentStateAction):
|
||||
await self.set_agent_state_to(event.agent_state) # type: ignore
|
||||
elif isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
# FIXME: we're hacking a message action into a user message observation, for the benefit of CodeAct
|
||||
await self.add_history(
|
||||
self._pending_talk_action,
|
||||
UserMessageObservation(event.content),
|
||||
add_to_stream=False,
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
|
||||
async def reset_task(self):
|
||||
if self.agent_task is not None:
|
||||
self.agent_task.cancel()
|
||||
self.state = None
|
||||
self._cur_step = 0
|
||||
self._task_state = TaskState.INIT
|
||||
self.agent.reset()
|
||||
await self.notify_task_state_changed()
|
||||
|
||||
async def set_task_state_to(self, state: TaskState):
|
||||
self._task_state = state
|
||||
if state == TaskState.STOPPED:
|
||||
await self.reset_task()
|
||||
logger.info(f'Task state set to {state}')
|
||||
|
||||
def get_task_state(self):
|
||||
"""Returns the current state of the agent task."""
|
||||
return self._task_state
|
||||
|
||||
async def notify_task_state_changed(self):
|
||||
await self._run_callbacks(TaskStateChangedAction(self._task_state))
|
||||
|
||||
async def add_user_message(self, message: UserMessageObservation):
|
||||
if self.state is None:
|
||||
async def set_agent_state_to(self, new_state: AgentState):
|
||||
logger.info(f'Setting agent state from {self._agent_state} to {new_state}')
|
||||
if new_state == self._agent_state:
|
||||
return
|
||||
|
||||
if self._task_state == TaskState.AWAITING_USER_INPUT:
|
||||
self._await_user_message_queue.put_nowait(message)
|
||||
self._agent_state = new_state
|
||||
if new_state == AgentState.RUNNING:
|
||||
self.agent_task = asyncio.create_task(self._run())
|
||||
elif (
|
||||
new_state == AgentState.PAUSED
|
||||
or new_state == AgentState.AWAITING_USER_INPUT
|
||||
):
|
||||
self._cur_step += 1
|
||||
if self.agent_task is not None:
|
||||
self.agent_task.cancel()
|
||||
elif new_state == AgentState.STOPPED:
|
||||
await self.reset_task()
|
||||
elif new_state == AgentState.FINISHED:
|
||||
await self.reset_task()
|
||||
|
||||
# set the task state to running
|
||||
self._task_state = TaskState.RUNNING
|
||||
await self.notify_task_state_changed()
|
||||
await self.event_stream.add_event(
|
||||
AgentStateChangedObservation('', self._agent_state), EventSource.AGENT
|
||||
)
|
||||
|
||||
elif self._task_state == TaskState.RUNNING:
|
||||
self.add_history(NullAction(), message)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Task (state: {self._task_state}) is not in a state to add user message'
|
||||
)
|
||||
|
||||
async def wait_for_user_input(self) -> UserMessageObservation:
|
||||
self._task_state = TaskState.AWAITING_USER_INPUT
|
||||
await self.notify_task_state_changed()
|
||||
# wait for the next user message
|
||||
if len(self.callbacks) == 0:
|
||||
logger.info(
|
||||
'Use STDIN to request user message as no callbacks are registered',
|
||||
extra={'msg_type': 'INFO'},
|
||||
)
|
||||
message = input('Request user input [type /exit to stop interaction] >> ')
|
||||
user_message_observation = UserMessageObservation(message)
|
||||
else:
|
||||
user_message_observation = await self._await_user_message_queue.get()
|
||||
self._await_user_message_queue.task_done()
|
||||
return user_message_observation
|
||||
def get_agent_state(self):
|
||||
"""Returns the current state of the agent task."""
|
||||
return self._agent_state
|
||||
|
||||
async def start_delegate(self, action: AgentDelegateAction):
|
||||
AgentCls: Type[Agent] = Agent.get_cls(action.agent)
|
||||
@ -244,9 +224,9 @@ class AgentController:
|
||||
self.delegate = AgentController(
|
||||
sid=self.id + '-delegate',
|
||||
agent=agent,
|
||||
event_stream=self.event_stream,
|
||||
max_iterations=self.max_iterations,
|
||||
max_chars=self.max_chars,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
task = action.inputs.get('task') or ''
|
||||
await self.delegate.setup_task(task, action.inputs)
|
||||
@ -259,7 +239,7 @@ class AgentController:
|
||||
if delegate_done:
|
||||
outputs = self.delegate.state.outputs if self.delegate.state else {}
|
||||
obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
|
||||
self.add_history(NullAction(), obs)
|
||||
await self.add_history(NullAction(), obs)
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
return False
|
||||
@ -272,8 +252,7 @@ class AgentController:
|
||||
|
||||
log_obs = self.action_manager.get_background_obs()
|
||||
for obs in log_obs:
|
||||
self.add_history(NullAction(), obs)
|
||||
await self._run_callbacks(obs)
|
||||
await self.add_history(NullAction(), obs)
|
||||
logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
|
||||
|
||||
self.update_state_for_step(i)
|
||||
@ -289,14 +268,10 @@ class AgentController:
|
||||
|
||||
self.update_state_after_step()
|
||||
|
||||
await self._run_callbacks(action)
|
||||
|
||||
# whether to await for user messages
|
||||
if isinstance(action, AgentTalkAction):
|
||||
# await for the next user messages
|
||||
user_message_observation = await self.wait_for_user_input()
|
||||
logger.info(user_message_observation, extra={'msg_type': 'OBSERVATION'})
|
||||
self.add_history(action, user_message_observation)
|
||||
self._pending_talk_action = action
|
||||
await self.event_stream.add_event(action, EventSource.AGENT)
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
return False
|
||||
|
||||
finished = isinstance(action, AgentFinishAction)
|
||||
@ -311,23 +286,9 @@ class AgentController:
|
||||
if not isinstance(observation, NullObservation):
|
||||
logger.info(observation, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
self.add_history(action, observation)
|
||||
await self._run_callbacks(observation)
|
||||
await self.add_history(action, observation)
|
||||
return False
|
||||
|
||||
async def _run_callbacks(self, event):
|
||||
if event is None:
|
||||
return
|
||||
for callback in self.callbacks:
|
||||
idx = self.callbacks.index(callback)
|
||||
try:
|
||||
await callback(event)
|
||||
except Exception as e:
|
||||
logger.exception(f'Callback error: {e}, idx: {idx}')
|
||||
await asyncio.sleep(
|
||||
0.001
|
||||
) # Give back control for a tick, so we can await in callbacks
|
||||
|
||||
def get_state(self):
|
||||
return self.state
|
||||
|
||||
|
||||
@ -6,6 +6,11 @@ import agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from opendevin.controller import AgentController
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.core.config import args
|
||||
from opendevin.core.schema import AgentState
|
||||
from opendevin.events.action import ChangeAgentStateAction, MessageAction
|
||||
from opendevin.events.event import Event
|
||||
from opendevin.events.observation import AgentStateChangedObservation
|
||||
from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
|
||||
from opendevin.llm.llm import LLM
|
||||
|
||||
|
||||
@ -41,11 +46,33 @@ async def main(task_str: str = ''):
|
||||
llm = LLM(args.model_name)
|
||||
AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
|
||||
agent = AgentCls(llm=llm)
|
||||
event_stream = EventStream()
|
||||
controller = AgentController(
|
||||
agent=agent, max_iterations=args.max_iterations, max_chars=args.max_chars
|
||||
agent=agent,
|
||||
max_iterations=args.max_iterations,
|
||||
max_chars=args.max_chars,
|
||||
event_stream=event_stream,
|
||||
)
|
||||
|
||||
await controller.start(task)
|
||||
await controller.setup_task(task)
|
||||
await event_stream.add_event(
|
||||
ChangeAgentStateAction(agent_state=AgentState.RUNNING), EventSource.USER
|
||||
)
|
||||
|
||||
async def on_event(event: Event):
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state == AgentState.AWAITING_USER_INPUT:
|
||||
message = input('Request user input >> ')
|
||||
action = MessageAction(content=message)
|
||||
await event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
|
||||
while controller.get_agent_state() not in [
|
||||
AgentState.FINISHED,
|
||||
AgentState.ERROR,
|
||||
AgentState.STOPPED,
|
||||
]:
|
||||
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
from .action import ActionType
|
||||
from .agent import AgentState
|
||||
from .config import ConfigType
|
||||
from .observation import ObservationType
|
||||
from .task import TaskState, TaskStateAction
|
||||
|
||||
__all__ = [
|
||||
'ActionType',
|
||||
'ObservationType',
|
||||
'ConfigType',
|
||||
'TaskState',
|
||||
'TaskStateAction',
|
||||
'AgentState',
|
||||
]
|
||||
|
||||
@ -8,8 +8,8 @@ class ActionTypeSchema(BaseModel):
|
||||
"""Initializes the agent. Only sent by client.
|
||||
"""
|
||||
|
||||
USER_MESSAGE: str = Field(default='user_message')
|
||||
"""Sends a message from the user. Only sent by the client.
|
||||
MESSAGE: str = Field(default='message')
|
||||
"""Represents a message.
|
||||
"""
|
||||
|
||||
START: str = Field(default='start')
|
||||
@ -81,7 +81,7 @@ class ActionTypeSchema(BaseModel):
|
||||
"""Stops the task. Must send a start action to restart a new task.
|
||||
"""
|
||||
|
||||
CHANGE_TASK_STATE: str = Field(default='change_task_state')
|
||||
CHANGE_AGENT_STATE: str = Field(default='change_agent_state')
|
||||
|
||||
PUSH: str = Field(default='push')
|
||||
"""Push a branch to github."""
|
||||
|
||||
35
opendevin/core/schema/agent.py
Normal file
35
opendevin/core/schema/agent.py
Normal file
@ -0,0 +1,35 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AgentState(str, Enum):
|
||||
LOADING = 'loading'
|
||||
"""The agent is loading.
|
||||
"""
|
||||
|
||||
INIT = 'init'
|
||||
"""The agent is initialized.
|
||||
"""
|
||||
|
||||
RUNNING = 'running'
|
||||
"""The agent is running.
|
||||
"""
|
||||
|
||||
AWAITING_USER_INPUT = 'awaiting_user_input'
|
||||
"""The agent is awaiting user input.
|
||||
"""
|
||||
|
||||
PAUSED = 'paused'
|
||||
"""The agent is paused.
|
||||
"""
|
||||
|
||||
STOPPED = 'stopped'
|
||||
"""The agent is stopped.
|
||||
"""
|
||||
|
||||
FINISHED = 'finished'
|
||||
"""The agent is finished with the current task.
|
||||
"""
|
||||
|
||||
ERROR = 'error'
|
||||
"""An error occurred during the task.
|
||||
"""
|
||||
@ -40,5 +40,7 @@ class ObservationTypeSchema(BaseModel):
|
||||
|
||||
NULL: str = Field(default='null')
|
||||
|
||||
AGENT_STATE_CHANGED: str = Field(default='agent_state_changed')
|
||||
|
||||
|
||||
ObservationType = ObservationTypeSchema()
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
INIT = 'init'
|
||||
"""Initial state of the task.
|
||||
"""
|
||||
|
||||
RUNNING = 'running'
|
||||
"""The task is running.
|
||||
"""
|
||||
|
||||
AWAITING_USER_INPUT = 'awaiting_user_input'
|
||||
"""The task is awaiting user input.
|
||||
"""
|
||||
|
||||
PAUSED = 'paused'
|
||||
"""The task is paused.
|
||||
"""
|
||||
|
||||
STOPPED = 'stopped'
|
||||
"""The task is stopped.
|
||||
"""
|
||||
|
||||
FINISHED = 'finished'
|
||||
"""The task is finished.
|
||||
"""
|
||||
|
||||
ERROR = 'error'
|
||||
"""An error occurred during the task.
|
||||
"""
|
||||
|
||||
|
||||
class TaskStateAction(str, Enum):
|
||||
START = 'start'
|
||||
"""Starts the task.
|
||||
"""
|
||||
|
||||
PAUSE = 'pause'
|
||||
"""Pauses the task.
|
||||
"""
|
||||
|
||||
RESUME = 'resume'
|
||||
"""Resumes the task.
|
||||
"""
|
||||
|
||||
STOP = 'stop'
|
||||
"""Stops the task.
|
||||
"""
|
||||
@ -9,13 +9,15 @@ from .agent import (
|
||||
AgentSummarizeAction,
|
||||
AgentTalkAction,
|
||||
AgentThinkAction,
|
||||
ChangeAgentStateAction,
|
||||
)
|
||||
from .browse import BrowseURLAction
|
||||
from .commands import CmdKillAction, CmdRunAction, IPythonRunCellAction
|
||||
from .empty import NullAction
|
||||
from .files import FileReadAction, FileWriteAction
|
||||
from .github import GitHubPushAction
|
||||
from .tasks import AddTaskAction, ModifyTaskAction, TaskStateChangedAction
|
||||
from .message import MessageAction
|
||||
from .tasks import AddTaskAction, ModifyTaskAction
|
||||
|
||||
actions = (
|
||||
CmdKillAction,
|
||||
@ -31,8 +33,9 @@ actions = (
|
||||
AgentDelegateAction,
|
||||
AddTaskAction,
|
||||
ModifyTaskAction,
|
||||
TaskStateChangedAction,
|
||||
ChangeAgentStateAction,
|
||||
GitHubPushAction,
|
||||
MessageAction,
|
||||
)
|
||||
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
@ -74,6 +77,7 @@ __all__ = [
|
||||
'AgentSummarizeAction',
|
||||
'AddTaskAction',
|
||||
'ModifyTaskAction',
|
||||
'TaskStateChangedAction',
|
||||
'ChangeAgentStateAction',
|
||||
'IPythonRunCellAction',
|
||||
'MessageAction',
|
||||
]
|
||||
|
||||
@ -15,6 +15,19 @@ if TYPE_CHECKING:
|
||||
from opendevin.controller import AgentController
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeAgentStateAction(Action):
|
||||
"""Fake action, just to notify the client that a task state has changed."""
|
||||
|
||||
agent_state: str
|
||||
thought: str = ''
|
||||
action: str = ActionType.CHANGE_AGENT_STATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Agent state changed to {self.agent_state}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRecallAction(Action):
|
||||
query: str
|
||||
|
||||
15
opendevin/events/action/message.py
Normal file
15
opendevin/events/action/message.py
Normal file
@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from opendevin.core.schema import ActionType
|
||||
|
||||
from .action import Action
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageAction(Action):
|
||||
content: str
|
||||
action: str = ActionType.MESSAGE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
@ -43,16 +43,3 @@ class ModifyTaskAction(Action):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Set task {self.id} to {self.state}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskStateChangedAction(Action):
|
||||
"""Fake action, just to notify the client that a task state has changed."""
|
||||
|
||||
task_state: str
|
||||
thought: str = ''
|
||||
action: str = ActionType.CHANGE_TASK_STATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Task state changed to {self.task_state}'
|
||||
|
||||
@ -13,4 +13,8 @@ class Event:
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.message
|
||||
return self._message # type: ignore [attr-defined]
|
||||
|
||||
@property
|
||||
def source(self) -> str:
|
||||
return self._source # type: ignore [attr-defined]
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .agent import AgentStateChangedObservation
|
||||
from .browse import BrowserOutputObservation
|
||||
from .commands import CmdOutputObservation, IPythonRunCellObservation
|
||||
from .delegate import AgentDelegateObservation
|
||||
@ -18,9 +19,13 @@ observations = (
|
||||
AgentRecallObservation,
|
||||
AgentDelegateObservation,
|
||||
AgentErrorObservation,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
|
||||
OBSERVATION_TYPE_TO_CLASS = {observation_class.observation: observation_class for observation_class in observations} # type: ignore[attr-defined]
|
||||
OBSERVATION_TYPE_TO_CLASS = {
|
||||
observation_class.observation: observation_class # type: ignore[attr-defined]
|
||||
for observation_class in observations
|
||||
}
|
||||
|
||||
|
||||
def observation_from_dict(observation: dict) -> Observation:
|
||||
@ -29,7 +34,9 @@ def observation_from_dict(observation: dict) -> Observation:
|
||||
raise KeyError(f"'observation' key is not found in {observation=}")
|
||||
observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation['observation'])
|
||||
if observation_class is None:
|
||||
raise KeyError(f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}")
|
||||
raise KeyError(
|
||||
f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}"
|
||||
)
|
||||
observation.pop('observation')
|
||||
observation.pop('message', None)
|
||||
content = observation.pop('content', '')
|
||||
@ -49,4 +56,5 @@ __all__ = [
|
||||
'AgentMessageObservation',
|
||||
'AgentRecallObservation',
|
||||
'AgentErrorObservation',
|
||||
'AgentStateChangedObservation',
|
||||
]
|
||||
|
||||
19
opendevin/events/observation/agent.py
Normal file
19
opendevin/events/observation/agent.py
Normal file
@ -0,0 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from opendevin.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStateChangedObservation(Observation):
|
||||
"""
|
||||
This data class represents the result from delegating to another agent
|
||||
"""
|
||||
|
||||
agent_state: str
|
||||
observation: str = ObservationType.AGENT_STATE_CHANGED
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return ''
|
||||
48
opendevin/events/stream.py
Normal file
48
opendevin/events/stream.py
Normal file
@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
|
||||
from .event import Event
|
||||
|
||||
|
||||
class EventStreamSubscriber(str, Enum):
|
||||
AGENT_CONTROLLER = 'agent_controller'
|
||||
SERVER = 'server'
|
||||
RUNTIME = 'runtime'
|
||||
MAIN = 'main'
|
||||
|
||||
|
||||
class EventSource(str, Enum):
|
||||
AGENT = 'agent'
|
||||
USER = 'user'
|
||||
|
||||
|
||||
class EventStream:
|
||||
_subscribers: Dict[str, Callable] = {}
|
||||
_events: List[Event] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def subscribe(self, id: EventStreamSubscriber, callback: Callable):
|
||||
if id in self._subscribers:
|
||||
logger.warning('Subscriber subscribed multiple times: ' + id)
|
||||
else:
|
||||
self._subscribers[id] = callback
|
||||
|
||||
def unsubscribe(self, id: EventStreamSubscriber):
|
||||
if id not in self._subscribers:
|
||||
logger.warning('Subscriber not found during unsubscribe: ' + id)
|
||||
else:
|
||||
del self._subscribers[id]
|
||||
|
||||
# TODO: make this not async
|
||||
async def add_event(self, event: Event, source: EventSource):
|
||||
async with self._lock:
|
||||
event._id = len(self._events) # type: ignore [attr-defined]
|
||||
event._timestamp = datetime.now() # type: ignore [attr-defined]
|
||||
event._source = source # type: ignore [attr-defined]
|
||||
self._events.append(event)
|
||||
for key, fn in self._subscribers.items():
|
||||
await fn(event)
|
||||
@ -1,74 +1,43 @@
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from opendevin.const.guide_url import TROUBLESHOOTING_URL
|
||||
from opendevin.controller import AgentController
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.core import config
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import ActionType, ConfigType, TaskState, TaskStateAction
|
||||
from opendevin.core.schema import ActionType, AgentState, ConfigType
|
||||
from opendevin.events.action import (
|
||||
Action,
|
||||
ChangeAgentStateAction,
|
||||
NullAction,
|
||||
action_from_dict,
|
||||
)
|
||||
from opendevin.events.event import Event
|
||||
from opendevin.events.observation import (
|
||||
NullObservation,
|
||||
Observation,
|
||||
UserMessageObservation,
|
||||
)
|
||||
from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
|
||||
from opendevin.llm.llm import LLM
|
||||
from opendevin.server.session import session_manager
|
||||
|
||||
# new task state to valid old task states
|
||||
VALID_TASK_STATE_MAP: Dict[TaskStateAction, List[TaskState]] = {
|
||||
TaskStateAction.PAUSE: [TaskState.RUNNING],
|
||||
TaskStateAction.RESUME: [TaskState.PAUSED],
|
||||
TaskStateAction.STOP: [
|
||||
TaskState.RUNNING,
|
||||
TaskState.PAUSED,
|
||||
TaskState.AWAITING_USER_INPUT,
|
||||
],
|
||||
}
|
||||
IGNORED_TASK_STATE_MAP: Dict[TaskStateAction, List[TaskState]] = {
|
||||
TaskStateAction.PAUSE: [
|
||||
TaskState.INIT,
|
||||
TaskState.PAUSED,
|
||||
TaskState.STOPPED,
|
||||
TaskState.FINISHED,
|
||||
TaskState.AWAITING_USER_INPUT,
|
||||
],
|
||||
TaskStateAction.RESUME: [
|
||||
TaskState.INIT,
|
||||
TaskState.RUNNING,
|
||||
TaskState.STOPPED,
|
||||
TaskState.FINISHED,
|
||||
TaskState.AWAITING_USER_INPUT,
|
||||
],
|
||||
TaskStateAction.STOP: [TaskState.INIT, TaskState.STOPPED, TaskState.FINISHED],
|
||||
}
|
||||
TASK_STATE_ACTION_MAP: Dict[TaskStateAction, TaskState] = {
|
||||
TaskStateAction.START: TaskState.RUNNING,
|
||||
TaskStateAction.PAUSE: TaskState.PAUSED,
|
||||
TaskStateAction.RESUME: TaskState.RUNNING,
|
||||
TaskStateAction.STOP: TaskState.STOPPED,
|
||||
}
|
||||
|
||||
|
||||
class AgentUnit:
|
||||
"""Represents a session with an agent.
|
||||
|
||||
Attributes:
|
||||
controller: The AgentController instance for controlling the agent.
|
||||
agent_task: The task representing the agent's execution.
|
||||
"""
|
||||
|
||||
sid: str
|
||||
event_stream: EventStream
|
||||
controller: Optional[AgentController] = None
|
||||
agent_task: Optional[asyncio.Task] = None
|
||||
# TODO: we will add the runtime here
|
||||
# runtime: Optional[Runtime] = None
|
||||
|
||||
def __init__(self, sid):
|
||||
"""Initializes a new instance of the Session class."""
|
||||
self.sid = sid
|
||||
self.event_stream = EventStream()
|
||||
self.event_stream.subscribe(EventStreamSubscriber.SERVER, self.on_event)
|
||||
|
||||
async def send_error(self, message):
|
||||
"""Sends an error message to the client.
|
||||
@ -100,28 +69,27 @@ class AgentUnit:
|
||||
await self.send_error('Invalid action')
|
||||
return
|
||||
|
||||
match action:
|
||||
case ActionType.INIT:
|
||||
await self.create_controller(data)
|
||||
case ActionType.START:
|
||||
await self.start_task(data)
|
||||
case ActionType.USER_MESSAGE:
|
||||
await self.send_user_message(data)
|
||||
case ActionType.CHANGE_TASK_STATE:
|
||||
task_state_action = data.get('args', {}).get('task_state_action', None)
|
||||
if task_state_action is None:
|
||||
await self.send_error('No task state action specified.')
|
||||
return
|
||||
await self.set_task_state(TaskStateAction(task_state_action))
|
||||
case ActionType.CHAT:
|
||||
if self.controller is None:
|
||||
await self.send_error('No agent started. Please wait a second...')
|
||||
return
|
||||
self.controller.add_history(
|
||||
NullAction(), UserMessageObservation(data['message'])
|
||||
)
|
||||
case _:
|
||||
await self.send_error("I didn't recognize this action:" + action)
|
||||
if action == ActionType.INIT:
|
||||
await self.create_controller(data)
|
||||
await self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
|
||||
)
|
||||
return
|
||||
elif action == ActionType.START:
|
||||
if self.controller is None:
|
||||
await self.send_error('No agent started.')
|
||||
return
|
||||
task = data['args']['task']
|
||||
await self.controller.setup_task(task)
|
||||
await self.event_stream.add_event(
|
||||
ChangeAgentStateAction(agent_state=AgentState.RUNNING), EventSource.USER
|
||||
)
|
||||
return
|
||||
|
||||
action_dict = data.copy()
|
||||
action_dict['action'] = action
|
||||
action_obj = action_from_dict(action_dict)
|
||||
await self.event_stream.add_event(action_obj, EventSource.USER)
|
||||
|
||||
def get_arg_or_default(self, _args: dict, key: ConfigType) -> str:
|
||||
"""Gets an argument from the args dictionary or the default value.
|
||||
@ -156,13 +124,15 @@ class AgentUnit:
|
||||
|
||||
logger.info(f'Creating agent {agent_cls} using LLM {model}')
|
||||
llm = LLM(model=model, api_key=api_key, base_url=api_base)
|
||||
if self.controller is not None:
|
||||
await self.controller.close()
|
||||
try:
|
||||
self.controller = AgentController(
|
||||
sid=self.sid,
|
||||
event_stream=self.event_stream,
|
||||
agent=Agent.get_cls(agent_cls)(llm),
|
||||
max_iterations=int(max_iterations),
|
||||
max_chars=int(max_chars),
|
||||
callbacks=[self.on_agent_event],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating controller: {e}')
|
||||
@ -170,77 +140,8 @@ class AgentUnit:
|
||||
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
|
||||
)
|
||||
return
|
||||
await self.init_done()
|
||||
|
||||
async def init_done(self):
|
||||
if self.controller is None:
|
||||
await self.send_error('No agent started.')
|
||||
return
|
||||
await self.send(
|
||||
{
|
||||
'action': ActionType.INIT,
|
||||
'message': 'Control loop started.',
|
||||
}
|
||||
)
|
||||
await self.controller.notify_task_state_changed()
|
||||
|
||||
async def start_task(self, start_event):
|
||||
"""Starts a task for the agent.
|
||||
|
||||
Args:
|
||||
start_event: The start event data.
|
||||
"""
|
||||
task = start_event['args']['task']
|
||||
if self.controller is None:
|
||||
await self.send_error('No agent started. Please wait a second...')
|
||||
return
|
||||
try:
|
||||
if self.agent_task:
|
||||
self.agent_task.cancel()
|
||||
self.agent_task = asyncio.create_task(
|
||||
self.controller.start(task), name='agent start task loop'
|
||||
)
|
||||
except Exception as e:
|
||||
await self.send_error(f'Error during task loop: {e}')
|
||||
|
||||
async def send_user_message(self, data: dict):
|
||||
if not self.agent_task or not self.controller:
|
||||
await self.send_error('No agent started.')
|
||||
return
|
||||
|
||||
await self.controller.add_user_message(
|
||||
UserMessageObservation(data['args']['message'])
|
||||
)
|
||||
|
||||
async def set_task_state(self, new_state_action: TaskStateAction):
|
||||
"""Sets the state of the agent task."""
|
||||
if self.controller is None:
|
||||
await self.send_error('No agent started.')
|
||||
return
|
||||
|
||||
cur_state = self.controller.get_task_state()
|
||||
new_state = TASK_STATE_ACTION_MAP.get(new_state_action)
|
||||
if new_state is None:
|
||||
await self.send_error('Invalid task state action.')
|
||||
return
|
||||
if cur_state in VALID_TASK_STATE_MAP.get(new_state_action, []):
|
||||
await self.controller.set_task_state_to(new_state)
|
||||
elif cur_state in IGNORED_TASK_STATE_MAP.get(new_state_action, []):
|
||||
# notify once again.
|
||||
await self.controller.notify_task_state_changed()
|
||||
return
|
||||
else:
|
||||
await self.send_error('Current task state not recognized.')
|
||||
return
|
||||
|
||||
if new_state_action == TaskStateAction.RESUME:
|
||||
if self.agent_task:
|
||||
self.agent_task.cancel()
|
||||
self.agent_task = asyncio.create_task(
|
||||
self.controller.resume(), name='agent resume task loop'
|
||||
)
|
||||
|
||||
async def on_agent_event(self, event: Observation | Action):
|
||||
async def on_event(self, event: Event):
|
||||
"""Callback function for agent events.
|
||||
|
||||
Args:
|
||||
@ -250,10 +151,10 @@ class AgentUnit:
|
||||
return
|
||||
if isinstance(event, NullObservation):
|
||||
return
|
||||
await self.send(event.to_dict())
|
||||
if event.source == 'agent':
|
||||
await self.send(event.to_dict())
|
||||
return
|
||||
|
||||
def close(self):
|
||||
if self.agent_task:
|
||||
self.agent_task.cancel()
|
||||
if self.controller is not None:
|
||||
self.controller.action_manager.sandbox.close()
|
||||
self.controller.close()
|
||||
|
||||
@ -64,7 +64,7 @@ class MessageStack:
|
||||
# Ignore assistant init message for now.
|
||||
if 'action' in msg.payload and msg.payload['action'] in [
|
||||
ActionType.INIT,
|
||||
ActionType.CHANGE_TASK_STATE,
|
||||
ActionType.CHANGE_AGENT_STATE,
|
||||
]:
|
||||
continue
|
||||
cnt += 1
|
||||
|
||||
@ -10,6 +10,7 @@ from opendevin.events.action.github import GitHubPushAction, GitHubSendPRAction
|
||||
from opendevin.events.observation.commands import CmdOutputObservation
|
||||
from opendevin.events.observation.error import AgentErrorObservation
|
||||
from opendevin.events.observation.message import AgentMessageObservation
|
||||
from opendevin.events.stream import EventStream
|
||||
from opendevin.llm.llm import LLM
|
||||
|
||||
|
||||
@ -19,7 +20,8 @@ def agent_controller():
|
||||
config.config[ConfigType.SANDBOX_TYPE] = 'local'
|
||||
llm = LLM()
|
||||
agent = DummyAgent(llm=llm)
|
||||
controller = AgentController(agent)
|
||||
event_stream = EventStream()
|
||||
controller = AgentController(agent, event_stream)
|
||||
yield controller
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user