mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Serialization of Actions and Observations (#314)
* checkout geohotstan work * merge session.py changes * add observation ids * ignore null actions and obs * add back action messages * fix lint
This commit is contained in:
@@ -4,6 +4,27 @@ from .browse import BrowseURLAction
|
||||
from .fileop import FileReadAction, FileWriteAction
|
||||
from .agent import AgentRecallAction, AgentThinkAction, AgentFinishAction, AgentEchoAction, AgentSummarizeAction
|
||||
|
||||
actions = (
|
||||
CmdKillAction,
|
||||
CmdRunAction,
|
||||
BrowseURLAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
AgentRecallAction,
|
||||
AgentThinkAction,
|
||||
AgentFinishAction
|
||||
)
|
||||
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action:action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
|
||||
def action_class_initialize_dispatcher(action: str, *args: str, **kwargs: str) -> Action:
|
||||
action_class = ACTION_TYPE_TO_CLASS.get(action)
|
||||
if action_class is None:
|
||||
raise KeyError(f"'{action=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}")
|
||||
return action_class(*args, **kwargs)
|
||||
|
||||
CLASS_TO_ACTION_TYPE = {v: k for k, v in ACTION_TYPE_TO_CLASS.items()}
|
||||
|
||||
__all__ = [
|
||||
"Action",
|
||||
"NullAction",
|
||||
|
||||
@@ -10,6 +10,7 @@ if TYPE_CHECKING:
|
||||
@dataclass
|
||||
class AgentRecallAction(ExecutableAction):
|
||||
query: str
|
||||
action: str = "recall"
|
||||
|
||||
def run(self, controller: "AgentController") -> AgentRecallObservation:
|
||||
return AgentRecallObservation(
|
||||
@@ -21,12 +22,11 @@ class AgentRecallAction(ExecutableAction):
|
||||
def message(self) -> str:
|
||||
return f"Let me dive into my memories to find what you're looking for! Searching for: '{self.query}'. This might take a moment."
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentThinkAction(NotExecutableAction):
|
||||
thought: str
|
||||
runnable: bool = False
|
||||
action: str = "think"
|
||||
|
||||
def run(self, controller: "AgentController") -> "Observation":
|
||||
raise NotImplementedError
|
||||
@@ -35,11 +35,11 @@ class AgentThinkAction(NotExecutableAction):
|
||||
def message(self) -> str:
|
||||
return self.thought
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentEchoAction(ExecutableAction):
|
||||
content: str
|
||||
runnable: bool = True
|
||||
action: str = "echo"
|
||||
|
||||
def run(self, controller: "AgentController") -> "Observation":
|
||||
return AgentMessageObservation(self.content)
|
||||
@@ -52,6 +52,8 @@ class AgentEchoAction(ExecutableAction):
|
||||
class AgentSummarizeAction(NotExecutableAction):
|
||||
summary: str
|
||||
|
||||
action: str = "summarize"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.summary
|
||||
@@ -59,6 +61,7 @@ class AgentSummarizeAction(NotExecutableAction):
|
||||
@dataclass
|
||||
class AgentFinishAction(NotExecutableAction):
|
||||
runnable: bool = False
|
||||
action: str = "finish"
|
||||
|
||||
def run(self, controller: "AgentController") -> "Observation":
|
||||
raise NotImplementedError
|
||||
@@ -66,4 +69,3 @@ class AgentFinishAction(NotExecutableAction):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return "All done! What's next on the agenda?"
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -12,7 +11,12 @@ class Action:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_dict(self):
|
||||
return {"action": self.__class__.__name__, "args": self.__dict__, "message": self.message}
|
||||
d = asdict(self)
|
||||
try:
|
||||
v = d.pop('action')
|
||||
except KeyError:
|
||||
raise NotImplementedError(f'{self=} does not have action attribute set')
|
||||
return {'action': v, "args": d, "message": self.message}
|
||||
|
||||
@property
|
||||
def executable(self) -> bool:
|
||||
@@ -24,21 +28,25 @@ class Action:
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutableAction(Action):
|
||||
@property
|
||||
def executable(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotExecutableAction(Action):
|
||||
@property
|
||||
def executable(self) -> bool:
|
||||
return False
|
||||
|
||||
@dataclass
|
||||
class NullAction(NotExecutableAction):
|
||||
"""An action that does nothing.
|
||||
This is used when the agent need to receive user follow-up messages from the frontend.
|
||||
"""
|
||||
action: str = "null"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
|
||||
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
||||
class CmdRunAction(ExecutableAction):
|
||||
command: str
|
||||
background: bool = False
|
||||
action: str = "run"
|
||||
|
||||
def run(self, controller: "AgentController") -> "CmdOutputObservation":
|
||||
return controller.command_manager.run_command(self.command, self.background)
|
||||
@@ -23,10 +24,11 @@ class CmdRunAction(ExecutableAction):
|
||||
@dataclass
|
||||
class CmdKillAction(ExecutableAction):
|
||||
id: int
|
||||
action: str = "kill"
|
||||
|
||||
def run(self, controller: "AgentController") -> "CmdOutputObservation":
|
||||
return controller.command_manager.kill_command(self.id)
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f"Killing command: {self.id}"
|
||||
return f"Killing command: {self.id}"
|
||||
@@ -8,6 +8,7 @@ from .base import ExecutableAction
|
||||
@dataclass
|
||||
class BrowseURLAction(ExecutableAction):
|
||||
url: str
|
||||
action: str = "browse"
|
||||
|
||||
def run(self, *args, **kwargs) -> BrowserOutputObservation:
|
||||
try:
|
||||
@@ -26,4 +27,4 @@ class BrowseURLAction(ExecutableAction):
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f"Browsing URL: {self.url}"
|
||||
return f"Browsing URL: {self.url}"
|
||||
@@ -17,6 +17,7 @@ def resolve_path(base_path, file_path):
|
||||
@dataclass
|
||||
class FileReadAction(ExecutableAction):
|
||||
path: str
|
||||
action: str = "read"
|
||||
|
||||
def run(self, controller) -> FileReadObservation:
|
||||
path = resolve_path(controller.workdir, self.path)
|
||||
@@ -29,11 +30,11 @@ class FileReadAction(ExecutableAction):
|
||||
def message(self) -> str:
|
||||
return f"Reading file: {self.path}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileWriteAction(ExecutableAction):
|
||||
path: str
|
||||
contents: str
|
||||
action: str = "write"
|
||||
|
||||
def run(self, controller) -> FileWriteObservation:
|
||||
path = resolve_path(controller.workdir, self.path)
|
||||
@@ -44,4 +45,3 @@ class FileWriteAction(ExecutableAction):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f"Writing file: {self.path}"
|
||||
|
||||
|
||||
@@ -18,8 +18,11 @@ class Observation:
|
||||
"""Converts the observation to a dictionary."""
|
||||
extras = copy.deepcopy(self.__dict__)
|
||||
extras.pop("content", None)
|
||||
observation = "observation"
|
||||
if hasattr(self, "observation"):
|
||||
observation = self.observation
|
||||
return {
|
||||
"observation": self.__class__.__name__,
|
||||
"observation": observation,
|
||||
"content": self.content,
|
||||
"extras": extras,
|
||||
"message": self.message,
|
||||
@@ -40,6 +43,7 @@ class CmdOutputObservation(Observation):
|
||||
command_id: int
|
||||
command: str
|
||||
exit_code: int = 0
|
||||
observation : str = "run"
|
||||
|
||||
@property
|
||||
def error(self) -> bool:
|
||||
@@ -56,6 +60,7 @@ class FileReadObservation(Observation):
|
||||
"""
|
||||
|
||||
path: str
|
||||
observation : str = "read"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
@@ -68,6 +73,7 @@ class FileWriteObservation(Observation):
|
||||
"""
|
||||
|
||||
path: str
|
||||
observation : str = "write"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
@@ -82,6 +88,7 @@ class BrowserOutputObservation(Observation):
|
||||
url: str
|
||||
status_code: int = 200
|
||||
error: bool = False
|
||||
observation : str = "browse"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
@@ -95,6 +102,7 @@ class UserMessageObservation(Observation):
|
||||
"""
|
||||
|
||||
role: str = "user"
|
||||
observation : str = "message"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
@@ -108,6 +116,7 @@ class AgentMessageObservation(Observation):
|
||||
"""
|
||||
|
||||
role: str = "assistant"
|
||||
observation : str = "message"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
@@ -122,6 +131,7 @@ class AgentRecallObservation(Observation):
|
||||
|
||||
memories: List[str]
|
||||
role: str = "assistant"
|
||||
observation : str = "recall"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
@@ -133,6 +143,7 @@ class AgentErrorObservation(Observation):
|
||||
"""
|
||||
This data class represents an error encountered by the agent.
|
||||
"""
|
||||
observation : str = "error"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
@@ -144,6 +155,7 @@ class NullObservation(Observation):
|
||||
This data class represents a null observation.
|
||||
This is used when the produced action is NOT executable.
|
||||
"""
|
||||
observation : str = "null"
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
|
||||
@@ -1,58 +1,22 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import WebSocketDisconnect
|
||||
|
||||
from opendevin.action import (
|
||||
Action,
|
||||
AgentFinishAction,
|
||||
AgentRecallAction,
|
||||
AgentThinkAction,
|
||||
BrowseURLAction,
|
||||
CmdKillAction,
|
||||
CmdRunAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
NullAction,
|
||||
)
|
||||
from opendevin.observation import NullObservation
|
||||
from opendevin.agent import Agent
|
||||
from opendevin.controller import AgentController
|
||||
from opendevin.llm.llm import LLM
|
||||
from opendevin.observation import Observation, UserMessageObservation
|
||||
|
||||
# NOTE: this is a temporary solution - but hopefully we can use Action/Observation throughout the codebase
|
||||
ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = {
|
||||
"run": CmdRunAction,
|
||||
"kill": CmdKillAction,
|
||||
"browse": BrowseURLAction,
|
||||
"read": FileReadAction,
|
||||
"write": FileWriteAction,
|
||||
"recall": AgentRecallAction,
|
||||
"think": AgentThinkAction,
|
||||
"finish": AgentFinishAction,
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_WORKSPACE_DIR = os.getenv("WORKSPACE_DIR", os.path.join(os.getcwd(), "workspace"))
|
||||
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4-0125-preview")
|
||||
|
||||
def parse_event(data):
|
||||
if "action" not in data:
|
||||
return None
|
||||
action = data["action"]
|
||||
args = {}
|
||||
if "args" in data:
|
||||
args = data["args"]
|
||||
message = None
|
||||
if "message" in data:
|
||||
message = data["message"]
|
||||
return {
|
||||
"action": action,
|
||||
"args": args,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
class Session:
|
||||
def __init__(self, websocket):
|
||||
self.websocket = websocket
|
||||
@@ -84,20 +48,20 @@ class Session:
|
||||
await self.send_error("Invalid JSON")
|
||||
continue
|
||||
|
||||
event = parse_event(data)
|
||||
if event is None:
|
||||
action = data.get("action", None)
|
||||
if action is None:
|
||||
await self.send_error("Invalid event")
|
||||
continue
|
||||
if event["action"] == "initialize":
|
||||
await self.create_controller(event)
|
||||
elif event["action"] == "start":
|
||||
await self.start_task(event)
|
||||
if action == "initialize":
|
||||
await self.create_controller(data)
|
||||
elif action == "start":
|
||||
await self.start_task(data)
|
||||
else:
|
||||
if self.controller is None:
|
||||
await self.send_error("No agent started. Please wait a second...")
|
||||
|
||||
elif event["action"] == "chat":
|
||||
self.controller.add_history(NullAction(), UserMessageObservation(event["message"]))
|
||||
elif action == "chat":
|
||||
self.controller.add_history(NullAction(), UserMessageObservation(data["message"]))
|
||||
else:
|
||||
# TODO: we only need to implement user message for now
|
||||
# since even Devin does not support having the user taking other
|
||||
@@ -147,33 +111,9 @@ class Session:
|
||||
self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop")
|
||||
|
||||
def on_agent_event(self, event: Observation | Action):
|
||||
# FIXME: we need better serialization
|
||||
if isinstance(event, NullAction):
|
||||
return
|
||||
if isinstance(event, NullObservation):
|
||||
return
|
||||
event_dict = event.to_dict()
|
||||
if "action" in event_dict:
|
||||
if event_dict["action"] == "CmdRunAction":
|
||||
event_dict["action"] = "run"
|
||||
elif event_dict["action"] == "CmdKillAction":
|
||||
event_dict["action"] = "kill"
|
||||
elif event_dict["action"] == "BrowseURLAction":
|
||||
event_dict["action"] = "browse"
|
||||
elif event_dict["action"] == "FileReadAction":
|
||||
event_dict["action"] = "read"
|
||||
elif event_dict["action"] == "FileWriteAction":
|
||||
event_dict["action"] = "write"
|
||||
elif event_dict["action"] == "AgentFinishAction":
|
||||
event_dict["action"] = "finish"
|
||||
elif event_dict["action"] == "AgentRecallAction":
|
||||
event_dict["action"] = "recall"
|
||||
elif event_dict["action"] == "AgentThinkAction":
|
||||
event_dict["action"] = "think"
|
||||
if "observation" in event_dict:
|
||||
if event_dict["observation"] == "UserMessageObservation":
|
||||
event_dict["observation"] = "chat"
|
||||
elif event_dict["observation"] == "AgentMessageObservation":
|
||||
event_dict["observation"] = "chat"
|
||||
elif event_dict["observation"] == "CmdOutputObservation":
|
||||
event_dict["observation"] = "run"
|
||||
elif event_dict["observation"] == "FileReadObservation":
|
||||
event_dict["observation"] = "read"
|
||||
|
||||
asyncio.create_task(self.send(event_dict), name="send event in callback")
|
||||
|
||||
Reference in New Issue
Block a user