Implement deserialization for actions and observations (#359)

* action deserializing

* add observation deserialization

* add tests

* refactor agents with serialization

* fix some errors

* fix lint

* fix json parser
This commit is contained in:
Robert Brennan 2024-03-30 10:06:25 -04:00 committed by GitHub
parent f68ee45761
commit effac868c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 318 additions and 328 deletions

View File

@ -7,22 +7,6 @@ import agenthub.langchains_agent.utils.prompts as prompts
from agenthub.langchains_agent.utils.monologue import Monologue
from agenthub.langchains_agent.utils.memory import LongTermMemory
from opendevin.action import (
NullAction,
CmdRunAction,
CmdKillAction,
BrowseURLAction,
FileReadAction,
FileWriteAction,
AgentRecallAction,
AgentThinkAction,
AgentFinishAction,
)
from opendevin.observation import (
CmdOutputObservation,
)
MAX_MONOLOGUE_LENGTH = 20000
MAX_OUTPUT_LENGTH = 5000
@ -81,7 +65,7 @@ class LangchainsAgent(Agent):
self.memory = LongTermMemory()
def _add_event(self, event: dict):
if 'output' in event['args'] and len(event['args']['output']) > MAX_OUTPUT_LENGTH:
if 'args' in event and 'output' in event['args'] and len(event['args']['output']) > MAX_OUTPUT_LENGTH:
event['args']['output'] = event['args']['output'][:MAX_OUTPUT_LENGTH] + "..."
self.monologue.add_event(event)
@ -136,45 +120,9 @@ class LangchainsAgent(Agent):
def step(self, state: State) -> Action:
self._initialize(state.plan.main_goal)
# TODO: make langchains agent use Action & Observation
# completly from ground up
# Translate state to action_dict
for prev_action, obs in state.updated_info:
d = None
if isinstance(obs, CmdOutputObservation):
if obs.error:
d = {"action": "error", "args": {"output": obs.content}}
else:
d = {"action": "output", "args": {"output": obs.content}}
else:
d = {"action": "output", "args": {"output": obs.content}}
if d is not None:
self._add_event(d)
d = None
if isinstance(prev_action, CmdRunAction):
d = {"action": "run", "args": {"command": prev_action.command}}
elif isinstance(prev_action, CmdKillAction):
d = {"action": "kill", "args": {"id": prev_action.id}}
elif isinstance(prev_action, BrowseURLAction):
d = {"action": "browse", "args": {"url": prev_action.url}}
elif isinstance(prev_action, FileReadAction):
d = {"action": "read", "args": {"file": prev_action.path}}
elif isinstance(prev_action, FileWriteAction):
d = {"action": "write", "args": {"file": prev_action.path, "content": prev_action.contents}}
elif isinstance(prev_action, AgentRecallAction):
d = {"action": "recall", "args": {"query": prev_action.query}}
elif isinstance(prev_action, AgentThinkAction):
d = {"action": "think", "args": {"thought": prev_action.thought}}
elif isinstance(prev_action, AgentFinishAction):
d = {"action": "finish"}
elif isinstance(prev_action, NullAction):
d = None
else:
raise ValueError(f"Unknown action type: {prev_action}")
if d is not None:
self._add_event(d)
self._add_event(prev_action.to_dict())
self._add_event(obs.to_dict())
state.updated_info = []

View File

@ -6,3 +6,7 @@ def my_encoder(obj):
def dumps(obj, **kwargs):
return json.dumps(obj, default=my_encoder, **kwargs)
def loads(s, **kwargs):
return json.loads(s, **kwargs)

View File

@ -48,11 +48,20 @@ class LongTermMemory:
self.thought_idx = 0
def add_event(self, event):
id = ""
t = ""
if "action" in event:
t = "action"
id = event["action"]
elif "observation" in event:
t = "observation"
id = event["observation"]
doc = Document(
text=json.dumps(event),
doc_id=str(self.thought_idx),
extra_info={
"type": event["action"],
"type": t,
"id": id,
"idx": self.thought_idx,
},
)

View File

@ -1,8 +1,6 @@
from typing import List, Dict, Type
from typing import List
from langchain_core.pydantic_v1 import BaseModel
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from opendevin import config
@ -13,35 +11,13 @@ if config.get_or_default("DEBUG", False):
from . import json
from opendevin.action import (
action_from_dict,
Action,
CmdRunAction,
CmdKillAction,
BrowseURLAction,
FileReadAction,
FileWriteAction,
AgentRecallAction,
AgentThinkAction,
AgentFinishAction,
AgentSummarizeAction,
)
from opendevin.observation import (
CmdOutputObservation,
)
ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = {
"run": CmdRunAction,
"kill": CmdKillAction,
"browse": BrowseURLAction,
"read": FileReadAction,
"write": FileWriteAction,
"recall": AgentRecallAction,
"think": AgentThinkAction,
"summarize": AgentSummarizeAction,
"finish": AgentFinishAction,
}
CLASS_TO_ACTION_TYPE: Dict[Type[Action], str] = {v: k for k, v in ACTION_TYPE_TO_CLASS.items()}
ACTION_PROMPT = """
You're a thoughtful robot. Your main task is to {task}.
Don't expand the scope of your task--just complete it as written.
@ -116,15 +92,6 @@ You can also use the same action and args from the source monologue.
"""
class _ActionDict(BaseModel):
action: str
args: dict
class NewMonologue(BaseModel):
new_monologue: List[_ActionDict]
def get_summarize_monologue_prompt(thoughts):
prompt = PromptTemplate.from_template(MONOLOGUE_SUMMARY_PROMPT)
return prompt.format(monologue=json.dumps({'old_monologue': thoughts}, indent=2))
@ -137,13 +104,14 @@ def get_request_action_prompt(
hint = ''
if len(thoughts) > 0:
latest_thought = thoughts[-1]
if latest_thought["action"] == 'think':
if latest_thought["args"]['thought'].startswith("OK so my task is"):
hint = "You're just getting started! What should you do first?"
else:
hint = "You've been thinking a lot lately. Maybe it's time to take action?"
elif latest_thought["action"] == 'error':
hint = "Looks like that last command failed. Maybe you need to fix it, or try something else."
if "action" in latest_thought:
if latest_thought["action"] == 'think':
if latest_thought["args"]['thought'].startswith("OK so my task is"):
hint = "You're just getting started! What should you do first?"
else:
hint = "You've been thinking a lot lately. Maybe it's time to take action?"
elif latest_thought["action"] == 'error':
hint = "Looks like that last command failed. Maybe you need to fix it, or try something else."
bg_commands_message = ""
if len(background_commands_obs) > 0:
@ -162,17 +130,15 @@ def get_request_action_prompt(
)
def parse_action_response(response: str) -> Action:
parser = JsonOutputParser(pydantic_object=_ActionDict)
action_dict = parser.parse(response)
json_start = response.find("{")
json_end = response.rfind("}") + 1
response = response[json_start:json_end]
action_dict = json.loads(response)
if 'content' in action_dict:
# The LLM gets confused here. Might as well be robust
action_dict['contents'] = action_dict.pop('content')
return action_from_dict(action_dict)
action = ACTION_TYPE_TO_CLASS[action_dict["action"]](**action_dict["args"])
return action
def parse_summary_response(response: str) -> List[Action]:
parser = JsonOutputParser(pydantic_object=NewMonologue)
parsed = parser.parse(response)
#thoughts = [ACTION_TYPE_TO_CLASS[t['action']](**t['args']) for t in parsed['new_monologue']]
def parse_summary_response(response: str) -> List[dict]:
parsed = json.loads(response)
return parsed['new_monologue']

View File

@ -3,7 +3,7 @@ from typing import List, Tuple, Dict, Type
from opendevin.controller.agent_controller import print_with_indent
from opendevin.plan import Plan
from opendevin.action import Action
from opendevin.action import Action, action_from_dict
from opendevin.observation import Observation
from opendevin.action import (
@ -136,15 +136,10 @@ def get_prompt(plan: Plan, history: List[Tuple[Action, Observation]]):
latest_action: Action = NullAction()
for action, observation in sub_history:
if not isinstance(action, NullAction):
#if not isinstance(action, ModifyTaskAction) and not isinstance(action, AddTaskAction):
action_dict = action.to_dict()
action_dict["action"] = convert_action(action_dict["action"])
history_dicts.append(action_dict)
history_dicts.append(action.to_dict())
latest_action = action
if not isinstance(observation, NullObservation):
observation_dict = observation.to_dict()
observation_dict["observation"] = convert_observation(observation_dict["observation"])
history_dicts.append(observation_dict)
history_dicts.append(observation.to_dict())
history_str = json.dumps(history_dicts, indent=2)
hint = ""
@ -157,7 +152,7 @@ def get_prompt(plan: Plan, history: List[Tuple[Action, Observation]]):
plan_status = "You're not currently working on any tasks. Your next action MUST be to mark a task as in_progress."
hint = plan_status
latest_action_id = convert_action(latest_action.to_dict()["action"])
latest_action_id = latest_action.to_dict()['action']
if current_task is not None:
if latest_action_id == "":
@ -200,43 +195,6 @@ def parse_response(response: str) -> Action:
if 'content' in action_dict:
# The LLM gets confused here. Might as well be robust
action_dict['contents'] = action_dict.pop('content')
args_dict = action_dict.get("args", {})
action = ACTION_TYPE_TO_CLASS[action_dict["action"]](**args_dict)
action = action_from_dict(action_dict)
return action
def convert_action(action):
if action == "CmdRunAction":
action = "run"
elif action == "CmdKillAction":
action = "kill"
elif action == "BrowseURLAction":
action = "browse"
elif action == "FileReadAction":
action = "read"
elif action == "FileWriteAction":
action = "write"
elif action == "AgentFinishAction":
action = "finish"
elif action == "AgentRecallAction":
action = "recall"
elif action == "AgentThinkAction":
action = "think"
elif action == "AgentSummarizeAction":
action = "summarize"
elif action == "AddTaskAction":
action = "add_task"
elif action == "ModifyTaskAction":
action = "modify_task"
return action
def convert_observation(observation):
if observation == "UserMessageObservation":
observation = "chat"
elif observation == "AgentMessageObservation":
observation = "chat"
elif observation == "CmdOutputObservation":
observation = "run"
elif observation == "FileReadObservation":
observation = "read"
return observation

View File

@ -13,18 +13,22 @@ actions = (
FileWriteAction,
AgentRecallAction,
AgentThinkAction,
AgentFinishAction
AgentFinishAction,
AddTaskAction,
ModifyTaskAction,
)
ACTION_TYPE_TO_CLASS = {action_class.action:action_class for action_class in actions} # type: ignore[attr-defined]
def action_class_initialize_dispatcher(action: str, *args: str, **kwargs: str) -> Action:
action_class = ACTION_TYPE_TO_CLASS.get(action)
def action_from_dict(action: dict) -> Action:
action = action.copy()
if "action" not in action:
raise KeyError(f"'action' key is not found in {action=}")
action_class = ACTION_TYPE_TO_CLASS.get(action["action"])
if action_class is None:
raise KeyError(f"'{action=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}")
return action_class(*args, **kwargs)
CLASS_TO_ACTION_TYPE = {v: k for k, v in ACTION_TYPE_TO_CLASS.items()}
raise KeyError(f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}")
args = action.get("args", {})
return action_class(**args)
__all__ = [
"Action",

View File

@ -25,7 +25,6 @@ class AgentRecallAction(ExecutableAction):
@dataclass
class AgentThinkAction(NotExecutableAction):
thought: str
runnable: bool = False
action: str = "think"
def run(self, controller: "AgentController") -> "Observation":
@ -38,7 +37,6 @@ class AgentThinkAction(NotExecutableAction):
@dataclass
class AgentEchoAction(ExecutableAction):
content: str
runnable: bool = True
action: str = "echo"
def run(self, controller: "AgentController") -> "Observation":
@ -60,7 +58,6 @@ class AgentSummarizeAction(NotExecutableAction):
@dataclass
class AgentFinishAction(NotExecutableAction):
runnable: bool = False
action: str = "finish"
def run(self, controller: "AgentController") -> "Observation":

View File

@ -26,8 +26,6 @@ class Action:
def message(self) -> str:
raise NotImplementedError
@dataclass
class ExecutableAction(Action):
@property

View File

@ -1,162 +0,0 @@
import copy
from typing import List
from dataclasses import dataclass
@dataclass
class Observation:
"""
This data class represents an observation of the environment.
"""
content: str
def __str__(self) -> str:
return self.content
def to_dict(self) -> dict:
"""Converts the observation to a dictionary."""
extras = copy.deepcopy(self.__dict__)
extras.pop("content", None)
observation = "observation"
if hasattr(self, "observation"):
observation = self.observation
return {
"observation": observation,
"content": self.content,
"extras": extras,
"message": self.message,
}
@property
def message(self) -> str:
"""Returns a message describing the observation."""
return ""
@dataclass
class CmdOutputObservation(Observation):
"""
This data class represents the output of a command.
"""
command_id: int
command: str
exit_code: int = 0
observation : str = "run"
@property
def error(self) -> bool:
return self.exit_code != 0
@property
def message(self) -> str:
return f'Command `{self.command}` executed with exit code {self.exit_code}.'
@dataclass
class FileReadObservation(Observation):
"""
This data class represents the content of a file.
"""
path: str
observation : str = "read"
@property
def message(self) -> str:
return f"I read the file {self.path}."
@dataclass
class FileWriteObservation(Observation):
"""
This data class represents a file write operation
"""
path: str
observation : str = "write"
@property
def message(self) -> str:
return f"I wrote to the file {self.path}."
@dataclass
class BrowserOutputObservation(Observation):
"""
This data class represents the output of a browser.
"""
url: str
status_code: int = 200
error: bool = False
observation : str = "browse"
@property
def message(self) -> str:
return "Visited " + self.url
@dataclass
class UserMessageObservation(Observation):
"""
This data class represents a message sent by the user.
"""
role: str = "user"
observation : str = "message"
@property
def message(self) -> str:
return ""
@dataclass
class AgentMessageObservation(Observation):
"""
This data class represents a message sent by the agent.
"""
role: str = "assistant"
observation : str = "message"
@property
def message(self) -> str:
return ""
@dataclass
class AgentRecallObservation(Observation):
"""
This data class represents a list of memories recalled by the agent.
"""
memories: List[str]
role: str = "assistant"
observation : str = "recall"
@property
def message(self) -> str:
return "The agent recalled memories."
@dataclass
class AgentErrorObservation(Observation):
"""
This data class represents an error encountered by the agent.
"""
observation : str = "error"
@property
def message(self) -> str:
return "Oops. Something went wrong: " + self.content
@dataclass
class NullObservation(Observation):
"""
This data class represents a null observation.
This is used when the produced action is NOT executable.
"""
observation : str = "null"
@property
def message(self) -> str:
return ""

View File

@ -0,0 +1,46 @@
from .base import Observation, NullObservation
from .run import CmdOutputObservation
from .browse import BrowserOutputObservation
from .files import FileReadObservation, FileWriteObservation
from .message import UserMessageObservation, AgentMessageObservation
from .recall import AgentRecallObservation
from .error import AgentErrorObservation
observations = (
CmdOutputObservation,
BrowserOutputObservation,
FileReadObservation,
FileWriteObservation,
UserMessageObservation,
AgentMessageObservation,
AgentRecallObservation,
AgentErrorObservation,
)
OBSERVATION_TYPE_TO_CLASS = {observation_class.observation:observation_class for observation_class in observations} # type: ignore[attr-defined]
def observation_from_dict(observation: dict) -> Observation:
observation = observation.copy()
if "observation" not in observation:
raise KeyError(f"'observation' key is not found in {observation=}")
observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation["observation"])
if observation_class is None:
raise KeyError(f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}")
observation.pop("observation")
observation.pop("message", None)
content = observation.pop("content", "")
extras = observation.pop("extras", {})
return observation_class(content=content, **extras)
__all__ = [
"Observation",
"NullObservation",
"CmdOutputObservation",
"BrowserOutputObservation",
"FileReadObservation",
"FileWriteObservation",
"UserMessageObservation",
"AgentMessageObservation",
"AgentRecallObservation",
"AgentErrorObservation",
]

View File

@ -0,0 +1,43 @@
import copy
from dataclasses import dataclass
@dataclass
class Observation:
"""
This data class represents an observation of the environment.
"""
content: str
def __str__(self) -> str:
return self.content
def to_dict(self) -> dict:
"""Converts the observation to a dictionary."""
extras = copy.deepcopy(self.__dict__)
content = extras.pop("content", "")
observation = extras.pop("observation", "")
return {
"observation": observation,
"content": content,
"extras": extras,
"message": self.message,
}
@property
def message(self) -> str:
"""Returns a message describing the observation."""
return ""
@dataclass
class NullObservation(Observation):
"""
This data class represents a null observation.
This is used when the produced action is NOT executable.
"""
observation : str = "null"
@property
def message(self) -> str:
return ""

View File

@ -0,0 +1,21 @@
from dataclasses import dataclass
from .base import Observation
@dataclass
class BrowserOutputObservation(Observation):
"""
This data class represents the output of a browser.
"""
url: str
status_code: int = 200
error: bool = False
observation : str = "browse"
@property
def message(self) -> str:
return "Visited " + self.url

View File

@ -0,0 +1,16 @@
from dataclasses import dataclass
from .base import Observation
@dataclass
class AgentErrorObservation(Observation):
"""
This data class represents an error encountered by the agent.
"""
observation : str = "error"
@property
def message(self) -> str:
return "Oops. Something went wrong: " + self.content

View File

@ -0,0 +1,31 @@
from dataclasses import dataclass
from .base import Observation
@dataclass
class FileReadObservation(Observation):
"""
This data class represents the content of a file.
"""
path: str
observation : str = "read"
@property
def message(self) -> str:
return f"I read the file {self.path}."
@dataclass
class FileWriteObservation(Observation):
"""
This data class represents a file write operation
"""
path: str
observation : str = "write"
@property
def message(self) -> str:
return f"I wrote to the file {self.path}."

View File

@ -0,0 +1,33 @@
from dataclasses import dataclass
from .base import Observation
@dataclass
class UserMessageObservation(Observation):
"""
This data class represents a message sent by the user.
"""
role: str = "user"
observation : str = "message"
@property
def message(self) -> str:
return ""
@dataclass
class AgentMessageObservation(Observation):
"""
This data class represents a message sent by the agent.
"""
role: str = "assistant"
observation : str = "message"
@property
def message(self) -> str:
return ""

View File

@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List
from .base import Observation
@dataclass
class AgentRecallObservation(Observation):
"""
This data class represents a list of memories recalled by the agent.
"""
memories: List[str]
role: str = "assistant"
observation : str = "recall"
@property
def message(self) -> str:
return "The agent recalled memories."

View File

@ -0,0 +1,24 @@
from dataclasses import dataclass
from .base import Observation
@dataclass
class CmdOutputObservation(Observation):
"""
This data class represents the output of a command.
"""
command_id: int
command: str
exit_code: int = 0
observation : str = "run"
@property
def error(self) -> bool:
return self.exit_code != 0
@property
def message(self) -> str:
return f'Command `{self.command}` executed with exit code {self.exit_code}.'

View File

@ -0,0 +1,16 @@
import pytest
from opendevin.action import action_from_dict, Action, AgentThinkAction
def test_action_serialization_deserialization():
original_action_dict = {
'action': 'think',
'args': {'thought': 'This is a test.'}
}
action_instance = action_from_dict(original_action_dict)
assert isinstance(action_instance, Action), 'The action instance should be an instance of Action.'
assert isinstance(action_instance, AgentThinkAction), 'The action instance should be an instance of AgentThinkAction.'
serialized_action_dict = action_instance.to_dict()
serialized_action_dict.pop('message')
assert serialized_action_dict == original_action_dict, 'The serialized action should match the original action dict.'
# Additional tests for various action subclasses can be included here

View File

@ -0,0 +1,17 @@
import pytest
from opendevin.observation import observation_from_dict, Observation, CmdOutputObservation
def test_observation_serialization_deserialization():
original_observation_dict = {
'observation': 'run',
'extras': {'exit_code': 0, 'command': 'ls -l', 'command_id': 3},
'message': 'Command `ls -l` executed with exit code 0.',
'content': 'foo.txt',
}
observation_instance = observation_from_dict(original_observation_dict)
assert isinstance(observation_instance, Observation), 'The observation instance should be an instance of Action.'
assert isinstance(observation_instance, CmdOutputObservation), 'The observation instance should be an instance of AgentThinkAction.'
serialized_observation_dict = observation_instance.to_dict()
assert serialized_observation_dict == original_observation_dict, 'The serialized observation should match the original observation dict.'
# Additional tests for various observation subclasses can be included here