Config to save screenshots in trajectory (#7284)

This commit is contained in:
Boxuan Li 2025-03-21 22:43:01 -07:00 committed by GitHub
parent 0fec237ead
commit d343e4ed9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 25 additions and 8 deletions

View File

@ -42,6 +42,10 @@ workspace_base = "./workspace"
# If it's a folder, the session id will be used as the file name
#save_trajectory_path="./trajectories"
# Whether to save screenshots in the trajectory
# The screenshots are encoded and can make trajectory json files very large
#save_screenshots_in_trajectory = false
# Path to replay a trajectory, must be a file path
# If provided, trajectory will be loaded and replayed before the
# agent responds to any user instruction

View File

@ -897,10 +897,13 @@ class AgentController:
# Always load from the event stream to avoid losing history
self._init_history()
def get_trajectory(self) -> list[dict]:
def get_trajectory(self, include_screenshots: bool = False) -> list[dict]:
# state history could be partially hidden/truncated before controller is closed
assert self._closed
return [event_to_trajectory(event) for event in self.state.history]
return [
event_to_trajectory(event, include_screenshots)
for event in self.state.history
]
def _init_history(self) -> None:
"""Initializes the agent's history from the event stream.

View File

@ -29,6 +29,7 @@ class AppConfig(BaseModel):
file_store: Type of file store to use.
file_store_path: Path to the file store.
save_trajectory_path: Either a folder path to store trajectories with auto-generated filenames, or a designated trajectory file path.
save_screenshots_in_trajectory: Whether to save screenshots in trajectory (in encoded image format).
replay_trajectory_path: Path to load trajectory and replay. If provided, trajectory would be replayed first before user's instruction.
workspace_base: Base path for the workspace. Defaults to `./workspace` as absolute path.
workspace_mount_path: Path to mount the workspace. Defaults to `workspace_base`.
@ -58,6 +59,7 @@ class AppConfig(BaseModel):
file_store: str = Field(default='local')
file_store_path: str = Field(default='/tmp/openhands_file_store')
save_trajectory_path: str | None = Field(default=None)
save_screenshots_in_trajectory: bool = Field(default=False)
replay_trajectory_path: str | None = Field(default=None)
workspace_base: str | None = Field(default=None)
workspace_mount_path: str | None = Field(default=None)

View File

@ -210,9 +210,9 @@ async def run_controller(
else:
file_path = config.save_trajectory_path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
histories = controller.get_trajectory()
histories = controller.get_trajectory(config.save_screenshots_in_trajectory)
with open(file_path, 'w') as f:
json.dump(histories, f)
json.dump(histories, f, indent=4)
return state

View File

@ -5,7 +5,6 @@ from enum import Enum
from pydantic import BaseModel
from openhands.events import Event, EventSource
from openhands.events.observation.observation import Observation
from openhands.events.serialization.action import action_from_dict
from openhands.events.serialization.observation import observation_from_dict
from openhands.events.serialization.utils import remove_fields
@ -34,7 +33,6 @@ UNDERSCORE_KEYS = [
]
DELETE_FROM_TRAJECTORY_EXTRAS = {
'screenshot',
'dom_object',
'axtree_object',
'active_page_index',
@ -44,6 +42,11 @@ DELETE_FROM_TRAJECTORY_EXTRAS = {
'extra_element_properties',
}
DELETE_FROM_TRAJECTORY_EXTRAS_AND_SCREENSHOTS = DELETE_FROM_TRAJECTORY_EXTRAS | {
'screenshot',
'set_of_marks',
}
def event_from_dict(data) -> 'Event':
evt: Event
@ -133,10 +136,15 @@ def event_to_dict(event: 'Event') -> dict:
return d
def event_to_trajectory(event: 'Event') -> dict:
def event_to_trajectory(event: 'Event', include_screenshots: bool = False) -> dict:
d = event_to_dict(event)
if 'extras' in d:
remove_fields(d['extras'], DELETE_FROM_TRAJECTORY_EXTRAS)
remove_fields(
d['extras'],
DELETE_FROM_TRAJECTORY_EXTRAS
if include_screenshots
else DELETE_FROM_TRAJECTORY_EXTRAS_AND_SCREENSHOTS,
)
return d