Property 'message' sent to llm (#1198)

* Add action with content, no message, to history

* fix to_memory(), add it to serialization tests

* Actions without 'message' in completion too
This commit is contained in:
Engel Nyst 2024-04-18 12:51:07 +02:00 committed by GitHub
parent 51149780ac
commit caabfab7e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 5 deletions

View File

@ -188,7 +188,7 @@ class MonologueAgent(Agent):
output_type = ActionType.BROWSE
else:
action = AgentThinkAction(thought=thought)
self._add_event(action.to_dict())
self._add_event(action.to_memory())
self._initialized = True
def step(self, state: State) -> Action:
@ -203,7 +203,7 @@ class MonologueAgent(Agent):
"""
self._initialize(state.plan.main_goal)
for prev_action, obs in state.updated_info:
self._add_event(prev_action.to_dict())
self._add_event(prev_action.to_memory())
self._add_event(obs.to_dict())
state.updated_info = []

View File

@ -149,7 +149,7 @@ def get_prompt(plan: Plan, history: List[Tuple[Action, Observation]]) -> str:
latest_action: Action = NullAction()
for action, observation in sub_history:
if not isinstance(action, NullAction):
history_dicts.append(action.to_dict())
history_dicts.append(action.to_memory())
latest_action = action
if not isinstance(observation, NullObservation):
observation_dict = observation.to_dict()

View File

@ -12,13 +12,18 @@ class Action:
async def run(self, controller: 'AgentController') -> 'Observation':
raise NotImplementedError
def to_dict(self):
def to_memory(self):
d = asdict(self)
try:
v = d.pop('action')
except KeyError:
raise NotImplementedError(f'{self=} does not have action attribute set')
return {'action': v, 'args': d, 'message': self.message}
return {'action': v, 'args': d}
def to_dict(self):
d = self.to_memory()
d['message'] = self.message
return d
@property
def executable(self) -> bool:

View File

@ -21,8 +21,10 @@ def serialization_deserialization(original_action_dict, cls):
assert isinstance(
action_instance, cls), f'The action instance should be an instance of {cls.__name__}.'
serialized_action_dict = action_instance.to_dict()
serialized_action_memory = action_instance.to_memory()
serialized_action_dict.pop('message')
assert serialized_action_dict == original_action_dict, 'The serialized action should match the original action dict.'
assert serialized_action_memory == original_action_dict, 'The serialized action in memory should match the original action dict.'
def test_agent_think_action_serialization_deserialization():