diff --git a/config.template.toml b/config.template.toml index 7673744bba..41f14eb349 100644 --- a/config.template.toml +++ b/config.template.toml @@ -28,6 +28,9 @@ workspace_base = "./workspace" # Enable saving and restoring the session when run from CLI #enable_cli_session = false +# Path to store trajectories +#trajectories_path="./trajectories" + # File store path #file_store_path = "/tmp/file_store" diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index 043607c535..0ca622f7d8 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -28,6 +28,7 @@ class AppConfig: runtime: The runtime environment. file_store: The file store to use. file_store_path: The path to the file store. + trajectories_path: The folder path to store trajectories. workspace_base: The base path for the workspace. Defaults to ./workspace as an absolute path. workspace_mount_path: The path to mount the workspace. This is set to the workspace base by default. workspace_mount_path_in_sandbox: The path to mount the workspace in the sandbox. Defaults to /workspace. @@ -53,6 +54,7 @@ class AppConfig: runtime: str = 'eventstream' file_store: str = 'memory' file_store_path: str = '/tmp/file_store' + trajectories_path: str | None = None # TODO: clean up workspace path after the removal of ServerRuntime workspace_base: str = os.path.join(os.getcwd(), 'workspace') workspace_mount_path: str | None = ( diff --git a/openhands/core/main.py b/openhands/core/main.py index b0702c943b..0ebcde8527 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -1,5 +1,7 @@ import asyncio import hashlib +import json +import os import sys import uuid from typing import Callable, Protocol, Type @@ -21,6 +23,7 @@ from openhands.events.action import MessageAction from openhands.events.action.action import Action from openhands.events.event import Event from openhands.events.observation import AgentStateChangedObservation +from openhands.events.serialization.event import event_to_trajectory from openhands.llm.llm import LLM from openhands.runtime import get_runtime_cls from openhands.runtime.runtime import Runtime @@ -202,6 +205,17 @@ async def run_controller( await controller.close() state = controller.get_state() + # save trajectories if applicable + if config.trajectories_path is not None: + file_path = os.path.join(config.trajectories_path, sid + '.json') + os.makedirs(os.path.dirname(file_path), exist_ok=True) + histories = [ + event_to_trajectory(event) + for event in state.history.get_events(include_delegates=True) + ] + with open(file_path, 'w') as f: + json.dump(histories, f) + return state diff --git a/openhands/events/serialization/__init__.py b/openhands/events/serialization/__init__.py index 67b9d30d08..f36d08d86c 100644 --- a/openhands/events/serialization/__init__.py +++ b/openhands/events/serialization/__init__.py @@ -5,6 +5,7 @@ from openhands.events.serialization.event import ( event_from_dict, event_to_dict, event_to_memory, + event_to_trajectory, ) from openhands.events.serialization.observation import ( observation_from_dict, @@ -15,5 +16,6 @@ __all__ = [ 'event_from_dict', 'event_to_dict', 'event_to_memory', + 'event_to_trajectory', 'observation_from_dict', ] diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py index bf5fb72cee..36883ae436 100644 --- a/openhands/events/serialization/event.py +++ b/openhands/events/serialization/event.py @@ -11,11 +11,10 @@ from openhands.events.serialization.utils import remove_fields TOP_KEYS = ['id', 'timestamp', 'source', 'message', 'cause', 'action', 'observation'] UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause'] -DELETE_FROM_MEMORY_EXTRAS = { +DELETE_FROM_TRAJECTORY_EXTRAS = { 'screenshot', 'dom_object', 'axtree_object', - 'open_pages_urls', 'active_page_index', 'last_browser_action', 'last_browser_action_error', @@ -23,6 +22,8 @@ DELETE_FROM_MEMORY_EXTRAS = { 'extra_element_properties', } +DELETE_FROM_MEMORY_EXTRAS = DELETE_FROM_TRAJECTORY_EXTRAS | {'open_pages_urls'} + def event_from_dict(data) -> 'Event': evt: Event @@ -73,6 +74,13 @@ def event_to_dict(event: 'Event') -> dict: return d +def event_to_trajectory(event: 'Event') -> dict: + d = event_to_dict(event) + if 'extras' in d: + remove_fields(d['extras'], DELETE_FROM_TRAJECTORY_EXTRAS) + return d + + def event_to_memory(event: 'Event', max_message_chars: int) -> dict: d = event_to_dict(event) d.pop('id', None) diff --git a/tests/unit/test_observation_serialization.py b/tests/unit/test_observation_serialization.py index 212b494667..252989517a 100644 --- a/tests/unit/test_observation_serialization.py +++ b/tests/unit/test_observation_serialization.py @@ -6,6 +6,7 @@ from openhands.events.serialization import ( event_from_dict, event_to_dict, event_to_memory, + event_to_trajectory, ) @@ -20,12 +21,16 @@ def serialization_deserialization( observation_instance, cls ), 'The observation instance should be an instance of CmdOutputObservation.' serialized_observation_dict = event_to_dict(observation_instance) + serialized_observation_trajectory = event_to_trajectory(observation_instance) serialized_observation_memory = event_to_memory( observation_instance, max_message_chars ) assert ( serialized_observation_dict == original_observation_dict ), 'The serialized observation should match the original observation dict.' + assert ( + serialized_observation_trajectory == original_observation_dict + ), 'The serialized observation trajectory should match the original observation dict.' original_observation_dict.pop('message', None) original_observation_dict.pop('id', None) original_observation_dict.pop('timestamp', None)