(feat) Add trajectory replay for headless mode (#6215)

This commit is contained in:
Boxuan Li 2025-01-17 21:48:22 -08:00 committed by GitHub
parent b4d20e3e18
commit 4383be1ab4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 188 additions and 42 deletions

View File

@ -39,6 +39,11 @@ workspace_base = "./workspace"
# If it's a folder, the session id will be used as the file name
#save_trajectory_path="./trajectories"
# Path to replay a trajectory, must be a file path
# If provided, trajectory will be loaded and replayed before the
# agent responds to any user instruction
#replay_trajectory_path = ""
# File store path
#file_store_path = "/tmp/file_store"

View File

@ -55,6 +55,11 @@ The core configuration options are defined in the `[core]` section of the `confi
- Default: `"./trajectories"`
- Description: Path to store trajectories (can be a folder or a file). If it's a folder, the trajectories will be saved in a file named with the session id name and .json extension, in that folder.
- `replay_trajectory_path`
- Type: `str`
- Default: `""`
- Description: Path to load a trajectory and replay. If given, must be a path to the trajectory file in JSON format. The actions in the trajectory file would be replayed first before any user instruction is executed.
### File Store
- `file_store_path`
- Type: `str`

View File

@ -12,6 +12,7 @@ from litellm.exceptions import (
)
from openhands.controller.agent import Agent
from openhands.controller.replay import ReplayManager
from openhands.controller.state.state import State, TrafficControlState
from openhands.controller.stuck import StuckDetector
from openhands.core.config import AgentConfig, LLMConfig
@ -90,6 +91,7 @@ class AgentController:
is_delegate: bool = False,
headless_mode: bool = True,
status_callback: Callable | None = None,
replay_events: list[Event] | None = None,
):
"""Initializes a new instance of the AgentController class.
@ -108,6 +110,7 @@ class AgentController:
is_delegate: Whether this controller is a delegate.
headless_mode: Whether the agent is run in headless mode.
status_callback: Optional callback function to handle status updates.
replay_events: A list of logs to replay.
"""
self.id = sid
self.agent = agent
@ -139,6 +142,9 @@ class AgentController:
self._stuck_detector = StuckDetector(self.state)
self.status_callback = status_callback
# replay-related
self._replay_manager = ReplayManager(replay_events)
async def close(self) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
@ -234,6 +240,11 @@ class AgentController:
await self._react_to_exception(reported)
def should_step(self, event: Event) -> bool:
"""
Whether the agent should take a step based on an event. In general,
the agent should take a step if it receives a message from the user,
or observes something in the environment (after acting).
"""
# it might be the delegate's day in the sun
if self.delegate is not None:
return False
@ -641,42 +652,50 @@ class AgentController:
self.update_state_before_step()
action: Action = NullAction()
try:
action = self.agent.step(self.state)
if action is None:
raise LLMNoActionError('No action was returned')
except (
LLMMalformedActionError,
LLMNoActionError,
LLMResponseError,
FunctionCallValidationError,
FunctionCallNotExistsError,
) as e:
self.event_stream.add_event(
ErrorObservation(
content=str(e),
),
EventSource.AGENT,
)
return
except (ContextWindowExceededError, BadRequestError) as e:
# FIXME: this is a hack until a litellm fix is confirmed
# Check if this is a nested context window error
error_str = str(e).lower()
if (
'contextwindowexceedederror' in error_str
or 'prompt is too long' in error_str
or isinstance(e, ContextWindowExceededError)
):
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(self.state.history)
# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Don't add error event - let the agent retry with reduced context
if self._replay_manager.should_replay():
# in replay mode, we don't let the agent to proceed
# instead, we replay the action from the replay trajectory
action = self._replay_manager.step()
else:
try:
action = self.agent.step(self.state)
if action is None:
raise LLMNoActionError('No action was returned')
except (
LLMMalformedActionError,
LLMNoActionError,
LLMResponseError,
FunctionCallValidationError,
FunctionCallNotExistsError,
) as e:
self.event_stream.add_event(
ErrorObservation(
content=str(e),
),
EventSource.AGENT,
)
return
raise
except (ContextWindowExceededError, BadRequestError) as e:
# FIXME: this is a hack until a litellm fix is confirmed
# Check if this is a nested context window error
error_str = str(e).lower()
if (
'contextwindowexceedederror' in error_str
or 'prompt is too long' in error_str
or isinstance(e, ContextWindowExceededError)
):
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(
self.state.history
)
# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Don't add error event - let the agent retry with reduced context
return
raise
if action.runnable:
if self.state.confirmation_mode and (

View File

@ -0,0 +1,52 @@
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.event import Event, EventSource
class ReplayManager:
"""ReplayManager manages the lifecycle of a replay session of a given trajectory.
Replay manager keeps track of a list of events, replays actions, and ignore
messages and observations. It could lead to unexpected or even errorneous
results if any action is non-deterministic, or if the initial state before
the replay session is different from the initial state of the trajectory.
"""
def __init__(self, replay_events: list[Event] | None):
if replay_events:
logger.info(f'Replay logs loaded, events length = {len(replay_events)}')
self.replay_events = replay_events
self.replay_mode = bool(replay_events)
self.replay_index = 0
def _replayable(self) -> bool:
return (
self.replay_events is not None
and self.replay_index < len(self.replay_events)
and isinstance(self.replay_events[self.replay_index], Action)
and self.replay_events[self.replay_index].source != EventSource.USER
)
def should_replay(self) -> bool:
"""
Whether the controller is in trajectory replay mode, and the replay
hasn't finished. Note: after the replay is finished, the user and
the agent could continue to message/act.
This method also moves "replay_index" to the next action, if applicable.
"""
if not self.replay_mode:
return False
assert self.replay_events is not None
while self.replay_index < len(self.replay_events) and not self._replayable():
self.replay_index += 1
return self._replayable()
def step(self) -> Action:
assert self.replay_events is not None
event = self.replay_events[self.replay_index]
assert isinstance(event, Action)
self.replay_index += 1
return event

View File

@ -28,6 +28,7 @@ class AppConfig(BaseModel):
file_store: Type of file store to use.
file_store_path: Path to the file store.
save_trajectory_path: Either a folder path to store trajectories with auto-generated filenames, or a designated trajectory file path.
replay_trajectory_path: Path to load trajectory and replay. If provided, trajectory would be replayed first before user's instruction.
workspace_base: Base path for the workspace. Defaults to `./workspace` as absolute path.
workspace_mount_path: Path to mount the workspace. Defaults to `workspace_base`.
workspace_mount_path_in_sandbox: Path to mount the workspace in sandbox. Defaults to `/workspace`.
@ -55,6 +56,7 @@ class AppConfig(BaseModel):
file_store: str = Field(default='local')
file_store_path: str = Field(default='/tmp/openhands_file_store')
save_trajectory_path: str | None = Field(default=None)
replay_trajectory_path: str | None = Field(default=None)
workspace_base: str | None = Field(default=None)
workspace_mount_path: str | None = Field(default=None)
workspace_mount_path_in_sandbox: str = Field(default='/workspace')

View File

@ -2,6 +2,7 @@ import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Callable, Protocol
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
@ -22,10 +23,11 @@ from openhands.core.setup import (
generate_sid,
)
from openhands.events import EventSource, EventStreamSubscriber
from openhands.events.action import MessageAction
from openhands.events.action import MessageAction, NullAction
from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.events.serialization.event import event_to_trajectory
from openhands.runtime.base import Runtime
@ -101,7 +103,17 @@ async def run_controller(
if agent is None:
agent = create_agent(runtime, config)
controller, initial_state = create_controller(agent, runtime, config)
replay_events: list[Event] | None = None
if config.replay_trajectory_path:
logger.info('Trajectory replay is enabled')
assert isinstance(initial_user_action, NullAction)
replay_events, initial_user_action = load_replay_log(
config.replay_trajectory_path
)
controller, initial_state = create_controller(
agent, runtime, config, replay_events=replay_events
)
assert isinstance(
initial_user_action, Action
@ -194,21 +206,64 @@ def auto_continue_response(
return message
def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
"""
Load trajectory from given path, serialize it to a list of events, and return
two things:
1) A list of events except the first action
2) First action (user message, a.k.a. initial task)
"""
try:
path = Path(trajectory_path).resolve()
if not path.exists():
raise ValueError(f'Trajectory file not found: {path}')
if not path.is_file():
raise ValueError(f'Trajectory path is a directory, not a file: {path}')
with open(path, 'r', encoding='utf-8') as file:
data = json.load(file)
if not isinstance(data, list):
raise ValueError(
f'Expected a list in {path}, got {type(data).__name__}'
)
events = []
for item in data:
event = event_from_dict(item)
# cannot add an event with _id to event stream
event._id = None # type: ignore[attr-defined]
events.append(event)
assert isinstance(events[0], MessageAction)
return events[1:], events[0]
except json.JSONDecodeError as e:
raise ValueError(f'Invalid JSON format in {trajectory_path}: {e}')
if __name__ == '__main__':
args = parse_arguments()
config = setup_config_from_args(args)
# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_task_from_stdin()
initial_user_action: Action = NullAction()
if config.replay_trajectory_path:
if task_str:
raise ValueError(
'User-specified task is not supported under trajectory replay mode'
)
elif task_str:
initial_user_action = MessageAction(content=task_str)
else:
raise ValueError('No task provided. Please specify a task through -t, -f.')
initial_user_action: MessageAction = MessageAction(content=task_str)
config = setup_config_from_args(args)
# Set session name
session_name = args.name

View File

@ -11,6 +11,7 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
@ -78,7 +79,11 @@ def create_agent(runtime: Runtime, config: AppConfig) -> Agent:
def create_controller(
agent: Agent, runtime: Runtime, config: AppConfig, headless_mode: bool = True
agent: Agent,
runtime: Runtime,
config: AppConfig,
headless_mode: bool = True,
replay_events: list[Event] | None = None,
) -> Tuple[AgentController, State | None]:
event_stream = runtime.event_stream
initial_state = None
@ -101,6 +106,7 @@ def create_controller(
initial_state=initial_state,
headless_mode=headless_mode,
confirmation_mode=config.security.confirmation_mode,
replay_events=replay_events,
)
return (controller, initial_state)

View File

@ -24,6 +24,8 @@ class FileReadSource(str, Enum):
@dataclass
class Event:
INVALID_ID = -1
@property
def message(self) -> str | None:
if hasattr(self, '_message'):
@ -34,7 +36,7 @@ class Event:
def id(self) -> int:
if hasattr(self, '_id'):
return self._id # type: ignore[attr-defined]
return -1
return Event.INVALID_ID
@property
def timestamp(self):

View File

@ -12,7 +12,7 @@ class BrowserOutputObservation(Observation):
url: str
trigger_by_action: str
screenshot: str = field(repr=False) # don't show in repr
screenshot: str = field(repr=False, default='') # don't show in repr
error: bool = False
observation: str = ObservationType.BROWSE
# do not include in the memory