Serialization of Actions and Observations (#314)

* checkout geohotstan work

* merge session.py changes

* add observation ids

* ignore null actions and obs

* add back action messages

* fix lint
This commit is contained in:
Robert Brennan
2024-03-29 10:49:40 -04:00
committed by GitHub
parent bbc51c858d
commit 32a3a0259a
8 changed files with 72 additions and 86 deletions

View File

@@ -4,6 +4,27 @@ from .browse import BrowseURLAction
from .fileop import FileReadAction, FileWriteAction
from .agent import AgentRecallAction, AgentThinkAction, AgentFinishAction, AgentEchoAction, AgentSummarizeAction
actions = (
CmdKillAction,
CmdRunAction,
BrowseURLAction,
FileReadAction,
FileWriteAction,
AgentRecallAction,
AgentThinkAction,
AgentFinishAction
)
ACTION_TYPE_TO_CLASS = {action_class.action:action_class for action_class in actions} # type: ignore[attr-defined]
def action_class_initialize_dispatcher(action: str, *args: str, **kwargs: str) -> Action:
action_class = ACTION_TYPE_TO_CLASS.get(action)
if action_class is None:
raise KeyError(f"'{action=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}")
return action_class(*args, **kwargs)
CLASS_TO_ACTION_TYPE = {v: k for k, v in ACTION_TYPE_TO_CLASS.items()}
__all__ = [
"Action",
"NullAction",

View File

@@ -10,6 +10,7 @@ if TYPE_CHECKING:
@dataclass
class AgentRecallAction(ExecutableAction):
query: str
action: str = "recall"
def run(self, controller: "AgentController") -> AgentRecallObservation:
return AgentRecallObservation(
@@ -21,12 +22,11 @@ class AgentRecallAction(ExecutableAction):
def message(self) -> str:
return f"Let me dive into my memories to find what you're looking for! Searching for: '{self.query}'. This might take a moment."
@dataclass
class AgentThinkAction(NotExecutableAction):
thought: str
runnable: bool = False
action: str = "think"
def run(self, controller: "AgentController") -> "Observation":
raise NotImplementedError
@@ -35,11 +35,11 @@ class AgentThinkAction(NotExecutableAction):
def message(self) -> str:
return self.thought
@dataclass
class AgentEchoAction(ExecutableAction):
content: str
runnable: bool = True
action: str = "echo"
def run(self, controller: "AgentController") -> "Observation":
return AgentMessageObservation(self.content)
@@ -52,6 +52,8 @@ class AgentEchoAction(ExecutableAction):
class AgentSummarizeAction(NotExecutableAction):
summary: str
action: str = "summarize"
@property
def message(self) -> str:
return self.summary
@@ -59,6 +61,7 @@ class AgentSummarizeAction(NotExecutableAction):
@dataclass
class AgentFinishAction(NotExecutableAction):
runnable: bool = False
action: str = "finish"
def run(self, controller: "AgentController") -> "Observation":
raise NotImplementedError
@@ -66,4 +69,3 @@ class AgentFinishAction(NotExecutableAction):
@property
def message(self) -> str:
return "All done! What's next on the agenda?"

View File

@@ -1,5 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -12,7 +11,12 @@ class Action:
raise NotImplementedError
def to_dict(self):
return {"action": self.__class__.__name__, "args": self.__dict__, "message": self.message}
d = asdict(self)
try:
v = d.pop('action')
except KeyError:
raise NotImplementedError(f'{self=} does not have action attribute set')
return {'action': v, "args": d, "message": self.message}
@property
def executable(self) -> bool:
@@ -24,21 +28,25 @@ class Action:
@dataclass
class ExecutableAction(Action):
@property
def executable(self) -> bool:
return True
@dataclass
class NotExecutableAction(Action):
@property
def executable(self) -> bool:
return False
@dataclass
class NullAction(NotExecutableAction):
"""An action that does nothing.
This is used when the agent need to receive user follow-up messages from the frontend.
"""
action: str = "null"
@property
def message(self) -> str:

View File

@@ -12,6 +12,7 @@ if TYPE_CHECKING:
class CmdRunAction(ExecutableAction):
command: str
background: bool = False
action: str = "run"
def run(self, controller: "AgentController") -> "CmdOutputObservation":
return controller.command_manager.run_command(self.command, self.background)
@@ -23,10 +24,11 @@ class CmdRunAction(ExecutableAction):
@dataclass
class CmdKillAction(ExecutableAction):
id: int
action: str = "kill"
def run(self, controller: "AgentController") -> "CmdOutputObservation":
return controller.command_manager.kill_command(self.id)
@property
def message(self) -> str:
return f"Killing command: {self.id}"
return f"Killing command: {self.id}"

View File

@@ -8,6 +8,7 @@ from .base import ExecutableAction
@dataclass
class BrowseURLAction(ExecutableAction):
url: str
action: str = "browse"
def run(self, *args, **kwargs) -> BrowserOutputObservation:
try:
@@ -26,4 +27,4 @@ class BrowseURLAction(ExecutableAction):
@property
def message(self) -> str:
return f"Browsing URL: {self.url}"
return f"Browsing URL: {self.url}"

View File

@@ -17,6 +17,7 @@ def resolve_path(base_path, file_path):
@dataclass
class FileReadAction(ExecutableAction):
path: str
action: str = "read"
def run(self, controller) -> FileReadObservation:
path = resolve_path(controller.workdir, self.path)
@@ -29,11 +30,11 @@ class FileReadAction(ExecutableAction):
def message(self) -> str:
return f"Reading file: {self.path}"
@dataclass
class FileWriteAction(ExecutableAction):
path: str
contents: str
action: str = "write"
def run(self, controller) -> FileWriteObservation:
path = resolve_path(controller.workdir, self.path)
@@ -44,4 +45,3 @@ class FileWriteAction(ExecutableAction):
@property
def message(self) -> str:
return f"Writing file: {self.path}"

View File

@@ -18,8 +18,11 @@ class Observation:
"""Converts the observation to a dictionary."""
extras = copy.deepcopy(self.__dict__)
extras.pop("content", None)
observation = "observation"
if hasattr(self, "observation"):
observation = self.observation
return {
"observation": self.__class__.__name__,
"observation": observation,
"content": self.content,
"extras": extras,
"message": self.message,
@@ -40,6 +43,7 @@ class CmdOutputObservation(Observation):
command_id: int
command: str
exit_code: int = 0
observation : str = "run"
@property
def error(self) -> bool:
@@ -56,6 +60,7 @@ class FileReadObservation(Observation):
"""
path: str
observation : str = "read"
@property
def message(self) -> str:
@@ -68,6 +73,7 @@ class FileWriteObservation(Observation):
"""
path: str
observation : str = "write"
@property
def message(self) -> str:
@@ -82,6 +88,7 @@ class BrowserOutputObservation(Observation):
url: str
status_code: int = 200
error: bool = False
observation : str = "browse"
@property
def message(self) -> str:
@@ -95,6 +102,7 @@ class UserMessageObservation(Observation):
"""
role: str = "user"
observation : str = "message"
@property
def message(self) -> str:
@@ -108,6 +116,7 @@ class AgentMessageObservation(Observation):
"""
role: str = "assistant"
observation : str = "message"
@property
def message(self) -> str:
@@ -122,6 +131,7 @@ class AgentRecallObservation(Observation):
memories: List[str]
role: str = "assistant"
observation : str = "recall"
@property
def message(self) -> str:
@@ -133,6 +143,7 @@ class AgentErrorObservation(Observation):
"""
This data class represents an error encountered by the agent.
"""
observation : str = "error"
@property
def message(self) -> str:
@@ -144,6 +155,7 @@ class NullObservation(Observation):
This data class represents a null observation.
This is used when the produced action is NOT executable.
"""
observation : str = "null"
@property
def message(self) -> str:

View File

@@ -1,58 +1,22 @@
import asyncio
import os
from typing import Dict, Optional, Type
from typing import Optional
from fastapi import WebSocketDisconnect
from opendevin.action import (
Action,
AgentFinishAction,
AgentRecallAction,
AgentThinkAction,
BrowseURLAction,
CmdKillAction,
CmdRunAction,
FileReadAction,
FileWriteAction,
NullAction,
)
from opendevin.observation import NullObservation
from opendevin.agent import Agent
from opendevin.controller import AgentController
from opendevin.llm.llm import LLM
from opendevin.observation import Observation, UserMessageObservation
# NOTE: this is a temporary solution - but hopefully we can use Action/Observation throughout the codebase
ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = {
"run": CmdRunAction,
"kill": CmdKillAction,
"browse": BrowseURLAction,
"read": FileReadAction,
"write": FileWriteAction,
"recall": AgentRecallAction,
"think": AgentThinkAction,
"finish": AgentFinishAction,
}
DEFAULT_WORKSPACE_DIR = os.getenv("WORKSPACE_DIR", os.path.join(os.getcwd(), "workspace"))
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4-0125-preview")
def parse_event(data):
if "action" not in data:
return None
action = data["action"]
args = {}
if "args" in data:
args = data["args"]
message = None
if "message" in data:
message = data["message"]
return {
"action": action,
"args": args,
"message": message,
}
class Session:
def __init__(self, websocket):
self.websocket = websocket
@@ -84,20 +48,20 @@ class Session:
await self.send_error("Invalid JSON")
continue
event = parse_event(data)
if event is None:
action = data.get("action", None)
if action is None:
await self.send_error("Invalid event")
continue
if event["action"] == "initialize":
await self.create_controller(event)
elif event["action"] == "start":
await self.start_task(event)
if action == "initialize":
await self.create_controller(data)
elif action == "start":
await self.start_task(data)
else:
if self.controller is None:
await self.send_error("No agent started. Please wait a second...")
elif event["action"] == "chat":
self.controller.add_history(NullAction(), UserMessageObservation(event["message"]))
elif action == "chat":
self.controller.add_history(NullAction(), UserMessageObservation(data["message"]))
else:
# TODO: we only need to implement user message for now
# since even Devin does not support having the user taking other
@@ -147,33 +111,9 @@ class Session:
self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop")
def on_agent_event(self, event: Observation | Action):
# FIXME: we need better serialization
if isinstance(event, NullAction):
return
if isinstance(event, NullObservation):
return
event_dict = event.to_dict()
if "action" in event_dict:
if event_dict["action"] == "CmdRunAction":
event_dict["action"] = "run"
elif event_dict["action"] == "CmdKillAction":
event_dict["action"] = "kill"
elif event_dict["action"] == "BrowseURLAction":
event_dict["action"] = "browse"
elif event_dict["action"] == "FileReadAction":
event_dict["action"] = "read"
elif event_dict["action"] == "FileWriteAction":
event_dict["action"] = "write"
elif event_dict["action"] == "AgentFinishAction":
event_dict["action"] = "finish"
elif event_dict["action"] == "AgentRecallAction":
event_dict["action"] = "recall"
elif event_dict["action"] == "AgentThinkAction":
event_dict["action"] = "think"
if "observation" in event_dict:
if event_dict["observation"] == "UserMessageObservation":
event_dict["observation"] = "chat"
elif event_dict["observation"] == "AgentMessageObservation":
event_dict["observation"] = "chat"
elif event_dict["observation"] == "CmdOutputObservation":
event_dict["observation"] = "run"
elif event_dict["observation"] == "FileReadObservation":
event_dict["observation"] = "read"
asyncio.create_task(self.send(event_dict), name="send event in callback")