refactor state management (#258)

* refactor state management

* rm import

* move task into state

* revert change

* revert a few files
This commit is contained in:
Robert Brennan
2024-03-28 15:23:47 -04:00
committed by GitHub
parent d993162801
commit 94120f2b5d
5 changed files with 32 additions and 35 deletions

View File

@@ -67,16 +67,15 @@ class CodeActAgent(Agent):
"""
super().__init__(llm)
self.messages: List[Mapping[str, str]] = []
self.instruction: str = ""
def step(self, state: State) -> Action:
if len(self.messages) == 0:
assert self.instruction, "Expecting instruction to be set"
assert state.task, "Expecting instruction to be set"
self.messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": self.instruction},
{"role": "user", "content": state.task},
]
print(colored("===USER:===\n" + self.instruction, "green"))
print(colored("===USER:===\n" + state.task, "green"))
updated_info = state.updated_info
if updated_info:
for prev_action, obs in updated_info:

View File

@@ -1,9 +1,8 @@
from typing import List
from opendevin.llm.llm import LLM
from opendevin.agent import Agent
from opendevin.state import State
from opendevin.action import Action
from opendevin.llm.llm import LLM
import agenthub.langchains_agent.utils.prompts as prompts
from agenthub.langchains_agent.utils.monologue import Monologue
from agenthub.langchains_agent.utils.memory import LongTermMemory
@@ -83,18 +82,18 @@ class LangchainsAgent(Agent):
if self.monologue.get_total_length() > MAX_MONOLOGUE_LENGTH:
self.monologue.condense(self.llm)
def _initialize(self):
def _initialize(self, task):
if self._initialized:
return
if self.instruction is None or self.instruction == "":
if task is None or task == "":
raise ValueError("Instruction must be provided")
self.monologue = Monologue()
self.memory = LongTermMemory()
next_is_output = False
for thought in INITIAL_THOUGHTS:
thought = thought.replace("$TASK", self.instruction)
thought = thought.replace("$TASK", task)
if next_is_output:
d = {"action": "output", "args": {"output": thought}}
next_is_output = False
@@ -120,7 +119,7 @@ class LangchainsAgent(Agent):
self._initialized = True
def step(self, state: State) -> Action:
self._initialize()
self._initialize(state.task)
# TODO: make langchains agent use Action & Observation
# completly from ground up
@@ -164,7 +163,7 @@ class LangchainsAgent(Agent):
state.updated_info = []
prompt = prompts.get_request_action_prompt(
self.instruction,
state.task,
self.monologue.get_thoughts(),
state.background_commands_obs,
)

View File

@@ -12,9 +12,6 @@ class Agent(ABC):
executing a specific instruction and allowing human interaction with the
agent during execution.
It tracks the execution status and maintains a history of interactions.
:param instruction: The instruction for the agent to execute.
:param model_name: The litellm name of the model to use for the agent.
"""
_registry: Dict[str, Type["Agent"]] = {}
@@ -23,7 +20,6 @@ class Agent(ABC):
self,
llm: LLM,
):
self.instruction = ""
self.llm = llm
self._complete = False
@@ -64,7 +60,6 @@ class Agent(ABC):
to prepare the agent for restarting the instruction or cleaning up before destruction.
"""
self.instruction = ""
self._complete = False
@classmethod

View File

@@ -1,5 +1,5 @@
import asyncio
from typing import List, Callable, Tuple
from typing import List, Callable
import traceback
from opendevin.state import State
@@ -36,27 +36,26 @@ class AgentController:
self.workdir = workdir
self.command_manager = CommandManager(workdir)
self.callbacks = callbacks
self.state_updated_info: List[Tuple[Action, Observation]] = []
def get_current_state(self) -> State:
# update observations & actions
state = State(
background_commands_obs=self.command_manager.get_background_obs(),
updated_info=self.state_updated_info,
)
self.state_updated_info = []
return state
def update_state_for_step(self, i):
self.state.iteration = i
self.state.background_commands_obs = self.command_manager.get_background_obs()
def update_state_after_step(self):
self.state.updated_info = []
def add_history(self, action: Action, observation: Observation):
if not isinstance(action, Action):
raise ValueError("action must be an instance of Action")
if not isinstance(observation, Observation):
raise ValueError("observation must be an instance of Observation")
self.state_updated_info.append((action, observation))
self.state.history.append((action, observation))
self.state.updated_info.append((action, observation))
async def start_loop(self, task_instruction: str):
async def start_loop(self, task: str):
finished = False
self.agent.instruction = task_instruction
self.state = State(task)
for i in range(self.max_iterations):
try:
finished = await self.step(i)
@@ -78,16 +77,19 @@ class AgentController:
await self._run_callbacks(obs)
print_with_indent("\nBACKGROUND LOG:\n%s" % obs)
state: State = self.get_current_state()
self.update_state_for_step(i)
action: Action = NullAction()
observation: Observation = NullObservation("")
try:
action = self.agent.step(state)
action = self.agent.step(self.state)
if action is None:
raise ValueError("Agent must return an action")
print_with_indent("\nACTION:\n%s" % action)
except Exception as e:
observation = AgentErrorObservation(str(e))
print_with_indent("\nAGENT ERROR:\n%s" % observation)
traceback.print_exc()
self.update_state_after_step()
await self._run_callbacks(action)

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List, Tuple
from opendevin.action import (
@@ -9,8 +9,10 @@ from opendevin.observation import (
CmdOutputObservation,
)
@dataclass
class State:
background_commands_obs: List[CmdOutputObservation]
updated_info: List[Tuple[Action, Observation]]
task: str
iteration: int = 0
background_commands_obs: List[CmdOutputObservation] = field(default_factory=list)
history: List[Tuple[Action, Observation]] = field(default_factory=list)
updated_info: List[Tuple[Action, Observation]] = field(default_factory=list)