Dump trajectories with delegate history if configured (#4336)

This commit is contained in:
Boxuan Li 2024-10-13 17:30:04 -07:00 committed by GitHub
parent 343cc8710f
commit 7186224899
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 36 additions and 2 deletions

View File

@ -28,6 +28,9 @@ workspace_base = "./workspace"
# Enable saving and restoring the session when run from CLI
#enable_cli_session = false
# Path to store trajectories
#trajectories_path="./trajectories"
# File store path
#file_store_path = "/tmp/file_store"

View File

@ -28,6 +28,7 @@ class AppConfig:
runtime: The runtime environment.
file_store: The file store to use.
file_store_path: The path to the file store.
trajectories_path: The folder path to store trajectories.
workspace_base: The base path for the workspace. Defaults to ./workspace as an absolute path.
workspace_mount_path: The path to mount the workspace. This is set to the workspace base by default.
workspace_mount_path_in_sandbox: The path to mount the workspace in the sandbox. Defaults to /workspace.
@ -53,6 +54,7 @@ class AppConfig:
runtime: str = 'eventstream'
file_store: str = 'memory'
file_store_path: str = '/tmp/file_store'
trajectories_path: str | None = None
# TODO: clean up workspace path after the removal of ServerRuntime
workspace_base: str = os.path.join(os.getcwd(), 'workspace')
workspace_mount_path: str | None = (

View File

@ -1,5 +1,7 @@
import asyncio
import hashlib
import json
import os
import sys
import uuid
from typing import Callable, Protocol, Type
@ -21,6 +23,7 @@ from openhands.events.action import MessageAction
from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization.event import event_to_trajectory
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.runtime.runtime import Runtime
@ -202,6 +205,17 @@ async def run_controller(
await controller.close()
state = controller.get_state()
# save trajectories if applicable
if config.trajectories_path is not None:
file_path = os.path.join(config.trajectories_path, sid + '.json')
os.makedirs(os.path.dirname(file_path), exist_ok=True)
histories = [
event_to_trajectory(event)
for event in state.history.get_events(include_delegates=True)
]
with open(file_path, 'w') as f:
json.dump(histories, f)
return state

View File

@ -5,6 +5,7 @@ from openhands.events.serialization.event import (
event_from_dict,
event_to_dict,
event_to_memory,
event_to_trajectory,
)
from openhands.events.serialization.observation import (
observation_from_dict,
@ -15,5 +16,6 @@ __all__ = [
'event_from_dict',
'event_to_dict',
'event_to_memory',
'event_to_trajectory',
'observation_from_dict',
]

View File

@ -11,11 +11,10 @@ from openhands.events.serialization.utils import remove_fields
TOP_KEYS = ['id', 'timestamp', 'source', 'message', 'cause', 'action', 'observation']
UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause']
DELETE_FROM_MEMORY_EXTRAS = {
DELETE_FROM_TRAJECTORY_EXTRAS = {
'screenshot',
'dom_object',
'axtree_object',
'open_pages_urls',
'active_page_index',
'last_browser_action',
'last_browser_action_error',
@ -23,6 +22,8 @@ DELETE_FROM_MEMORY_EXTRAS = {
'extra_element_properties',
}
DELETE_FROM_MEMORY_EXTRAS = DELETE_FROM_TRAJECTORY_EXTRAS | {'open_pages_urls'}
def event_from_dict(data) -> 'Event':
evt: Event
@ -73,6 +74,13 @@ def event_to_dict(event: 'Event') -> dict:
return d
def event_to_trajectory(event: 'Event') -> dict:
d = event_to_dict(event)
if 'extras' in d:
remove_fields(d['extras'], DELETE_FROM_TRAJECTORY_EXTRAS)
return d
def event_to_memory(event: 'Event', max_message_chars: int) -> dict:
d = event_to_dict(event)
d.pop('id', None)

View File

@ -6,6 +6,7 @@ from openhands.events.serialization import (
event_from_dict,
event_to_dict,
event_to_memory,
event_to_trajectory,
)
@ -20,12 +21,16 @@ def serialization_deserialization(
observation_instance, cls
), 'The observation instance should be an instance of CmdOutputObservation.'
serialized_observation_dict = event_to_dict(observation_instance)
serialized_observation_trajectory = event_to_trajectory(observation_instance)
serialized_observation_memory = event_to_memory(
observation_instance, max_message_chars
)
assert (
serialized_observation_dict == original_observation_dict
), 'The serialized observation should match the original observation dict.'
assert (
serialized_observation_trajectory == original_observation_dict
), 'The serialized observation trajectory should match the original observation dict.'
original_observation_dict.pop('message', None)
original_observation_dict.pop('id', None)
original_observation_dict.pop('timestamp', None)