Separate agent controller and server via EventStream (#1538)

* move towards event stream

* refactor agent state changes

* move agent state logic

* fix callbacks

* break on finish

* closer to working

* change frontend to accomodate new flow

* handle start action

* fix locked stream

* revert message

* logspam

* no async on close

* get rid of agent_task

* fix up closing

* better asyncio handling

* sleep to give back control

* fix key

* logspam

* update frontend agent state actions

* fix pause and cancel

* delint

* fix map

* delint

* wait for agent to finish

* fix unit test

* event stream enums

* fix merge issues

* fix lint

* fix test

* fix test

* add user message action

* add user message action

* fix up user messages

* fix main.py flow

* refactor message waiting

* lint

* fix test

* fix test
This commit is contained in:
Robert Brennan 2024-05-05 15:20:01 -04:00 committed by GitHub
parent 4e84aac577
commit f7e0c6cd06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 433 additions and 494 deletions

View File

@ -4,47 +4,39 @@ import { useSelector } from "react-redux";
import ArrowIcon from "#/assets/arrow";
import PauseIcon from "#/assets/pause";
import PlayIcon from "#/assets/play";
import { changeTaskState } from "#/services/agentStateService";
import { changeAgentState } from "#/services/agentStateService";
import { clearMsgs } from "#/services/session";
import store, { RootState } from "#/store";
import AgentTaskAction from "#/types/AgentTaskAction";
import AgentTaskState from "#/types/AgentTaskState";
import AgentState from "#/types/AgentState";
import { clearMessages } from "#/state/chatSlice";
const TaskStateActionMap = {
[AgentTaskAction.START]: AgentTaskState.RUNNING,
[AgentTaskAction.PAUSE]: AgentTaskState.PAUSED,
[AgentTaskAction.RESUME]: AgentTaskState.RUNNING,
[AgentTaskAction.STOP]: AgentTaskState.STOPPED,
};
const IgnoreTaskStateMap: { [k: string]: AgentTaskState[] } = {
[AgentTaskAction.PAUSE]: [
AgentTaskState.INIT,
AgentTaskState.PAUSED,
AgentTaskState.STOPPED,
AgentTaskState.FINISHED,
AgentTaskState.AWAITING_USER_INPUT,
const IgnoreTaskStateMap: { [k: string]: AgentState[] } = {
[AgentState.PAUSED]: [
AgentState.INIT,
AgentState.PAUSED,
AgentState.STOPPED,
AgentState.FINISHED,
AgentState.AWAITING_USER_INPUT,
],
[AgentTaskAction.RESUME]: [
AgentTaskState.INIT,
AgentTaskState.RUNNING,
AgentTaskState.STOPPED,
AgentTaskState.FINISHED,
AgentTaskState.AWAITING_USER_INPUT,
[AgentState.RUNNING]: [
AgentState.INIT,
AgentState.RUNNING,
AgentState.STOPPED,
AgentState.FINISHED,
AgentState.AWAITING_USER_INPUT,
],
[AgentTaskAction.STOP]: [
AgentTaskState.INIT,
AgentTaskState.STOPPED,
AgentTaskState.FINISHED,
[AgentState.STOPPED]: [
AgentState.INIT,
AgentState.STOPPED,
AgentState.FINISHED,
],
};
interface ButtonProps {
isDisabled: boolean;
content: string;
action: AgentTaskAction;
handleAction: (action: AgentTaskAction) => void;
action: AgentState;
handleAction: (action: AgentState) => void;
large?: boolean;
}
@ -75,53 +67,53 @@ ActionButton.defaultProps = {
};
function AgentControlBar() {
const { curTaskState } = useSelector((state: RootState) => state.agent);
const [desiredState, setDesiredState] = React.useState(AgentTaskState.INIT);
const { curAgentState } = useSelector((state: RootState) => state.agent);
const [desiredState, setDesiredState] = React.useState(AgentState.INIT);
const [isLoading, setIsLoading] = React.useState(false);
const handleAction = (action: AgentTaskAction) => {
if (IgnoreTaskStateMap[action].includes(curTaskState)) {
const handleAction = (action: AgentState) => {
if (IgnoreTaskStateMap[action].includes(curAgentState)) {
return;
}
let act = action;
if (act === AgentTaskAction.STOP) {
act = AgentTaskAction.STOP;
if (act === AgentState.STOPPED) {
act = AgentState.STOPPED;
clearMsgs().then().catch();
store.dispatch(clearMessages());
} else {
setIsLoading(true);
}
setDesiredState(TaskStateActionMap[act]);
changeTaskState(act);
setDesiredState(act);
changeAgentState(act);
};
useEffect(() => {
if (curTaskState === desiredState) {
if (curTaskState === AgentTaskState.STOPPED) {
if (curAgentState === desiredState) {
if (curAgentState === AgentState.STOPPED) {
clearMsgs().then().catch();
store.dispatch(clearMessages());
}
setIsLoading(false);
} else if (curTaskState === AgentTaskState.RUNNING) {
setDesiredState(AgentTaskState.RUNNING);
} else if (curAgentState === AgentState.RUNNING) {
setDesiredState(AgentState.RUNNING);
}
// We only want to run this effect when curTaskState changes
// We only want to run this effect when curAgentState changes
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [curTaskState]);
}, [curAgentState]);
return (
<div className="flex items-center gap-3">
{curTaskState === AgentTaskState.PAUSED ? (
{curAgentState === AgentState.PAUSED ? (
<ActionButton
isDisabled={
isLoading ||
IgnoreTaskStateMap[AgentTaskAction.RESUME].includes(curTaskState)
IgnoreTaskStateMap[AgentState.RUNNING].includes(curAgentState)
}
content="Resume the agent task"
action={AgentTaskAction.RESUME}
action={AgentState.RUNNING}
handleAction={handleAction}
large
>
@ -131,10 +123,10 @@ function AgentControlBar() {
<ActionButton
isDisabled={
isLoading ||
IgnoreTaskStateMap[AgentTaskAction.PAUSE].includes(curTaskState)
IgnoreTaskStateMap[AgentState.PAUSED].includes(curAgentState)
}
content="Pause the agent task"
action={AgentTaskAction.PAUSE}
action={AgentState.PAUSED}
handleAction={handleAction}
large
>
@ -144,7 +136,7 @@ function AgentControlBar() {
<ActionButton
isDisabled={isLoading}
content="Restart a new agent task"
action={AgentTaskAction.STOP}
action={AgentState.STOPPED}
handleAction={handleAction}
>
<ArrowIcon />

View File

@ -3,31 +3,31 @@ import { useTranslation } from "react-i18next";
import { useSelector } from "react-redux";
import { I18nKey } from "#/i18n/declaration";
import { RootState } from "#/store";
import AgentTaskState from "#/types/AgentTaskState";
import AgentState from "#/types/AgentState";
const AgentStatusMap: { [k: string]: { message: string; indicator: string } } =
{
[AgentTaskState.INIT]: {
[AgentState.INIT]: {
message: "Agent is initialized, waiting for task...",
indicator: "bg-blue-500",
},
[AgentTaskState.RUNNING]: {
[AgentState.RUNNING]: {
message: "Agent is running task...",
indicator: "bg-green-500",
},
[AgentTaskState.AWAITING_USER_INPUT]: {
[AgentState.AWAITING_USER_INPUT]: {
message: "Agent is awaiting user input...",
indicator: "bg-orange-500",
},
[AgentTaskState.PAUSED]: {
[AgentState.PAUSED]: {
message: "Agent has paused.",
indicator: "bg-yellow-500",
},
[AgentTaskState.STOPPED]: {
[AgentState.STOPPED]: {
message: "Agent has stopped.",
indicator: "bg-red-500",
},
[AgentTaskState.FINISHED]: {
[AgentState.FINISHED]: {
message: "Agent has finished the task.",
indicator: "bg-green-500",
},
@ -35,8 +35,7 @@ const AgentStatusMap: { [k: string]: { message: string; indicator: string } } =
function AgentStatusBar() {
const { t } = useTranslation();
const { initialized } = useSelector((state: RootState) => state.task);
const { curTaskState } = useSelector((state: RootState) => state.agent);
const { curAgentState } = useSelector((state: RootState) => state.agent);
// TODO: Extend the agent status, e.g.:
// - Agent is typing
@ -46,13 +45,13 @@ function AgentStatusBar() {
// - Agent is not available
return (
<div className="flex items-center">
{initialized ? (
{curAgentState !== AgentState.LOADING ? (
<>
<div
className={`w-3 h-3 mr-2 rounded-full animate-pulse ${AgentStatusMap[curTaskState].indicator}`}
className={`w-3 h-3 mr-2 rounded-full animate-pulse ${AgentStatusMap[curAgentState].indicator}`}
/>
<span className="text-sm text-stone-400">
{AgentStatusMap[curTaskState].message}
{AgentStatusMap[curAgentState].message}
</span>
</>
) : (

View File

@ -8,7 +8,7 @@ import ChatInterface from "./ChatInterface";
import Socket from "#/services/socket";
import ActionType from "#/types/ActionType";
import { addAssistantMessage } from "#/state/chatSlice";
import AgentTaskState from "#/types/AgentTaskState";
import AgentState from "#/types/AgentState";
// avoid typing side-effect
vi.mock("#/hooks/useTyping", () => ({
@ -25,7 +25,6 @@ const renderChatInterface = () =>
renderWithProviders(<ChatInterface />, {
preloadedState: {
task: {
initialized: true,
completed: false,
},
},
@ -38,7 +37,16 @@ describe("ChatInterface", () => {
});
it("should render the new message the user has typed", async () => {
renderChatInterface();
renderWithProviders(<ChatInterface />, {
preloadedState: {
task: {
completed: false,
},
agent: {
curAgentState: AgentState.INIT,
},
},
});
const input = screen.getByRole("textbox");
@ -73,11 +81,10 @@ describe("ChatInterface", () => {
renderWithProviders(<ChatInterface />, {
preloadedState: {
task: {
initialized: true,
completed: false,
},
agent: {
curTaskState: AgentTaskState.INIT,
curAgentState: AgentState.INIT,
},
},
});
@ -95,11 +102,10 @@ describe("ChatInterface", () => {
renderWithProviders(<ChatInterface />, {
preloadedState: {
task: {
initialized: true,
completed: false,
},
agent: {
curTaskState: AgentTaskState.AWAITING_USER_INPUT,
curAgentState: AgentState.AWAITING_USER_INPUT,
},
},
});
@ -110,8 +116,8 @@ describe("ChatInterface", () => {
});
const event = {
action: ActionType.USER_MESSAGE,
args: { message: "my message" },
action: ActionType.MESSAGE,
args: { content: "my message" },
};
expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
@ -120,9 +126,11 @@ describe("ChatInterface", () => {
renderWithProviders(<ChatInterface />, {
preloadedState: {
task: {
initialized: false,
completed: false,
},
agent: {
curAgentState: AgentState.LOADING,
},
},
});

View File

@ -1,32 +1,19 @@
import React from "react";
import { useDispatch, useSelector } from "react-redux";
import { useSelector } from "react-redux";
import { IoMdChatbubbles } from "react-icons/io";
import ChatInput from "./ChatInput";
import Chat from "./Chat";
import { RootState } from "#/store";
import AgentTaskState from "#/types/AgentTaskState";
import { addUserMessage } from "#/state/chatSlice";
import ActionType from "#/types/ActionType";
import Socket from "#/services/socket";
import AgentState from "#/types/AgentState";
import { sendChatMessage } from "#/services/chatService";
function ChatInterface() {
const { initialized } = useSelector((state: RootState) => state.task);
const { messages } = useSelector((state: RootState) => state.chat);
const { curTaskState } = useSelector((state: RootState) => state.agent);
const dispatch = useDispatch();
const { curAgentState } = useSelector((state: RootState) => state.agent);
const handleSendMessage = (content: string) => {
dispatch(addUserMessage(content));
let event;
if (curTaskState === AgentTaskState.INIT) {
event = { action: ActionType.START, args: { task: content } };
} else {
event = { action: ActionType.USER_MESSAGE, args: { message: content } };
}
Socket.send(JSON.stringify(event));
const isTask = curAgentState === AgentState.INIT;
sendChatMessage(content, isTask);
};
return (
@ -42,7 +29,10 @@ function ChatInterface() {
{/* Fade between messages and input */}
<div className="absolute bottom-0 left-0 right-0 h-4 bg-gradient-to-b from-transparent to-neutral-800" />
</div>
<ChatInput disabled={!initialized} onSendMessage={handleSendMessage} />
<ChatInput
disabled={curAgentState === AgentState.LOADING}
onSendMessage={handleSendMessage}
/>
</div>
);
}

View File

@ -2,7 +2,7 @@ import { act, screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import React from "react";
import { renderWithProviders } from "test-utils";
import AgentTaskState from "#/types/AgentTaskState";
import AgentState from "#/types/AgentState";
import { Settings } from "#/services/settings";
import SettingsForm from "./SettingsForm";
@ -80,7 +80,7 @@ describe("SettingsForm", () => {
onLanguageChange={onLanguageChangeMock}
onAPIKeyChange={onAPIKeyChangeMock}
/>,
{ preloadedState: { agent: { curTaskState: AgentTaskState.RUNNING } } },
{ preloadedState: { agent: { curAgentState: AgentState.RUNNING } } },
);
const modelInput = screen.getByRole("combobox", { name: "model" });
const agentInput = screen.getByRole("combobox", { name: "agent" });

View File

@ -6,7 +6,7 @@ import { useSelector } from "react-redux";
import { AvailableLanguages } from "../../../i18n";
import { I18nKey } from "../../../i18n/declaration";
import { RootState } from "../../../store";
import AgentTaskState from "../../../types/AgentTaskState";
import AgentState from "../../../types/AgentState";
import { AutocompleteCombobox } from "./AutocompleteCombobox";
import { Settings } from "#/services/settings";
@ -31,21 +31,21 @@ function SettingsForm({
onLanguageChange,
}: SettingsFormProps) {
const { t } = useTranslation();
const { curTaskState } = useSelector((state: RootState) => state.agent);
const { curAgentState } = useSelector((state: RootState) => state.agent);
const [disabled, setDisabled] = React.useState<boolean>(false);
const { isOpen: isVisible, onOpenChange: onVisibleChange } = useDisclosure();
useEffect(() => {
if (
curTaskState === AgentTaskState.RUNNING ||
curTaskState === AgentTaskState.PAUSED ||
curTaskState === AgentTaskState.AWAITING_USER_INPUT
curAgentState === AgentState.RUNNING ||
curAgentState === AgentState.PAUSED ||
curAgentState === AgentState.AWAITING_USER_INPUT
) {
setDisabled(true);
} else {
setDisabled(false);
}
}, [curTaskState, setDisabled]);
}, [curAgentState, setDisabled]);
return (
<>

View File

@ -1,11 +1,9 @@
import { changeTaskState } from "#/state/agentSlice";
import { setScreenshotSrc, setUrl } from "#/state/browserSlice";
import { addAssistantMessage } from "#/state/chatSlice";
import { setCode, updatePath } from "#/state/codeSlice";
import { appendInput } from "#/state/commandSlice";
import { appendJupyterInput } from "#/state/jupyterSlice";
import { setPlan } from "#/state/planSlice";
import { setInitialized } from "#/state/taskSlice";
import store from "#/store";
import ActionType from "#/types/ActionType";
import { ActionMessage } from "#/types/Message";
@ -14,9 +12,6 @@ import { handleObservationMessage } from "./observations";
import { getPlan } from "./planService";
const messageActions = {
[ActionType.INIT]: () => {
store.dispatch(setInitialized(true));
},
[ActionType.BROWSE]: (message: ActionMessage) => {
const { url, screenshotSrc } = message.args;
store.dispatch(setUrl(url));
@ -54,9 +49,6 @@ const messageActions = {
[ActionType.MODIFY_TASK]: () => {
getPlan().then((fetchedPlan) => store.dispatch(setPlan(fetchedPlan)));
},
[ActionType.CHANGE_TASK_STATE]: (message: ActionMessage) => {
store.dispatch(changeTaskState(message.args.task_state));
},
};
export function handleActionMessage(message: ActionMessage) {

View File

@ -1,14 +1,11 @@
import { describe, expect, it, vi } from "vitest";
import { setInitialized } from "#/state/taskSlice";
import store from "#/store";
import ActionType from "#/types/ActionType";
import { initializeAgent } from "./agent";
import { Settings } from "./settings";
import Socket from "./socket";
const sendSpy = vi.spyOn(Socket, "send");
const dispatchSpy = vi.spyOn(store, "dispatch");
describe("initializeAgent", () => {
it("Should initialize the agent with the current settings", () => {
@ -27,6 +24,5 @@ describe("initializeAgent", () => {
initializeAgent(settings);
expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
expect(dispatchSpy).toHaveBeenCalledWith(setInitialized(false));
});
});

View File

@ -1,5 +1,3 @@
import { setInitialized } from "#/state/taskSlice";
import store from "#/store";
import ActionType from "#/types/ActionType";
import { Settings } from "./settings";
import Socket from "./socket";
@ -11,7 +9,5 @@ import Socket from "./socket";
export const initializeAgent = (settings: Settings) => {
const event = { action: ActionType.INIT, args: settings };
const eventString = JSON.stringify(event);
store.dispatch(setInitialized(false));
Socket.send(eventString);
};

View File

@ -1,11 +1,11 @@
import ActionType from "#/types/ActionType";
import AgentTaskAction from "#/types/AgentTaskAction";
import AgentState from "#/types/AgentState";
import Socket from "./socket";
export function changeTaskState(message: AgentTaskAction): void {
export function changeAgentState(message: AgentState): void {
const eventString = JSON.stringify({
action: ActionType.CHANGE_TASK_STATE,
args: { task_state_action: message },
action: ActionType.CHANGE_AGENT_STATE,
args: { agent_state: message },
});
Socket.send(eventString);
}

View File

@ -11,7 +11,7 @@ export function sendChatMessage(message: string, isTask: boolean = true): void {
if (isTask) {
event = { action: ActionType.START, args: { task: message } };
} else {
event = { action: ActionType.USER_MESSAGE, args: { message } };
event = { action: ActionType.MESSAGE, args: { content: message } };
}
const eventString = JSON.stringify(event);
Socket.send(eventString);

View File

@ -1,3 +1,4 @@
import { changeAgentState } from "#/state/agentSlice";
import { setUrl, setScreenshotSrc } from "#/state/browserSlice";
import store from "#/store";
import { ObservationMessage } from "#/types/Message";
@ -23,6 +24,9 @@ export function handleObservationMessage(message: ObservationMessage) {
store.dispatch(setUrl(message.extras.url));
}
break;
case ObservationType.AGENT_STATE_CHANGED:
store.dispatch(changeAgentState(message.extras.agent_state));
break;
default:
store.dispatch(addAssistantMessage(message.message));
break;

View File

@ -1,18 +1,18 @@
import { createSlice } from "@reduxjs/toolkit";
import AgentTaskState from "#/types/AgentTaskState";
import AgentState from "#/types/AgentState";
export const agentSlice = createSlice({
name: "agent",
initialState: {
curTaskState: AgentTaskState.INIT,
curAgentState: AgentState.LOADING,
},
reducers: {
changeTaskState: (state, action) => {
state.curTaskState = action.payload;
changeAgentState: (state, action) => {
state.curAgentState = action.payload;
},
},
});
export const { changeTaskState } = agentSlice.actions;
export const { changeAgentState } = agentSlice.actions;
export default agentSlice.reducer;

View File

@ -3,19 +3,15 @@ import { createSlice } from "@reduxjs/toolkit";
export const taskSlice = createSlice({
name: "task",
initialState: {
initialized: false,
completed: false,
},
reducers: {
setInitialized: (state, action) => {
state.initialized = action.payload;
},
setCompleted: (state, action) => {
state.completed = action.payload;
},
},
});
export const { setInitialized, setCompleted } = taskSlice.actions;
export const { setCompleted } = taskSlice.actions;
export default taskSlice.reducer;

View File

@ -2,12 +2,12 @@ enum ActionType {
// Initializes the agent. Only sent by client.
INIT = "initialize",
// Sends a message from the user
USER_MESSAGE = "user_message",
// Starts a new development task
// Starts a new development task.
START = "start",
// Represents a message from the user or agent.
MESSAGE = "message",
// Reads the contents of a file.
READ = "read",
@ -45,7 +45,8 @@ enum ActionType {
// Updates a task in the plan.
MODIFY_TASK = "modify_task",
CHANGE_TASK_STATE = "change_task_state",
// Changes the state of the agent, e.g. to paused or running
CHANGE_AGENT_STATE = "change_agent_state",
}
export default ActionType;

View File

@ -1,4 +1,5 @@
enum AgentTaskState {
enum AgentState {
LOADING = "loading",
INIT = "init",
RUNNING = "running",
AWAITING_USER_INPUT = "awaiting_user_input",
@ -8,4 +9,4 @@ enum AgentTaskState {
ERROR = "error",
}
export default AgentTaskState;
export default AgentState;

View File

@ -1,15 +0,0 @@
enum AgentTaskAction {
// Starts the task.
START = "start",
// Pauses the task.
PAUSE = "pause",
// Resumes the task.
RESUME = "resume",
// Stops the task.
STOP = "stop",
}
export default AgentTaskAction;

View File

@ -16,6 +16,9 @@ enum ObservationType {
// A message from the user
CHAT = "chat",
// Agent state has changed
AGENT_STATE_CHANGED = "agent_state_changed",
}
export default ObservationType;

View File

@ -1,5 +1,5 @@
import asyncio
from typing import Callable, List, Type
from typing import Optional, Type
from agenthub.codeact_agent.codeact_agent import CodeActAgent
from opendevin.controller.action_manager import ActionManager
@ -14,23 +14,27 @@ from opendevin.core.exceptions import (
MaxCharsExceedError,
)
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import TaskState
from opendevin.core.schema import AgentState
from opendevin.core.schema.config import ConfigType
from opendevin.events.action import (
Action,
AgentDelegateAction,
AgentFinishAction,
AgentTalkAction,
ChangeAgentStateAction,
MessageAction,
NullAction,
TaskStateChangedAction,
)
from opendevin.events.event import Event
from opendevin.events.observation import (
AgentDelegateObservation,
AgentErrorObservation,
AgentStateChangedObservation,
NullObservation,
Observation,
UserMessageObservation,
)
from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
from opendevin.runtime import DockerSSHBox
from opendevin.runtime.browser.browser_env import BrowserEnv
@ -43,22 +47,21 @@ class AgentController:
agent: Agent
max_iterations: int
action_manager: ActionManager
callbacks: List[Callable]
browser: BrowserEnv
event_stream: EventStream
agent_task: Optional[asyncio.Task] = None
delegate: 'AgentController | None' = None
state: State | None = None
_task_state: TaskState = TaskState.INIT
_agent_state: AgentState = AgentState.LOADING
_cur_step: int = 0
def __init__(
self,
agent: Agent,
event_stream: EventStream,
sid: str = 'default',
max_iterations: int = MAX_ITERATIONS,
max_chars: int = MAX_CHARS,
callbacks: List[Callable] = [],
):
"""Initializes a new instance of the AgentController class.
@ -67,14 +70,16 @@ class AgentController:
sid: The session ID of the agent.
max_iterations: The maximum number of iterations the agent can run.
max_chars: The maximum number of characters the agent can output.
callbacks: A list of callback functions to run after each action.
"""
self.id = sid
self.agent = agent
self.event_stream = event_stream
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event
)
self.max_iterations = max_iterations
self.action_manager = ActionManager(self.id)
self.max_chars = max_chars
self.callbacks = callbacks
# Initialize agent-required plugins for sandbox (if any)
self.action_manager.init_sandbox_plugins(agent.sandbox_plugins)
# Initialize browser environment
@ -87,7 +92,12 @@ class AgentController:
'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
)
self._await_user_message_queue: asyncio.Queue = asyncio.Queue()
async def close(self):
if self.agent_task is not None:
self.agent_task.cancel()
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
self.action_manager.sandbox.close()
await self.set_agent_state_to(AgentState.STOPPED)
def update_state_for_step(self, i):
if self.state is None:
@ -100,9 +110,14 @@ class AgentController:
return
self.state.updated_info = []
def add_history(self, action: Action, observation: Observation):
async def add_error_to_history(self, message: str):
await self.add_history(NullAction(), AgentErrorObservation(message))
async def add_history(
self, action: Action, observation: Observation, add_to_stream=True
):
if self.state is None:
return
raise ValueError('Added history while state was None')
if not isinstance(action, Action):
raise TypeError(
f'action must be an instance of Action, got {type(action).__name__} instead'
@ -113,12 +128,15 @@ class AgentController:
)
self.state.history.append((action, observation))
self.state.updated_info.append((action, observation))
if add_to_stream:
await self.event_stream.add_event(action, EventSource.AGENT)
await self.event_stream.add_event(observation, EventSource.AGENT)
async def _run(self):
if self.state is None:
return
if self._task_state != TaskState.RUNNING:
if self._agent_state != AgentState.RUNNING:
raise ValueError('Task is not in running state')
for i in range(self._cur_step, self.max_iterations):
@ -126,117 +144,79 @@ class AgentController:
try:
finished = await self.step(i)
if finished:
self._task_state = TaskState.FINISHED
await self.set_agent_state_to(AgentState.FINISHED)
break
except Exception:
logger.error('Error in loop', exc_info=True)
await self._run_callbacks(
AgentErrorObservation(
'Oops! Something went wrong while completing your task. You can check the logs for more info.'
)
await self.set_agent_state_to(AgentState.ERROR)
await self.add_error_to_history(
'Oops! Something went wrong while completing your task. You can check the logs for more info.'
)
await self.set_task_state_to(TaskState.STOPPED)
break
if self._task_state == TaskState.FINISHED:
logger.info('Task finished by agent')
await self.reset_task()
break
elif self._task_state == TaskState.STOPPED:
logger.info('Task stopped by user')
await self.reset_task()
break
elif self._task_state == TaskState.PAUSED:
logger.info('Task paused')
self._cur_step = i + 1
await self.notify_task_state_changed()
break
if self._is_stuck():
logger.info('Loop detected, stopping task')
observation = AgentErrorObservation(
await self.set_agent_state_to(AgentState.ERROR)
await self.add_error_to_history(
'I got stuck into a loop, the task has stopped.'
)
await self._run_callbacks(observation)
await self.set_task_state_to(TaskState.STOPPED)
break
await asyncio.sleep(
0.001
) # Give back control for a tick, so other async stuff can run
async def setup_task(self, task: str, inputs: dict = {}):
"""Sets up the agent controller with a task."""
self._task_state = TaskState.RUNNING
await self.notify_task_state_changed()
await self.set_agent_state_to(AgentState.INIT)
self.state = State(Plan(task))
self.state.inputs = inputs
async def start(self, task: str):
"""Starts the agent controller with a task.
If task already run before, it will continue from the last step.
"""
await self.setup_task(task)
await self._run()
async def resume(self):
if self.state is None:
raise ValueError('No task to resume')
self._task_state = TaskState.RUNNING
await self.notify_task_state_changed()
await self._run()
async def on_event(self, event: Event):
if isinstance(event, ChangeAgentStateAction):
await self.set_agent_state_to(event.agent_state) # type: ignore
elif isinstance(event, MessageAction) and event.source == EventSource.USER:
# FIXME: we're hacking a message action into a user message observation, for the benefit of CodeAct
await self.add_history(
self._pending_talk_action,
UserMessageObservation(event.content),
add_to_stream=False,
)
await self.set_agent_state_to(AgentState.RUNNING)
async def reset_task(self):
if self.agent_task is not None:
self.agent_task.cancel()
self.state = None
self._cur_step = 0
self._task_state = TaskState.INIT
self.agent.reset()
await self.notify_task_state_changed()
async def set_task_state_to(self, state: TaskState):
self._task_state = state
if state == TaskState.STOPPED:
await self.reset_task()
logger.info(f'Task state set to {state}')
def get_task_state(self):
"""Returns the current state of the agent task."""
return self._task_state
async def notify_task_state_changed(self):
await self._run_callbacks(TaskStateChangedAction(self._task_state))
async def add_user_message(self, message: UserMessageObservation):
if self.state is None:
async def set_agent_state_to(self, new_state: AgentState):
logger.info(f'Setting agent state from {self._agent_state} to {new_state}')
if new_state == self._agent_state:
return
if self._task_state == TaskState.AWAITING_USER_INPUT:
self._await_user_message_queue.put_nowait(message)
self._agent_state = new_state
if new_state == AgentState.RUNNING:
self.agent_task = asyncio.create_task(self._run())
elif (
new_state == AgentState.PAUSED
or new_state == AgentState.AWAITING_USER_INPUT
):
self._cur_step += 1
if self.agent_task is not None:
self.agent_task.cancel()
elif new_state == AgentState.STOPPED:
await self.reset_task()
elif new_state == AgentState.FINISHED:
await self.reset_task()
# set the task state to running
self._task_state = TaskState.RUNNING
await self.notify_task_state_changed()
await self.event_stream.add_event(
AgentStateChangedObservation('', self._agent_state), EventSource.AGENT
)
elif self._task_state == TaskState.RUNNING:
self.add_history(NullAction(), message)
else:
raise ValueError(
f'Task (state: {self._task_state}) is not in a state to add user message'
)
async def wait_for_user_input(self) -> UserMessageObservation:
self._task_state = TaskState.AWAITING_USER_INPUT
await self.notify_task_state_changed()
# wait for the next user message
if len(self.callbacks) == 0:
logger.info(
'Use STDIN to request user message as no callbacks are registered',
extra={'msg_type': 'INFO'},
)
message = input('Request user input [type /exit to stop interaction] >> ')
user_message_observation = UserMessageObservation(message)
else:
user_message_observation = await self._await_user_message_queue.get()
self._await_user_message_queue.task_done()
return user_message_observation
def get_agent_state(self):
"""Returns the current state of the agent task."""
return self._agent_state
async def start_delegate(self, action: AgentDelegateAction):
AgentCls: Type[Agent] = Agent.get_cls(action.agent)
@ -244,9 +224,9 @@ class AgentController:
self.delegate = AgentController(
sid=self.id + '-delegate',
agent=agent,
event_stream=self.event_stream,
max_iterations=self.max_iterations,
max_chars=self.max_chars,
callbacks=self.callbacks,
)
task = action.inputs.get('task') or ''
await self.delegate.setup_task(task, action.inputs)
@ -259,7 +239,7 @@ class AgentController:
if delegate_done:
outputs = self.delegate.state.outputs if self.delegate.state else {}
obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
self.add_history(NullAction(), obs)
await self.add_history(NullAction(), obs)
self.delegate = None
self.delegateAction = None
return False
@ -272,8 +252,7 @@ class AgentController:
log_obs = self.action_manager.get_background_obs()
for obs in log_obs:
self.add_history(NullAction(), obs)
await self._run_callbacks(obs)
await self.add_history(NullAction(), obs)
logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
self.update_state_for_step(i)
@ -289,14 +268,10 @@ class AgentController:
self.update_state_after_step()
await self._run_callbacks(action)
# whether to await for user messages
if isinstance(action, AgentTalkAction):
# await for the next user messages
user_message_observation = await self.wait_for_user_input()
logger.info(user_message_observation, extra={'msg_type': 'OBSERVATION'})
self.add_history(action, user_message_observation)
self._pending_talk_action = action
await self.event_stream.add_event(action, EventSource.AGENT)
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
return False
finished = isinstance(action, AgentFinishAction)
@ -311,23 +286,9 @@ class AgentController:
if not isinstance(observation, NullObservation):
logger.info(observation, extra={'msg_type': 'OBSERVATION'})
self.add_history(action, observation)
await self._run_callbacks(observation)
await self.add_history(action, observation)
return False
async def _run_callbacks(self, event):
if event is None:
return
for callback in self.callbacks:
idx = self.callbacks.index(callback)
try:
await callback(event)
except Exception as e:
logger.exception(f'Callback error: {e}, idx: {idx}')
await asyncio.sleep(
0.001
) # Give back control for a tick, so we can await in callbacks
def get_state(self):
return self.state

View File

@ -6,6 +6,11 @@ import agenthub # noqa F401 (we import this to get the agents registered)
from opendevin.controller import AgentController
from opendevin.controller.agent import Agent
from opendevin.core.config import args
from opendevin.core.schema import AgentState
from opendevin.events.action import ChangeAgentStateAction, MessageAction
from opendevin.events.event import Event
from opendevin.events.observation import AgentStateChangedObservation
from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
from opendevin.llm.llm import LLM
@ -41,11 +46,33 @@ async def main(task_str: str = ''):
llm = LLM(args.model_name)
AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
agent = AgentCls(llm=llm)
event_stream = EventStream()
controller = AgentController(
agent=agent, max_iterations=args.max_iterations, max_chars=args.max_chars
agent=agent,
max_iterations=args.max_iterations,
max_chars=args.max_chars,
event_stream=event_stream,
)
await controller.start(task)
await controller.setup_task(task)
await event_stream.add_event(
ChangeAgentStateAction(agent_state=AgentState.RUNNING), EventSource.USER
)
async def on_event(event: Event):
if isinstance(event, AgentStateChangedObservation):
if event.agent_state == AgentState.AWAITING_USER_INPUT:
message = input('Request user input >> ')
action = MessageAction(content=message)
await event_stream.add_event(action, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
while controller.get_agent_state() not in [
AgentState.FINISHED,
AgentState.ERROR,
AgentState.STOPPED,
]:
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
if __name__ == '__main__':

View File

@ -1,12 +1,11 @@
from .action import ActionType
from .agent import AgentState
from .config import ConfigType
from .observation import ObservationType
from .task import TaskState, TaskStateAction
__all__ = [
'ActionType',
'ObservationType',
'ConfigType',
'TaskState',
'TaskStateAction',
'AgentState',
]

View File

@ -8,8 +8,8 @@ class ActionTypeSchema(BaseModel):
"""Initializes the agent. Only sent by client.
"""
USER_MESSAGE: str = Field(default='user_message')
"""Sends a message from the user. Only sent by the client.
MESSAGE: str = Field(default='message')
"""Represents a message.
"""
START: str = Field(default='start')
@ -81,7 +81,7 @@ class ActionTypeSchema(BaseModel):
"""Stops the task. Must send a start action to restart a new task.
"""
CHANGE_TASK_STATE: str = Field(default='change_task_state')
CHANGE_AGENT_STATE: str = Field(default='change_agent_state')
PUSH: str = Field(default='push')
"""Push a branch to github."""

View File

@ -0,0 +1,35 @@
from enum import Enum
class AgentState(str, Enum):
LOADING = 'loading'
"""The agent is loading.
"""
INIT = 'init'
"""The agent is initialized.
"""
RUNNING = 'running'
"""The agent is running.
"""
AWAITING_USER_INPUT = 'awaiting_user_input'
"""The agent is awaiting user input.
"""
PAUSED = 'paused'
"""The agent is paused.
"""
STOPPED = 'stopped'
"""The agent is stopped.
"""
FINISHED = 'finished'
"""The agent is finished with the current task.
"""
ERROR = 'error'
"""An error occurred during the task.
"""

View File

@ -40,5 +40,7 @@ class ObservationTypeSchema(BaseModel):
NULL: str = Field(default='null')
AGENT_STATE_CHANGED: str = Field(default='agent_state_changed')
ObservationType = ObservationTypeSchema()

View File

@ -1,49 +0,0 @@
from enum import Enum
class TaskState(str, Enum):
INIT = 'init'
"""Initial state of the task.
"""
RUNNING = 'running'
"""The task is running.
"""
AWAITING_USER_INPUT = 'awaiting_user_input'
"""The task is awaiting user input.
"""
PAUSED = 'paused'
"""The task is paused.
"""
STOPPED = 'stopped'
"""The task is stopped.
"""
FINISHED = 'finished'
"""The task is finished.
"""
ERROR = 'error'
"""An error occurred during the task.
"""
class TaskStateAction(str, Enum):
START = 'start'
"""Starts the task.
"""
PAUSE = 'pause'
"""Pauses the task.
"""
RESUME = 'resume'
"""Resumes the task.
"""
STOP = 'stop'
"""Stops the task.
"""

View File

@ -9,13 +9,15 @@ from .agent import (
AgentSummarizeAction,
AgentTalkAction,
AgentThinkAction,
ChangeAgentStateAction,
)
from .browse import BrowseURLAction
from .commands import CmdKillAction, CmdRunAction, IPythonRunCellAction
from .empty import NullAction
from .files import FileReadAction, FileWriteAction
from .github import GitHubPushAction
from .tasks import AddTaskAction, ModifyTaskAction, TaskStateChangedAction
from .message import MessageAction
from .tasks import AddTaskAction, ModifyTaskAction
actions = (
CmdKillAction,
@ -31,8 +33,9 @@ actions = (
AgentDelegateAction,
AddTaskAction,
ModifyTaskAction,
TaskStateChangedAction,
ChangeAgentStateAction,
GitHubPushAction,
MessageAction,
)
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
@ -74,6 +77,7 @@ __all__ = [
'AgentSummarizeAction',
'AddTaskAction',
'ModifyTaskAction',
'TaskStateChangedAction',
'ChangeAgentStateAction',
'IPythonRunCellAction',
'MessageAction',
]

View File

@ -15,6 +15,19 @@ if TYPE_CHECKING:
from opendevin.controller import AgentController
@dataclass
class ChangeAgentStateAction(Action):
"""Fake action, just to notify the client that a task state has changed."""
agent_state: str
thought: str = ''
action: str = ActionType.CHANGE_AGENT_STATE
@property
def message(self) -> str:
return f'Agent state changed to {self.agent_state}'
@dataclass
class AgentRecallAction(Action):
query: str

View File

@ -0,0 +1,15 @@
from dataclasses import dataclass
from opendevin.core.schema import ActionType
from .action import Action
@dataclass
class MessageAction(Action):
content: str
action: str = ActionType.MESSAGE
@property
def message(self) -> str:
return self.content

View File

@ -43,16 +43,3 @@ class ModifyTaskAction(Action):
@property
def message(self) -> str:
return f'Set task {self.id} to {self.state}'
@dataclass
class TaskStateChangedAction(Action):
"""Fake action, just to notify the client that a task state has changed."""
task_state: str
thought: str = ''
action: str = ActionType.CHANGE_TASK_STATE
@property
def message(self) -> str:
return f'Task state changed to {self.task_state}'

View File

@ -13,4 +13,8 @@ class Event:
@property
def message(self) -> str:
return self.message
return self._message # type: ignore [attr-defined]
@property
def source(self) -> str:
return self._source # type: ignore [attr-defined]

View File

@ -1,3 +1,4 @@
from .agent import AgentStateChangedObservation
from .browse import BrowserOutputObservation
from .commands import CmdOutputObservation, IPythonRunCellObservation
from .delegate import AgentDelegateObservation
@ -18,9 +19,13 @@ observations = (
AgentRecallObservation,
AgentDelegateObservation,
AgentErrorObservation,
AgentStateChangedObservation,
)
OBSERVATION_TYPE_TO_CLASS = {observation_class.observation: observation_class for observation_class in observations} # type: ignore[attr-defined]
OBSERVATION_TYPE_TO_CLASS = {
observation_class.observation: observation_class # type: ignore[attr-defined]
for observation_class in observations
}
def observation_from_dict(observation: dict) -> Observation:
@ -29,7 +34,9 @@ def observation_from_dict(observation: dict) -> Observation:
raise KeyError(f"'observation' key is not found in {observation=}")
observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation['observation'])
if observation_class is None:
raise KeyError(f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}")
raise KeyError(
f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}"
)
observation.pop('observation')
observation.pop('message', None)
content = observation.pop('content', '')
@ -49,4 +56,5 @@ __all__ = [
'AgentMessageObservation',
'AgentRecallObservation',
'AgentErrorObservation',
'AgentStateChangedObservation',
]

View File

@ -0,0 +1,19 @@
from dataclasses import dataclass
from opendevin.core.schema import ObservationType
from .observation import Observation
@dataclass
class AgentStateChangedObservation(Observation):
"""
This data class represents the result from delegating to another agent
"""
agent_state: str
observation: str = ObservationType.AGENT_STATE_CHANGED
@property
def message(self) -> str:
return ''

View File

@ -0,0 +1,48 @@
import asyncio
from datetime import datetime
from enum import Enum
from typing import Callable, Dict, List
from opendevin.core.logger import opendevin_logger as logger
from .event import Event
class EventStreamSubscriber(str, Enum):
AGENT_CONTROLLER = 'agent_controller'
SERVER = 'server'
RUNTIME = 'runtime'
MAIN = 'main'
class EventSource(str, Enum):
AGENT = 'agent'
USER = 'user'
class EventStream:
_subscribers: Dict[str, Callable] = {}
_events: List[Event] = []
_lock = asyncio.Lock()
def subscribe(self, id: EventStreamSubscriber, callback: Callable):
if id in self._subscribers:
logger.warning('Subscriber subscribed multiple times: ' + id)
else:
self._subscribers[id] = callback
def unsubscribe(self, id: EventStreamSubscriber):
if id not in self._subscribers:
logger.warning('Subscriber not found during unsubscribe: ' + id)
else:
del self._subscribers[id]
# TODO: make this not async
async def add_event(self, event: Event, source: EventSource):
async with self._lock:
event._id = len(self._events) # type: ignore [attr-defined]
event._timestamp = datetime.now() # type: ignore [attr-defined]
event._source = source # type: ignore [attr-defined]
self._events.append(event)
for key, fn in self._subscribers.items():
await fn(event)

View File

@ -1,74 +1,43 @@
import asyncio
from typing import Dict, List, Optional
from typing import Optional
from opendevin.const.guide_url import TROUBLESHOOTING_URL
from opendevin.controller import AgentController
from opendevin.controller.agent import Agent
from opendevin.core import config
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import ActionType, ConfigType, TaskState, TaskStateAction
from opendevin.core.schema import ActionType, AgentState, ConfigType
from opendevin.events.action import (
Action,
ChangeAgentStateAction,
NullAction,
action_from_dict,
)
from opendevin.events.event import Event
from opendevin.events.observation import (
NullObservation,
Observation,
UserMessageObservation,
)
from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
from opendevin.llm.llm import LLM
from opendevin.server.session import session_manager
# new task state to valid old task states
VALID_TASK_STATE_MAP: Dict[TaskStateAction, List[TaskState]] = {
TaskStateAction.PAUSE: [TaskState.RUNNING],
TaskStateAction.RESUME: [TaskState.PAUSED],
TaskStateAction.STOP: [
TaskState.RUNNING,
TaskState.PAUSED,
TaskState.AWAITING_USER_INPUT,
],
}
IGNORED_TASK_STATE_MAP: Dict[TaskStateAction, List[TaskState]] = {
TaskStateAction.PAUSE: [
TaskState.INIT,
TaskState.PAUSED,
TaskState.STOPPED,
TaskState.FINISHED,
TaskState.AWAITING_USER_INPUT,
],
TaskStateAction.RESUME: [
TaskState.INIT,
TaskState.RUNNING,
TaskState.STOPPED,
TaskState.FINISHED,
TaskState.AWAITING_USER_INPUT,
],
TaskStateAction.STOP: [TaskState.INIT, TaskState.STOPPED, TaskState.FINISHED],
}
TASK_STATE_ACTION_MAP: Dict[TaskStateAction, TaskState] = {
TaskStateAction.START: TaskState.RUNNING,
TaskStateAction.PAUSE: TaskState.PAUSED,
TaskStateAction.RESUME: TaskState.RUNNING,
TaskStateAction.STOP: TaskState.STOPPED,
}
class AgentUnit:
"""Represents a session with an agent.
Attributes:
controller: The AgentController instance for controlling the agent.
agent_task: The task representing the agent's execution.
"""
sid: str
event_stream: EventStream
controller: Optional[AgentController] = None
agent_task: Optional[asyncio.Task] = None
# TODO: we will add the runtime here
# runtime: Optional[Runtime] = None
def __init__(self, sid):
"""Initializes a new instance of the Session class."""
self.sid = sid
self.event_stream = EventStream()
self.event_stream.subscribe(EventStreamSubscriber.SERVER, self.on_event)
async def send_error(self, message):
"""Sends an error message to the client.
@ -100,28 +69,27 @@ class AgentUnit:
await self.send_error('Invalid action')
return
match action:
case ActionType.INIT:
await self.create_controller(data)
case ActionType.START:
await self.start_task(data)
case ActionType.USER_MESSAGE:
await self.send_user_message(data)
case ActionType.CHANGE_TASK_STATE:
task_state_action = data.get('args', {}).get('task_state_action', None)
if task_state_action is None:
await self.send_error('No task state action specified.')
return
await self.set_task_state(TaskStateAction(task_state_action))
case ActionType.CHAT:
if self.controller is None:
await self.send_error('No agent started. Please wait a second...')
return
self.controller.add_history(
NullAction(), UserMessageObservation(data['message'])
)
case _:
await self.send_error("I didn't recognize this action:" + action)
if action == ActionType.INIT:
await self.create_controller(data)
await self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
)
return
elif action == ActionType.START:
if self.controller is None:
await self.send_error('No agent started.')
return
task = data['args']['task']
await self.controller.setup_task(task)
await self.event_stream.add_event(
ChangeAgentStateAction(agent_state=AgentState.RUNNING), EventSource.USER
)
return
action_dict = data.copy()
action_dict['action'] = action
action_obj = action_from_dict(action_dict)
await self.event_stream.add_event(action_obj, EventSource.USER)
def get_arg_or_default(self, _args: dict, key: ConfigType) -> str:
"""Gets an argument from the args dictionary or the default value.
@ -156,13 +124,15 @@ class AgentUnit:
logger.info(f'Creating agent {agent_cls} using LLM {model}')
llm = LLM(model=model, api_key=api_key, base_url=api_base)
if self.controller is not None:
await self.controller.close()
try:
self.controller = AgentController(
sid=self.sid,
event_stream=self.event_stream,
agent=Agent.get_cls(agent_cls)(llm),
max_iterations=int(max_iterations),
max_chars=int(max_chars),
callbacks=[self.on_agent_event],
)
except Exception as e:
logger.exception(f'Error creating controller: {e}')
@ -170,77 +140,8 @@ class AgentUnit:
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
)
return
await self.init_done()
async def init_done(self):
if self.controller is None:
await self.send_error('No agent started.')
return
await self.send(
{
'action': ActionType.INIT,
'message': 'Control loop started.',
}
)
await self.controller.notify_task_state_changed()
async def start_task(self, start_event):
"""Starts a task for the agent.
Args:
start_event: The start event data.
"""
task = start_event['args']['task']
if self.controller is None:
await self.send_error('No agent started. Please wait a second...')
return
try:
if self.agent_task:
self.agent_task.cancel()
self.agent_task = asyncio.create_task(
self.controller.start(task), name='agent start task loop'
)
except Exception as e:
await self.send_error(f'Error during task loop: {e}')
async def send_user_message(self, data: dict):
if not self.agent_task or not self.controller:
await self.send_error('No agent started.')
return
await self.controller.add_user_message(
UserMessageObservation(data['args']['message'])
)
async def set_task_state(self, new_state_action: TaskStateAction):
"""Sets the state of the agent task."""
if self.controller is None:
await self.send_error('No agent started.')
return
cur_state = self.controller.get_task_state()
new_state = TASK_STATE_ACTION_MAP.get(new_state_action)
if new_state is None:
await self.send_error('Invalid task state action.')
return
if cur_state in VALID_TASK_STATE_MAP.get(new_state_action, []):
await self.controller.set_task_state_to(new_state)
elif cur_state in IGNORED_TASK_STATE_MAP.get(new_state_action, []):
# notify once again.
await self.controller.notify_task_state_changed()
return
else:
await self.send_error('Current task state not recognized.')
return
if new_state_action == TaskStateAction.RESUME:
if self.agent_task:
self.agent_task.cancel()
self.agent_task = asyncio.create_task(
self.controller.resume(), name='agent resume task loop'
)
async def on_agent_event(self, event: Observation | Action):
async def on_event(self, event: Event):
"""Callback function for agent events.
Args:
@ -250,10 +151,10 @@ class AgentUnit:
return
if isinstance(event, NullObservation):
return
await self.send(event.to_dict())
if event.source == 'agent':
await self.send(event.to_dict())
return
def close(self):
if self.agent_task:
self.agent_task.cancel()
if self.controller is not None:
self.controller.action_manager.sandbox.close()
self.controller.close()

View File

@ -64,7 +64,7 @@ class MessageStack:
# Ignore assistant init message for now.
if 'action' in msg.payload and msg.payload['action'] in [
ActionType.INIT,
ActionType.CHANGE_TASK_STATE,
ActionType.CHANGE_AGENT_STATE,
]:
continue
cnt += 1

View File

@ -10,6 +10,7 @@ from opendevin.events.action.github import GitHubPushAction, GitHubSendPRAction
from opendevin.events.observation.commands import CmdOutputObservation
from opendevin.events.observation.error import AgentErrorObservation
from opendevin.events.observation.message import AgentMessageObservation
from opendevin.events.stream import EventStream
from opendevin.llm.llm import LLM
@ -19,7 +20,8 @@ def agent_controller():
config.config[ConfigType.SANDBOX_TYPE] = 'local'
llm = LLM()
agent = DummyAgent(llm=llm)
controller = AgentController(agent)
event_stream = EventStream()
controller = AgentController(agent, event_stream)
yield controller