mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Dump trajectories with delegate history if configured (#4336)
This commit is contained in:
parent
343cc8710f
commit
7186224899
@ -28,6 +28,9 @@ workspace_base = "./workspace"
|
||||
# Enable saving and restoring the session when run from CLI
|
||||
#enable_cli_session = false
|
||||
|
||||
# Path to store trajectories
|
||||
#trajectories_path="./trajectories"
|
||||
|
||||
# File store path
|
||||
#file_store_path = "/tmp/file_store"
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ class AppConfig:
|
||||
runtime: The runtime environment.
|
||||
file_store: The file store to use.
|
||||
file_store_path: The path to the file store.
|
||||
trajectories_path: The folder path to store trajectories.
|
||||
workspace_base: The base path for the workspace. Defaults to ./workspace as an absolute path.
|
||||
workspace_mount_path: The path to mount the workspace. This is set to the workspace base by default.
|
||||
workspace_mount_path_in_sandbox: The path to mount the workspace in the sandbox. Defaults to /workspace.
|
||||
@ -53,6 +54,7 @@ class AppConfig:
|
||||
runtime: str = 'eventstream'
|
||||
file_store: str = 'memory'
|
||||
file_store_path: str = '/tmp/file_store'
|
||||
trajectories_path: str | None = None
|
||||
# TODO: clean up workspace path after the removal of ServerRuntime
|
||||
workspace_base: str = os.path.join(os.getcwd(), 'workspace')
|
||||
workspace_mount_path: str | None = (
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Callable, Protocol, Type
|
||||
@ -21,6 +23,7 @@ from openhands.events.action import MessageAction
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.events.serialization.event import event_to_trajectory
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.runtime import Runtime
|
||||
@ -202,6 +205,17 @@ async def run_controller(
|
||||
await controller.close()
|
||||
state = controller.get_state()
|
||||
|
||||
# save trajectories if applicable
|
||||
if config.trajectories_path is not None:
|
||||
file_path = os.path.join(config.trajectories_path, sid + '.json')
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
histories = [
|
||||
event_to_trajectory(event)
|
||||
for event in state.history.get_events(include_delegates=True)
|
||||
]
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(histories, f)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from openhands.events.serialization.event import (
|
||||
event_from_dict,
|
||||
event_to_dict,
|
||||
event_to_memory,
|
||||
event_to_trajectory,
|
||||
)
|
||||
from openhands.events.serialization.observation import (
|
||||
observation_from_dict,
|
||||
@ -15,5 +16,6 @@ __all__ = [
|
||||
'event_from_dict',
|
||||
'event_to_dict',
|
||||
'event_to_memory',
|
||||
'event_to_trajectory',
|
||||
'observation_from_dict',
|
||||
]
|
||||
|
||||
@ -11,11 +11,10 @@ from openhands.events.serialization.utils import remove_fields
|
||||
TOP_KEYS = ['id', 'timestamp', 'source', 'message', 'cause', 'action', 'observation']
|
||||
UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause']
|
||||
|
||||
DELETE_FROM_MEMORY_EXTRAS = {
|
||||
DELETE_FROM_TRAJECTORY_EXTRAS = {
|
||||
'screenshot',
|
||||
'dom_object',
|
||||
'axtree_object',
|
||||
'open_pages_urls',
|
||||
'active_page_index',
|
||||
'last_browser_action',
|
||||
'last_browser_action_error',
|
||||
@ -23,6 +22,8 @@ DELETE_FROM_MEMORY_EXTRAS = {
|
||||
'extra_element_properties',
|
||||
}
|
||||
|
||||
DELETE_FROM_MEMORY_EXTRAS = DELETE_FROM_TRAJECTORY_EXTRAS | {'open_pages_urls'}
|
||||
|
||||
|
||||
def event_from_dict(data) -> 'Event':
|
||||
evt: Event
|
||||
@ -73,6 +74,13 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
return d
|
||||
|
||||
|
||||
def event_to_trajectory(event: 'Event') -> dict:
|
||||
d = event_to_dict(event)
|
||||
if 'extras' in d:
|
||||
remove_fields(d['extras'], DELETE_FROM_TRAJECTORY_EXTRAS)
|
||||
return d
|
||||
|
||||
|
||||
def event_to_memory(event: 'Event', max_message_chars: int) -> dict:
|
||||
d = event_to_dict(event)
|
||||
d.pop('id', None)
|
||||
|
||||
@ -6,6 +6,7 @@ from openhands.events.serialization import (
|
||||
event_from_dict,
|
||||
event_to_dict,
|
||||
event_to_memory,
|
||||
event_to_trajectory,
|
||||
)
|
||||
|
||||
|
||||
@ -20,12 +21,16 @@ def serialization_deserialization(
|
||||
observation_instance, cls
|
||||
), 'The observation instance should be an instance of CmdOutputObservation.'
|
||||
serialized_observation_dict = event_to_dict(observation_instance)
|
||||
serialized_observation_trajectory = event_to_trajectory(observation_instance)
|
||||
serialized_observation_memory = event_to_memory(
|
||||
observation_instance, max_message_chars
|
||||
)
|
||||
assert (
|
||||
serialized_observation_dict == original_observation_dict
|
||||
), 'The serialized observation should match the original observation dict.'
|
||||
assert (
|
||||
serialized_observation_trajectory == original_observation_dict
|
||||
), 'The serialized observation trajectory should match the original observation dict.'
|
||||
original_observation_dict.pop('message', None)
|
||||
original_observation_dict.pop('id', None)
|
||||
original_observation_dict.pop('timestamp', None)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user