From 32a3a0259acf6b8b241f83462c8a46a521686041 Mon Sep 17 00:00:00 2001 From: Robert Brennan Date: Fri, 29 Mar 2024 10:49:40 -0400 Subject: [PATCH] Serialization of Actions and Observations (#314) * checkout geohotstan work * merge session.py changes * add observation ids * ignore null actions and obs * add back action messages * fix lint --- opendevin/action/__init__.py | 21 +++++++++ opendevin/action/agent.py | 10 ++-- opendevin/action/base.py | 14 ++++-- opendevin/action/bash.py | 4 +- opendevin/action/browse.py | 3 +- opendevin/action/fileop.py | 4 +- opendevin/observation.py | 14 +++++- opendevin/server/session.py | 88 ++++++------------------------------ 8 files changed, 72 insertions(+), 86 deletions(-) diff --git a/opendevin/action/__init__.py b/opendevin/action/__init__.py index 9fe6f12354..b0b32f693b 100644 --- a/opendevin/action/__init__.py +++ b/opendevin/action/__init__.py @@ -4,6 +4,27 @@ from .browse import BrowseURLAction from .fileop import FileReadAction, FileWriteAction from .agent import AgentRecallAction, AgentThinkAction, AgentFinishAction, AgentEchoAction, AgentSummarizeAction +actions = ( + CmdKillAction, + CmdRunAction, + BrowseURLAction, + FileReadAction, + FileWriteAction, + AgentRecallAction, + AgentThinkAction, + AgentFinishAction +) + +ACTION_TYPE_TO_CLASS = {action_class.action:action_class for action_class in actions} # type: ignore[attr-defined] + +def action_class_initialize_dispatcher(action: str, *args: str, **kwargs: str) -> Action: + action_class = ACTION_TYPE_TO_CLASS.get(action) + if action_class is None: + raise KeyError(f"'{action=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}") + return action_class(*args, **kwargs) + +CLASS_TO_ACTION_TYPE = {v: k for k, v in ACTION_TYPE_TO_CLASS.items()} + __all__ = [ "Action", "NullAction", diff --git a/opendevin/action/agent.py b/opendevin/action/agent.py index 13a4d2ec7d..b0ea63d09c 100644 --- a/opendevin/action/agent.py +++ b/opendevin/action/agent.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: @dataclass class AgentRecallAction(ExecutableAction): query: str + action: str = "recall" def run(self, controller: "AgentController") -> AgentRecallObservation: return AgentRecallObservation( @@ -21,12 +22,11 @@ class AgentRecallAction(ExecutableAction): def message(self) -> str: return f"Let me dive into my memories to find what you're looking for! Searching for: '{self.query}'. This might take a moment." - - @dataclass class AgentThinkAction(NotExecutableAction): thought: str runnable: bool = False + action: str = "think" def run(self, controller: "AgentController") -> "Observation": raise NotImplementedError @@ -35,11 +35,11 @@ class AgentThinkAction(NotExecutableAction): def message(self) -> str: return self.thought - @dataclass class AgentEchoAction(ExecutableAction): content: str runnable: bool = True + action: str = "echo" def run(self, controller: "AgentController") -> "Observation": return AgentMessageObservation(self.content) @@ -52,6 +52,8 @@ class AgentEchoAction(ExecutableAction): class AgentSummarizeAction(NotExecutableAction): summary: str + action: str = "summarize" + @property def message(self) -> str: return self.summary @@ -59,6 +61,7 @@ class AgentSummarizeAction(NotExecutableAction): @dataclass class AgentFinishAction(NotExecutableAction): runnable: bool = False + action: str = "finish" def run(self, controller: "AgentController") -> "Observation": raise NotImplementedError @@ -66,4 +69,3 @@ class AgentFinishAction(NotExecutableAction): @property def message(self) -> str: return "All done! What's next on the agenda?" - diff --git a/opendevin/action/base.py b/opendevin/action/base.py index 4932a43f53..0d5d0765e1 100644 --- a/opendevin/action/base.py +++ b/opendevin/action/base.py @@ -1,5 +1,4 @@ -from dataclasses import dataclass - +from dataclasses import dataclass, asdict from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -12,7 +11,12 @@ class Action: raise NotImplementedError def to_dict(self): - return {"action": self.__class__.__name__, "args": self.__dict__, "message": self.message} + d = asdict(self) + try: + v = d.pop('action') + except KeyError: + raise NotImplementedError(f'{self=} does not have action attribute set') + return {'action': v, "args": d, "message": self.message} @property def executable(self) -> bool: @@ -24,21 +28,25 @@ class Action: +@dataclass class ExecutableAction(Action): @property def executable(self) -> bool: return True +@dataclass class NotExecutableAction(Action): @property def executable(self) -> bool: return False +@dataclass class NullAction(NotExecutableAction): """An action that does nothing. This is used when the agent need to receive user follow-up messages from the frontend. """ + action: str = "null" @property def message(self) -> str: diff --git a/opendevin/action/bash.py b/opendevin/action/bash.py index 4342e97078..c2a1e54fea 100644 --- a/opendevin/action/bash.py +++ b/opendevin/action/bash.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: class CmdRunAction(ExecutableAction): command: str background: bool = False + action: str = "run" def run(self, controller: "AgentController") -> "CmdOutputObservation": return controller.command_manager.run_command(self.command, self.background) @@ -23,10 +24,11 @@ class CmdRunAction(ExecutableAction): @dataclass class CmdKillAction(ExecutableAction): id: int + action: str = "kill" def run(self, controller: "AgentController") -> "CmdOutputObservation": return controller.command_manager.kill_command(self.id) @property def message(self) -> str: - return f"Killing command: {self.id}" + return f"Killing command: {self.id}" \ No newline at end of file diff --git a/opendevin/action/browse.py b/opendevin/action/browse.py index 57fd6f8246..7f60ab73d6 100644 --- a/opendevin/action/browse.py +++ b/opendevin/action/browse.py @@ -8,6 +8,7 @@ from .base import ExecutableAction @dataclass class BrowseURLAction(ExecutableAction): url: str + action: str = "browse" def run(self, *args, **kwargs) -> BrowserOutputObservation: try: @@ -26,4 +27,4 @@ class BrowseURLAction(ExecutableAction): @property def message(self) -> str: - return f"Browsing URL: {self.url}" + return f"Browsing URL: {self.url}" \ No newline at end of file diff --git a/opendevin/action/fileop.py b/opendevin/action/fileop.py index b3ee9a91d3..6c8280f3aa 100644 --- a/opendevin/action/fileop.py +++ b/opendevin/action/fileop.py @@ -17,6 +17,7 @@ def resolve_path(base_path, file_path): @dataclass class FileReadAction(ExecutableAction): path: str + action: str = "read" def run(self, controller) -> FileReadObservation: path = resolve_path(controller.workdir, self.path) @@ -29,11 +30,11 @@ class FileReadAction(ExecutableAction): def message(self) -> str: return f"Reading file: {self.path}" - @dataclass class FileWriteAction(ExecutableAction): path: str contents: str + action: str = "write" def run(self, controller) -> FileWriteObservation: path = resolve_path(controller.workdir, self.path) @@ -44,4 +45,3 @@ class FileWriteAction(ExecutableAction): @property def message(self) -> str: return f"Writing file: {self.path}" - diff --git a/opendevin/observation.py b/opendevin/observation.py index bdc149648e..7cee626048 100644 --- a/opendevin/observation.py +++ b/opendevin/observation.py @@ -18,8 +18,11 @@ class Observation: """Converts the observation to a dictionary.""" extras = copy.deepcopy(self.__dict__) extras.pop("content", None) + observation = "observation" + if hasattr(self, "observation"): + observation = self.observation return { - "observation": self.__class__.__name__, + "observation": observation, "content": self.content, "extras": extras, "message": self.message, @@ -40,6 +43,7 @@ class CmdOutputObservation(Observation): command_id: int command: str exit_code: int = 0 + observation : str = "run" @property def error(self) -> bool: @@ -56,6 +60,7 @@ class FileReadObservation(Observation): """ path: str + observation : str = "read" @property def message(self) -> str: @@ -68,6 +73,7 @@ class FileWriteObservation(Observation): """ path: str + observation : str = "write" @property def message(self) -> str: @@ -82,6 +88,7 @@ class BrowserOutputObservation(Observation): url: str status_code: int = 200 error: bool = False + observation : str = "browse" @property def message(self) -> str: @@ -95,6 +102,7 @@ class UserMessageObservation(Observation): """ role: str = "user" + observation : str = "message" @property def message(self) -> str: @@ -108,6 +116,7 @@ class AgentMessageObservation(Observation): """ role: str = "assistant" + observation : str = "message" @property def message(self) -> str: @@ -122,6 +131,7 @@ class AgentRecallObservation(Observation): memories: List[str] role: str = "assistant" + observation : str = "recall" @property def message(self) -> str: @@ -133,6 +143,7 @@ class AgentErrorObservation(Observation): """ This data class represents an error encountered by the agent. """ + observation : str = "error" @property def message(self) -> str: @@ -144,6 +155,7 @@ class NullObservation(Observation): This data class represents a null observation. This is used when the produced action is NOT executable. """ + observation : str = "null" @property def message(self) -> str: diff --git a/opendevin/server/session.py b/opendevin/server/session.py index 77651753a1..145d4cce05 100644 --- a/opendevin/server/session.py +++ b/opendevin/server/session.py @@ -1,58 +1,22 @@ import asyncio import os -from typing import Dict, Optional, Type +from typing import Optional from fastapi import WebSocketDisconnect from opendevin.action import ( Action, - AgentFinishAction, - AgentRecallAction, - AgentThinkAction, - BrowseURLAction, - CmdKillAction, - CmdRunAction, - FileReadAction, - FileWriteAction, NullAction, ) +from opendevin.observation import NullObservation from opendevin.agent import Agent from opendevin.controller import AgentController from opendevin.llm.llm import LLM from opendevin.observation import Observation, UserMessageObservation -# NOTE: this is a temporary solution - but hopefully we can use Action/Observation throughout the codebase -ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = { - "run": CmdRunAction, - "kill": CmdKillAction, - "browse": BrowseURLAction, - "read": FileReadAction, - "write": FileWriteAction, - "recall": AgentRecallAction, - "think": AgentThinkAction, - "finish": AgentFinishAction, -} - - DEFAULT_WORKSPACE_DIR = os.getenv("WORKSPACE_DIR", os.path.join(os.getcwd(), "workspace")) LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4-0125-preview") -def parse_event(data): - if "action" not in data: - return None - action = data["action"] - args = {} - if "args" in data: - args = data["args"] - message = None - if "message" in data: - message = data["message"] - return { - "action": action, - "args": args, - "message": message, - } - class Session: def __init__(self, websocket): self.websocket = websocket @@ -84,20 +48,20 @@ class Session: await self.send_error("Invalid JSON") continue - event = parse_event(data) - if event is None: + action = data.get("action", None) + if action is None: await self.send_error("Invalid event") continue - if event["action"] == "initialize": - await self.create_controller(event) - elif event["action"] == "start": - await self.start_task(event) + if action == "initialize": + await self.create_controller(data) + elif action == "start": + await self.start_task(data) else: if self.controller is None: await self.send_error("No agent started. Please wait a second...") - elif event["action"] == "chat": - self.controller.add_history(NullAction(), UserMessageObservation(event["message"])) + elif action == "chat": + self.controller.add_history(NullAction(), UserMessageObservation(data["message"])) else: # TODO: we only need to implement user message for now # since even Devin does not support having the user taking other @@ -147,33 +111,9 @@ class Session: self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop") def on_agent_event(self, event: Observation | Action): - # FIXME: we need better serialization + if isinstance(event, NullAction): + return + if isinstance(event, NullObservation): + return event_dict = event.to_dict() - if "action" in event_dict: - if event_dict["action"] == "CmdRunAction": - event_dict["action"] = "run" - elif event_dict["action"] == "CmdKillAction": - event_dict["action"] = "kill" - elif event_dict["action"] == "BrowseURLAction": - event_dict["action"] = "browse" - elif event_dict["action"] == "FileReadAction": - event_dict["action"] = "read" - elif event_dict["action"] == "FileWriteAction": - event_dict["action"] = "write" - elif event_dict["action"] == "AgentFinishAction": - event_dict["action"] = "finish" - elif event_dict["action"] == "AgentRecallAction": - event_dict["action"] = "recall" - elif event_dict["action"] == "AgentThinkAction": - event_dict["action"] = "think" - if "observation" in event_dict: - if event_dict["observation"] == "UserMessageObservation": - event_dict["observation"] = "chat" - elif event_dict["observation"] == "AgentMessageObservation": - event_dict["observation"] = "chat" - elif event_dict["observation"] == "CmdOutputObservation": - event_dict["observation"] = "run" - elif event_dict["observation"] == "FileReadObservation": - event_dict["observation"] = "read" - asyncio.create_task(self.send(event_dict), name="send event in callback")