From effac868c116f27bc971cbd4579e33b6311bdafd Mon Sep 17 00:00:00 2001 From: Robert Brennan Date: Sat, 30 Mar 2024 10:06:25 -0400 Subject: [PATCH] Implement deserialization for actions and observations (#359) * action deserializing * add observation deserialization * add tests * refactor agents with serialization * fix some errors * fix lint * fix json parser --- agenthub/langchains_agent/langchains_agent.py | 58 +------ agenthub/langchains_agent/utils/json.py | 4 + agenthub/langchains_agent/utils/memory.py | 11 +- agenthub/langchains_agent/utils/prompts.py | 68 ++------ agenthub/planner_agent/prompt.py | 52 +----- opendevin/action/__init__.py | 18 +- opendevin/action/agent.py | 3 - opendevin/action/base.py | 2 - opendevin/observation.py | 162 ------------------ opendevin/observation/__init__.py | 46 +++++ opendevin/observation/base.py | 43 +++++ opendevin/observation/browse.py | 21 +++ opendevin/observation/error.py | 16 ++ opendevin/observation/files.py | 31 ++++ opendevin/observation/message.py | 33 ++++ opendevin/observation/recall.py | 21 +++ opendevin/observation/run.py | 24 +++ tests/test_action_serialization.py | 16 ++ tests/test_observation_serialization.py | 17 ++ 19 files changed, 318 insertions(+), 328 deletions(-) delete mode 100644 opendevin/observation.py create mode 100644 opendevin/observation/__init__.py create mode 100644 opendevin/observation/base.py create mode 100644 opendevin/observation/browse.py create mode 100644 opendevin/observation/error.py create mode 100644 opendevin/observation/files.py create mode 100644 opendevin/observation/message.py create mode 100644 opendevin/observation/recall.py create mode 100644 opendevin/observation/run.py create mode 100644 tests/test_action_serialization.py create mode 100644 tests/test_observation_serialization.py diff --git a/agenthub/langchains_agent/langchains_agent.py b/agenthub/langchains_agent/langchains_agent.py index 48e392dbba..20f45a15bc 100644 --- a/agenthub/langchains_agent/langchains_agent.py +++ b/agenthub/langchains_agent/langchains_agent.py @@ -7,22 +7,6 @@ import agenthub.langchains_agent.utils.prompts as prompts from agenthub.langchains_agent.utils.monologue import Monologue from agenthub.langchains_agent.utils.memory import LongTermMemory -from opendevin.action import ( - NullAction, - CmdRunAction, - CmdKillAction, - BrowseURLAction, - FileReadAction, - FileWriteAction, - AgentRecallAction, - AgentThinkAction, - AgentFinishAction, -) -from opendevin.observation import ( - CmdOutputObservation, -) - - MAX_MONOLOGUE_LENGTH = 20000 MAX_OUTPUT_LENGTH = 5000 @@ -81,7 +65,7 @@ class LangchainsAgent(Agent): self.memory = LongTermMemory() def _add_event(self, event: dict): - if 'output' in event['args'] and len(event['args']['output']) > MAX_OUTPUT_LENGTH: + if 'args' in event and 'output' in event['args'] and len(event['args']['output']) > MAX_OUTPUT_LENGTH: event['args']['output'] = event['args']['output'][:MAX_OUTPUT_LENGTH] + "..." self.monologue.add_event(event) @@ -136,45 +120,9 @@ class LangchainsAgent(Agent): def step(self, state: State) -> Action: self._initialize(state.plan.main_goal) - # TODO: make langchains agent use Action & Observation - # completly from ground up - - # Translate state to action_dict for prev_action, obs in state.updated_info: - d = None - if isinstance(obs, CmdOutputObservation): - if obs.error: - d = {"action": "error", "args": {"output": obs.content}} - else: - d = {"action": "output", "args": {"output": obs.content}} - else: - d = {"action": "output", "args": {"output": obs.content}} - if d is not None: - self._add_event(d) - - d = None - if isinstance(prev_action, CmdRunAction): - d = {"action": "run", "args": {"command": prev_action.command}} - elif isinstance(prev_action, CmdKillAction): - d = {"action": "kill", "args": {"id": prev_action.id}} - elif isinstance(prev_action, BrowseURLAction): - d = {"action": "browse", "args": {"url": prev_action.url}} - elif isinstance(prev_action, FileReadAction): - d = {"action": "read", "args": {"file": prev_action.path}} - elif isinstance(prev_action, FileWriteAction): - d = {"action": "write", "args": {"file": prev_action.path, "content": prev_action.contents}} - elif isinstance(prev_action, AgentRecallAction): - d = {"action": "recall", "args": {"query": prev_action.query}} - elif isinstance(prev_action, AgentThinkAction): - d = {"action": "think", "args": {"thought": prev_action.thought}} - elif isinstance(prev_action, AgentFinishAction): - d = {"action": "finish"} - elif isinstance(prev_action, NullAction): - d = None - else: - raise ValueError(f"Unknown action type: {prev_action}") - if d is not None: - self._add_event(d) + self._add_event(prev_action.to_dict()) + self._add_event(obs.to_dict()) state.updated_info = [] diff --git a/agenthub/langchains_agent/utils/json.py b/agenthub/langchains_agent/utils/json.py index 3c99c7f94e..c0f2dee02b 100644 --- a/agenthub/langchains_agent/utils/json.py +++ b/agenthub/langchains_agent/utils/json.py @@ -6,3 +6,7 @@ def my_encoder(obj): def dumps(obj, **kwargs): return json.dumps(obj, default=my_encoder, **kwargs) + +def loads(s, **kwargs): + return json.loads(s, **kwargs) + diff --git a/agenthub/langchains_agent/utils/memory.py b/agenthub/langchains_agent/utils/memory.py index 4b2a09a63e..6212a54095 100644 --- a/agenthub/langchains_agent/utils/memory.py +++ b/agenthub/langchains_agent/utils/memory.py @@ -48,11 +48,20 @@ class LongTermMemory: self.thought_idx = 0 def add_event(self, event): + id = "" + t = "" + if "action" in event: + t = "action" + id = event["action"] + elif "observation" in event: + t = "observation" + id = event["observation"] doc = Document( text=json.dumps(event), doc_id=str(self.thought_idx), extra_info={ - "type": event["action"], + "type": t, + "id": id, "idx": self.thought_idx, }, ) diff --git a/agenthub/langchains_agent/utils/prompts.py b/agenthub/langchains_agent/utils/prompts.py index 3d30c221ce..69a7622156 100644 --- a/agenthub/langchains_agent/utils/prompts.py +++ b/agenthub/langchains_agent/utils/prompts.py @@ -1,8 +1,6 @@ -from typing import List, Dict, Type +from typing import List -from langchain_core.pydantic_v1 import BaseModel from langchain.prompts import PromptTemplate -from langchain_core.output_parsers import JsonOutputParser from opendevin import config @@ -13,35 +11,13 @@ if config.get_or_default("DEBUG", False): from . import json from opendevin.action import ( + action_from_dict, Action, - CmdRunAction, - CmdKillAction, - BrowseURLAction, - FileReadAction, - FileWriteAction, - AgentRecallAction, - AgentThinkAction, - AgentFinishAction, - AgentSummarizeAction, ) from opendevin.observation import ( CmdOutputObservation, ) - -ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = { - "run": CmdRunAction, - "kill": CmdKillAction, - "browse": BrowseURLAction, - "read": FileReadAction, - "write": FileWriteAction, - "recall": AgentRecallAction, - "think": AgentThinkAction, - "summarize": AgentSummarizeAction, - "finish": AgentFinishAction, -} -CLASS_TO_ACTION_TYPE: Dict[Type[Action], str] = {v: k for k, v in ACTION_TYPE_TO_CLASS.items()} - ACTION_PROMPT = """ You're a thoughtful robot. Your main task is to {task}. Don't expand the scope of your task--just complete it as written. @@ -116,15 +92,6 @@ You can also use the same action and args from the source monologue. """ -class _ActionDict(BaseModel): - action: str - args: dict - - -class NewMonologue(BaseModel): - new_monologue: List[_ActionDict] - - def get_summarize_monologue_prompt(thoughts): prompt = PromptTemplate.from_template(MONOLOGUE_SUMMARY_PROMPT) return prompt.format(monologue=json.dumps({'old_monologue': thoughts}, indent=2)) @@ -137,13 +104,14 @@ def get_request_action_prompt( hint = '' if len(thoughts) > 0: latest_thought = thoughts[-1] - if latest_thought["action"] == 'think': - if latest_thought["args"]['thought'].startswith("OK so my task is"): - hint = "You're just getting started! What should you do first?" - else: - hint = "You've been thinking a lot lately. Maybe it's time to take action?" - elif latest_thought["action"] == 'error': - hint = "Looks like that last command failed. Maybe you need to fix it, or try something else." + if "action" in latest_thought: + if latest_thought["action"] == 'think': + if latest_thought["args"]['thought'].startswith("OK so my task is"): + hint = "You're just getting started! What should you do first?" + else: + hint = "You've been thinking a lot lately. Maybe it's time to take action?" + elif latest_thought["action"] == 'error': + hint = "Looks like that last command failed. Maybe you need to fix it, or try something else." bg_commands_message = "" if len(background_commands_obs) > 0: @@ -162,17 +130,15 @@ def get_request_action_prompt( ) def parse_action_response(response: str) -> Action: - parser = JsonOutputParser(pydantic_object=_ActionDict) - action_dict = parser.parse(response) + json_start = response.find("{") + json_end = response.rfind("}") + 1 + response = response[json_start:json_end] + action_dict = json.loads(response) if 'content' in action_dict: # The LLM gets confused here. Might as well be robust action_dict['contents'] = action_dict.pop('content') + return action_from_dict(action_dict) - action = ACTION_TYPE_TO_CLASS[action_dict["action"]](**action_dict["args"]) - return action - -def parse_summary_response(response: str) -> List[Action]: - parser = JsonOutputParser(pydantic_object=NewMonologue) - parsed = parser.parse(response) - #thoughts = [ACTION_TYPE_TO_CLASS[t['action']](**t['args']) for t in parsed['new_monologue']] +def parse_summary_response(response: str) -> List[dict]: + parsed = json.loads(response) return parsed['new_monologue'] diff --git a/agenthub/planner_agent/prompt.py b/agenthub/planner_agent/prompt.py index 9032e150f8..33eeb48ab3 100644 --- a/agenthub/planner_agent/prompt.py +++ b/agenthub/planner_agent/prompt.py @@ -3,7 +3,7 @@ from typing import List, Tuple, Dict, Type from opendevin.controller.agent_controller import print_with_indent from opendevin.plan import Plan -from opendevin.action import Action +from opendevin.action import Action, action_from_dict from opendevin.observation import Observation from opendevin.action import ( @@ -136,15 +136,10 @@ def get_prompt(plan: Plan, history: List[Tuple[Action, Observation]]): latest_action: Action = NullAction() for action, observation in sub_history: if not isinstance(action, NullAction): - #if not isinstance(action, ModifyTaskAction) and not isinstance(action, AddTaskAction): - action_dict = action.to_dict() - action_dict["action"] = convert_action(action_dict["action"]) - history_dicts.append(action_dict) + history_dicts.append(action.to_dict()) latest_action = action if not isinstance(observation, NullObservation): - observation_dict = observation.to_dict() - observation_dict["observation"] = convert_observation(observation_dict["observation"]) - history_dicts.append(observation_dict) + history_dicts.append(observation.to_dict()) history_str = json.dumps(history_dicts, indent=2) hint = "" @@ -157,7 +152,7 @@ def get_prompt(plan: Plan, history: List[Tuple[Action, Observation]]): plan_status = "You're not currently working on any tasks. Your next action MUST be to mark a task as in_progress." hint = plan_status - latest_action_id = convert_action(latest_action.to_dict()["action"]) + latest_action_id = latest_action.to_dict()['action'] if current_task is not None: if latest_action_id == "": @@ -200,43 +195,6 @@ def parse_response(response: str) -> Action: if 'content' in action_dict: # The LLM gets confused here. Might as well be robust action_dict['contents'] = action_dict.pop('content') - - args_dict = action_dict.get("args", {}) - action = ACTION_TYPE_TO_CLASS[action_dict["action"]](**args_dict) + action = action_from_dict(action_dict) return action -def convert_action(action): - if action == "CmdRunAction": - action = "run" - elif action == "CmdKillAction": - action = "kill" - elif action == "BrowseURLAction": - action = "browse" - elif action == "FileReadAction": - action = "read" - elif action == "FileWriteAction": - action = "write" - elif action == "AgentFinishAction": - action = "finish" - elif action == "AgentRecallAction": - action = "recall" - elif action == "AgentThinkAction": - action = "think" - elif action == "AgentSummarizeAction": - action = "summarize" - elif action == "AddTaskAction": - action = "add_task" - elif action == "ModifyTaskAction": - action = "modify_task" - return action - -def convert_observation(observation): - if observation == "UserMessageObservation": - observation = "chat" - elif observation == "AgentMessageObservation": - observation = "chat" - elif observation == "CmdOutputObservation": - observation = "run" - elif observation == "FileReadObservation": - observation = "read" - return observation diff --git a/opendevin/action/__init__.py b/opendevin/action/__init__.py index 37ad3b8439..83fd9a406e 100644 --- a/opendevin/action/__init__.py +++ b/opendevin/action/__init__.py @@ -13,18 +13,22 @@ actions = ( FileWriteAction, AgentRecallAction, AgentThinkAction, - AgentFinishAction + AgentFinishAction, + AddTaskAction, + ModifyTaskAction, ) ACTION_TYPE_TO_CLASS = {action_class.action:action_class for action_class in actions} # type: ignore[attr-defined] -def action_class_initialize_dispatcher(action: str, *args: str, **kwargs: str) -> Action: - action_class = ACTION_TYPE_TO_CLASS.get(action) +def action_from_dict(action: dict) -> Action: + action = action.copy() + if "action" not in action: + raise KeyError(f"'action' key is not found in {action=}") + action_class = ACTION_TYPE_TO_CLASS.get(action["action"]) if action_class is None: - raise KeyError(f"'{action=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}") - return action_class(*args, **kwargs) - -CLASS_TO_ACTION_TYPE = {v: k for k, v in ACTION_TYPE_TO_CLASS.items()} + raise KeyError(f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}") + args = action.get("args", {}) + return action_class(**args) __all__ = [ "Action", diff --git a/opendevin/action/agent.py b/opendevin/action/agent.py index b0ea63d09c..74edab945c 100644 --- a/opendevin/action/agent.py +++ b/opendevin/action/agent.py @@ -25,7 +25,6 @@ class AgentRecallAction(ExecutableAction): @dataclass class AgentThinkAction(NotExecutableAction): thought: str - runnable: bool = False action: str = "think" def run(self, controller: "AgentController") -> "Observation": @@ -38,7 +37,6 @@ class AgentThinkAction(NotExecutableAction): @dataclass class AgentEchoAction(ExecutableAction): content: str - runnable: bool = True action: str = "echo" def run(self, controller: "AgentController") -> "Observation": @@ -60,7 +58,6 @@ class AgentSummarizeAction(NotExecutableAction): @dataclass class AgentFinishAction(NotExecutableAction): - runnable: bool = False action: str = "finish" def run(self, controller: "AgentController") -> "Observation": diff --git a/opendevin/action/base.py b/opendevin/action/base.py index 0d5d0765e1..bd47ff1562 100644 --- a/opendevin/action/base.py +++ b/opendevin/action/base.py @@ -26,8 +26,6 @@ class Action: def message(self) -> str: raise NotImplementedError - - @dataclass class ExecutableAction(Action): @property diff --git a/opendevin/observation.py b/opendevin/observation.py deleted file mode 100644 index 7cee626048..0000000000 --- a/opendevin/observation.py +++ /dev/null @@ -1,162 +0,0 @@ -import copy -from typing import List -from dataclasses import dataclass - - -@dataclass -class Observation: - """ - This data class represents an observation of the environment. - """ - - content: str - - def __str__(self) -> str: - return self.content - - def to_dict(self) -> dict: - """Converts the observation to a dictionary.""" - extras = copy.deepcopy(self.__dict__) - extras.pop("content", None) - observation = "observation" - if hasattr(self, "observation"): - observation = self.observation - return { - "observation": observation, - "content": self.content, - "extras": extras, - "message": self.message, - } - - @property - def message(self) -> str: - """Returns a message describing the observation.""" - return "" - - -@dataclass -class CmdOutputObservation(Observation): - """ - This data class represents the output of a command. - """ - - command_id: int - command: str - exit_code: int = 0 - observation : str = "run" - - @property - def error(self) -> bool: - return self.exit_code != 0 - - @property - def message(self) -> str: - return f'Command `{self.command}` executed with exit code {self.exit_code}.' - -@dataclass -class FileReadObservation(Observation): - """ - This data class represents the content of a file. - """ - - path: str - observation : str = "read" - - @property - def message(self) -> str: - return f"I read the file {self.path}." - -@dataclass -class FileWriteObservation(Observation): - """ - This data class represents a file write operation - """ - - path: str - observation : str = "write" - - @property - def message(self) -> str: - return f"I wrote to the file {self.path}." - -@dataclass -class BrowserOutputObservation(Observation): - """ - This data class represents the output of a browser. - """ - - url: str - status_code: int = 200 - error: bool = False - observation : str = "browse" - - @property - def message(self) -> str: - return "Visited " + self.url - - -@dataclass -class UserMessageObservation(Observation): - """ - This data class represents a message sent by the user. - """ - - role: str = "user" - observation : str = "message" - - @property - def message(self) -> str: - return "" - - -@dataclass -class AgentMessageObservation(Observation): - """ - This data class represents a message sent by the agent. - """ - - role: str = "assistant" - observation : str = "message" - - @property - def message(self) -> str: - return "" - - -@dataclass -class AgentRecallObservation(Observation): - """ - This data class represents a list of memories recalled by the agent. - """ - - memories: List[str] - role: str = "assistant" - observation : str = "recall" - - @property - def message(self) -> str: - return "The agent recalled memories." - - -@dataclass -class AgentErrorObservation(Observation): - """ - This data class represents an error encountered by the agent. - """ - observation : str = "error" - - @property - def message(self) -> str: - return "Oops. Something went wrong: " + self.content - -@dataclass -class NullObservation(Observation): - """ - This data class represents a null observation. - This is used when the produced action is NOT executable. - """ - observation : str = "null" - - @property - def message(self) -> str: - return "" diff --git a/opendevin/observation/__init__.py b/opendevin/observation/__init__.py new file mode 100644 index 0000000000..b8d965530e --- /dev/null +++ b/opendevin/observation/__init__.py @@ -0,0 +1,46 @@ +from .base import Observation, NullObservation +from .run import CmdOutputObservation +from .browse import BrowserOutputObservation +from .files import FileReadObservation, FileWriteObservation +from .message import UserMessageObservation, AgentMessageObservation +from .recall import AgentRecallObservation +from .error import AgentErrorObservation + +observations = ( + CmdOutputObservation, + BrowserOutputObservation, + FileReadObservation, + FileWriteObservation, + UserMessageObservation, + AgentMessageObservation, + AgentRecallObservation, + AgentErrorObservation, +) + +OBSERVATION_TYPE_TO_CLASS = {observation_class.observation:observation_class for observation_class in observations} # type: ignore[attr-defined] + +def observation_from_dict(observation: dict) -> Observation: + observation = observation.copy() + if "observation" not in observation: + raise KeyError(f"'observation' key is not found in {observation=}") + observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation["observation"]) + if observation_class is None: + raise KeyError(f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}") + observation.pop("observation") + observation.pop("message", None) + content = observation.pop("content", "") + extras = observation.pop("extras", {}) + return observation_class(content=content, **extras) + +__all__ = [ + "Observation", + "NullObservation", + "CmdOutputObservation", + "BrowserOutputObservation", + "FileReadObservation", + "FileWriteObservation", + "UserMessageObservation", + "AgentMessageObservation", + "AgentRecallObservation", + "AgentErrorObservation", +] diff --git a/opendevin/observation/base.py b/opendevin/observation/base.py new file mode 100644 index 0000000000..c44790fcdd --- /dev/null +++ b/opendevin/observation/base.py @@ -0,0 +1,43 @@ +import copy +from dataclasses import dataclass + +@dataclass +class Observation: + """ + This data class represents an observation of the environment. + """ + + content: str + + def __str__(self) -> str: + return self.content + + def to_dict(self) -> dict: + """Converts the observation to a dictionary.""" + extras = copy.deepcopy(self.__dict__) + content = extras.pop("content", "") + observation = extras.pop("observation", "") + return { + "observation": observation, + "content": content, + "extras": extras, + "message": self.message, + } + + @property + def message(self) -> str: + """Returns a message describing the observation.""" + return "" + + +@dataclass +class NullObservation(Observation): + """ + This data class represents a null observation. + This is used when the produced action is NOT executable. + """ + observation : str = "null" + + @property + def message(self) -> str: + return "" diff --git a/opendevin/observation/browse.py b/opendevin/observation/browse.py new file mode 100644 index 0000000000..37faed8352 --- /dev/null +++ b/opendevin/observation/browse.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from .base import Observation + +@dataclass +class BrowserOutputObservation(Observation): + """ + This data class represents the output of a browser. + """ + + url: str + status_code: int = 200 + error: bool = False + observation : str = "browse" + + @property + def message(self) -> str: + return "Visited " + self.url + + + diff --git a/opendevin/observation/error.py b/opendevin/observation/error.py new file mode 100644 index 0000000000..184758f8ce --- /dev/null +++ b/opendevin/observation/error.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +from .base import Observation + +@dataclass +class AgentErrorObservation(Observation): + """ + This data class represents an error encountered by the agent. + """ + observation : str = "error" + + @property + def message(self) -> str: + return "Oops. Something went wrong: " + self.content + + diff --git a/opendevin/observation/files.py b/opendevin/observation/files.py new file mode 100644 index 0000000000..1eb33e6c15 --- /dev/null +++ b/opendevin/observation/files.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + +from .base import Observation + +@dataclass +class FileReadObservation(Observation): + """ + This data class represents the content of a file. + """ + + path: str + observation : str = "read" + + @property + def message(self) -> str: + return f"I read the file {self.path}." + +@dataclass +class FileWriteObservation(Observation): + """ + This data class represents a file write operation + """ + + path: str + observation : str = "write" + + @property + def message(self) -> str: + return f"I wrote to the file {self.path}." + + diff --git a/opendevin/observation/message.py b/opendevin/observation/message.py new file mode 100644 index 0000000000..bee6a36029 --- /dev/null +++ b/opendevin/observation/message.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass + +from .base import Observation + +@dataclass +class UserMessageObservation(Observation): + """ + This data class represents a message sent by the user. + """ + + role: str = "user" + observation : str = "message" + + @property + def message(self) -> str: + return "" + + +@dataclass +class AgentMessageObservation(Observation): + """ + This data class represents a message sent by the agent. + """ + + role: str = "assistant" + observation : str = "message" + + @property + def message(self) -> str: + return "" + + + diff --git a/opendevin/observation/recall.py b/opendevin/observation/recall.py new file mode 100644 index 0000000000..bab1d6def4 --- /dev/null +++ b/opendevin/observation/recall.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List + +from .base import Observation + +@dataclass +class AgentRecallObservation(Observation): + """ + This data class represents a list of memories recalled by the agent. + """ + + memories: List[str] + role: str = "assistant" + observation : str = "recall" + + @property + def message(self) -> str: + return "The agent recalled memories." + + + diff --git a/opendevin/observation/run.py b/opendevin/observation/run.py new file mode 100644 index 0000000000..9bfc489d4e --- /dev/null +++ b/opendevin/observation/run.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass + +from .base import Observation + +@dataclass +class CmdOutputObservation(Observation): + """ + This data class represents the output of a command. + """ + + command_id: int + command: str + exit_code: int = 0 + observation : str = "run" + + @property + def error(self) -> bool: + return self.exit_code != 0 + + @property + def message(self) -> str: + return f'Command `{self.command}` executed with exit code {self.exit_code}.' + + diff --git a/tests/test_action_serialization.py b/tests/test_action_serialization.py new file mode 100644 index 0000000000..32f92f67a0 --- /dev/null +++ b/tests/test_action_serialization.py @@ -0,0 +1,16 @@ +import pytest +from opendevin.action import action_from_dict, Action, AgentThinkAction + +def test_action_serialization_deserialization(): + original_action_dict = { + 'action': 'think', + 'args': {'thought': 'This is a test.'} + } + action_instance = action_from_dict(original_action_dict) + assert isinstance(action_instance, Action), 'The action instance should be an instance of Action.' + assert isinstance(action_instance, AgentThinkAction), 'The action instance should be an instance of AgentThinkAction.' + serialized_action_dict = action_instance.to_dict() + serialized_action_dict.pop('message') + assert serialized_action_dict == original_action_dict, 'The serialized action should match the original action dict.' + +# Additional tests for various action subclasses can be included here diff --git a/tests/test_observation_serialization.py b/tests/test_observation_serialization.py new file mode 100644 index 0000000000..d09bb047a7 --- /dev/null +++ b/tests/test_observation_serialization.py @@ -0,0 +1,17 @@ +import pytest +from opendevin.observation import observation_from_dict, Observation, CmdOutputObservation + +def test_observation_serialization_deserialization(): + original_observation_dict = { + 'observation': 'run', + 'extras': {'exit_code': 0, 'command': 'ls -l', 'command_id': 3}, + 'message': 'Command `ls -l` executed with exit code 0.', + 'content': 'foo.txt', + } + observation_instance = observation_from_dict(original_observation_dict) + assert isinstance(observation_instance, Observation), 'The observation instance should be an instance of Action.' + assert isinstance(observation_instance, CmdOutputObservation), 'The observation instance should be an instance of AgentThinkAction.' + serialized_observation_dict = observation_instance.to_dict() + assert serialized_observation_dict == original_observation_dict, 'The serialized observation should match the original observation dict.' + +# Additional tests for various observation subclasses can be included here