diff --git a/config.template.toml b/config.template.toml index ccb7b11597..7f256898b3 100644 --- a/config.template.toml +++ b/config.template.toml @@ -39,6 +39,11 @@ workspace_base = "./workspace" # If it's a folder, the session id will be used as the file name #save_trajectory_path="./trajectories" +# Path to replay a trajectory, must be a file path +# If provided, trajectory will be loaded and replayed before the +# agent responds to any user instruction +#replay_trajectory_path = "" + # File store path #file_store_path = "/tmp/file_store" diff --git a/docs/modules/usage/configuration-options.md b/docs/modules/usage/configuration-options.md index a3c11de52e..ff0aa5674c 100644 --- a/docs/modules/usage/configuration-options.md +++ b/docs/modules/usage/configuration-options.md @@ -55,6 +55,11 @@ The core configuration options are defined in the `[core]` section of the `confi - Default: `"./trajectories"` - Description: Path to store trajectories (can be a folder or a file). If it's a folder, the trajectories will be saved in a file named with the session id name and .json extension, in that folder. +- `replay_trajectory_path` + - Type: `str` + - Default: `""` + - Description: Path to load a trajectory and replay. If given, must be a path to the trajectory file in JSON format. The actions in the trajectory file would be replayed first before any user instruction is executed. + ### File Store - `file_store_path` - Type: `str` diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 521968c88d..87bdd08173 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -12,6 +12,7 @@ from litellm.exceptions import ( ) from openhands.controller.agent import Agent +from openhands.controller.replay import ReplayManager from openhands.controller.state.state import State, TrafficControlState from openhands.controller.stuck import StuckDetector from openhands.core.config import AgentConfig, LLMConfig @@ -90,6 +91,7 @@ class AgentController: is_delegate: bool = False, headless_mode: bool = True, status_callback: Callable | None = None, + replay_events: list[Event] | None = None, ): """Initializes a new instance of the AgentController class. @@ -108,6 +110,7 @@ class AgentController: is_delegate: Whether this controller is a delegate. headless_mode: Whether the agent is run in headless mode. status_callback: Optional callback function to handle status updates. + replay_events: A list of logs to replay. """ self.id = sid self.agent = agent @@ -139,6 +142,9 @@ class AgentController: self._stuck_detector = StuckDetector(self.state) self.status_callback = status_callback + # replay-related + self._replay_manager = ReplayManager(replay_events) + async def close(self) -> None: """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream. @@ -234,6 +240,11 @@ class AgentController: await self._react_to_exception(reported) def should_step(self, event: Event) -> bool: + """ + Whether the agent should take a step based on an event. In general, + the agent should take a step if it receives a message from the user, + or observes something in the environment (after acting). + """ # it might be the delegate's day in the sun if self.delegate is not None: return False @@ -641,42 +652,50 @@ class AgentController: self.update_state_before_step() action: Action = NullAction() - try: - action = self.agent.step(self.state) - if action is None: - raise LLMNoActionError('No action was returned') - except ( - LLMMalformedActionError, - LLMNoActionError, - LLMResponseError, - FunctionCallValidationError, - FunctionCallNotExistsError, - ) as e: - self.event_stream.add_event( - ErrorObservation( - content=str(e), - ), - EventSource.AGENT, - ) - return - except (ContextWindowExceededError, BadRequestError) as e: - # FIXME: this is a hack until a litellm fix is confirmed - # Check if this is a nested context window error - error_str = str(e).lower() - if ( - 'contextwindowexceedederror' in error_str - or 'prompt is too long' in error_str - or isinstance(e, ContextWindowExceededError) - ): - # When context window is exceeded, keep roughly half of agent interactions - self.state.history = self._apply_conversation_window(self.state.history) - # Save the ID of the first event in our truncated history for future reloading - if self.state.history: - self.state.start_id = self.state.history[0].id - # Don't add error event - let the agent retry with reduced context + if self._replay_manager.should_replay(): + # in replay mode, we don't let the agent to proceed + # instead, we replay the action from the replay trajectory + action = self._replay_manager.step() + else: + try: + action = self.agent.step(self.state) + if action is None: + raise LLMNoActionError('No action was returned') + except ( + LLMMalformedActionError, + LLMNoActionError, + LLMResponseError, + FunctionCallValidationError, + FunctionCallNotExistsError, + ) as e: + self.event_stream.add_event( + ErrorObservation( + content=str(e), + ), + EventSource.AGENT, + ) return - raise + except (ContextWindowExceededError, BadRequestError) as e: + # FIXME: this is a hack until a litellm fix is confirmed + # Check if this is a nested context window error + error_str = str(e).lower() + if ( + 'contextwindowexceedederror' in error_str + or 'prompt is too long' in error_str + or isinstance(e, ContextWindowExceededError) + ): + # When context window is exceeded, keep roughly half of agent interactions + self.state.history = self._apply_conversation_window( + self.state.history + ) + + # Save the ID of the first event in our truncated history for future reloading + if self.state.history: + self.state.start_id = self.state.history[0].id + # Don't add error event - let the agent retry with reduced context + return + raise if action.runnable: if self.state.confirmation_mode and ( diff --git a/openhands/controller/replay.py b/openhands/controller/replay.py new file mode 100644 index 0000000000..c960d4a1fb --- /dev/null +++ b/openhands/controller/replay.py @@ -0,0 +1,52 @@ +from openhands.core.logger import openhands_logger as logger +from openhands.events.action.action import Action +from openhands.events.event import Event, EventSource + + +class ReplayManager: + """ReplayManager manages the lifecycle of a replay session of a given trajectory. + + Replay manager keeps track of a list of events, replays actions, and ignore + messages and observations. It could lead to unexpected or even errorneous + results if any action is non-deterministic, or if the initial state before + the replay session is different from the initial state of the trajectory. + """ + + def __init__(self, replay_events: list[Event] | None): + if replay_events: + logger.info(f'Replay logs loaded, events length = {len(replay_events)}') + self.replay_events = replay_events + self.replay_mode = bool(replay_events) + self.replay_index = 0 + + def _replayable(self) -> bool: + return ( + self.replay_events is not None + and self.replay_index < len(self.replay_events) + and isinstance(self.replay_events[self.replay_index], Action) + and self.replay_events[self.replay_index].source != EventSource.USER + ) + + def should_replay(self) -> bool: + """ + Whether the controller is in trajectory replay mode, and the replay + hasn't finished. Note: after the replay is finished, the user and + the agent could continue to message/act. + + This method also moves "replay_index" to the next action, if applicable. + """ + if not self.replay_mode: + return False + + assert self.replay_events is not None + while self.replay_index < len(self.replay_events) and not self._replayable(): + self.replay_index += 1 + + return self._replayable() + + def step(self) -> Action: + assert self.replay_events is not None + event = self.replay_events[self.replay_index] + assert isinstance(event, Action) + self.replay_index += 1 + return event diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index 468d37572f..8c995d1ee3 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -28,6 +28,7 @@ class AppConfig(BaseModel): file_store: Type of file store to use. file_store_path: Path to the file store. save_trajectory_path: Either a folder path to store trajectories with auto-generated filenames, or a designated trajectory file path. + replay_trajectory_path: Path to load trajectory and replay. If provided, trajectory would be replayed first before user's instruction. workspace_base: Base path for the workspace. Defaults to `./workspace` as absolute path. workspace_mount_path: Path to mount the workspace. Defaults to `workspace_base`. workspace_mount_path_in_sandbox: Path to mount the workspace in sandbox. Defaults to `/workspace`. @@ -55,6 +56,7 @@ class AppConfig(BaseModel): file_store: str = Field(default='local') file_store_path: str = Field(default='/tmp/openhands_file_store') save_trajectory_path: str | None = Field(default=None) + replay_trajectory_path: str | None = Field(default=None) workspace_base: str | None = Field(default=None) workspace_mount_path: str | None = Field(default=None) workspace_mount_path_in_sandbox: str = Field(default='/workspace') diff --git a/openhands/core/main.py b/openhands/core/main.py index b27cac1e58..5c3b38a21b 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -2,6 +2,7 @@ import asyncio import json import os import sys +from pathlib import Path from typing import Callable, Protocol import openhands.agenthub # noqa F401 (we import this to get the agents registered) @@ -22,10 +23,11 @@ from openhands.core.setup import ( generate_sid, ) from openhands.events import EventSource, EventStreamSubscriber -from openhands.events.action import MessageAction +from openhands.events.action import MessageAction, NullAction from openhands.events.action.action import Action from openhands.events.event import Event from openhands.events.observation import AgentStateChangedObservation +from openhands.events.serialization import event_from_dict from openhands.events.serialization.event import event_to_trajectory from openhands.runtime.base import Runtime @@ -101,7 +103,17 @@ async def run_controller( if agent is None: agent = create_agent(runtime, config) - controller, initial_state = create_controller(agent, runtime, config) + replay_events: list[Event] | None = None + if config.replay_trajectory_path: + logger.info('Trajectory replay is enabled') + assert isinstance(initial_user_action, NullAction) + replay_events, initial_user_action = load_replay_log( + config.replay_trajectory_path + ) + + controller, initial_state = create_controller( + agent, runtime, config, replay_events=replay_events + ) assert isinstance( initial_user_action, Action @@ -194,21 +206,64 @@ def auto_continue_response( return message +def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]: + """ + Load trajectory from given path, serialize it to a list of events, and return + two things: + 1) A list of events except the first action + 2) First action (user message, a.k.a. initial task) + """ + try: + path = Path(trajectory_path).resolve() + + if not path.exists(): + raise ValueError(f'Trajectory file not found: {path}') + + if not path.is_file(): + raise ValueError(f'Trajectory path is a directory, not a file: {path}') + + with open(path, 'r', encoding='utf-8') as file: + data = json.load(file) + if not isinstance(data, list): + raise ValueError( + f'Expected a list in {path}, got {type(data).__name__}' + ) + events = [] + for item in data: + event = event_from_dict(item) + # cannot add an event with _id to event stream + event._id = None # type: ignore[attr-defined] + events.append(event) + assert isinstance(events[0], MessageAction) + return events[1:], events[0] + except json.JSONDecodeError as e: + raise ValueError(f'Invalid JSON format in {trajectory_path}: {e}') + + if __name__ == '__main__': args = parse_arguments() + config = setup_config_from_args(args) + # Determine the task + task_str = '' if args.file: task_str = read_task_from_file(args.file) elif args.task: task_str = args.task elif not sys.stdin.isatty(): task_str = read_task_from_stdin() + + initial_user_action: Action = NullAction() + if config.replay_trajectory_path: + if task_str: + raise ValueError( + 'User-specified task is not supported under trajectory replay mode' + ) + elif task_str: + initial_user_action = MessageAction(content=task_str) else: raise ValueError('No task provided. Please specify a task through -t, -f.') - initial_user_action: MessageAction = MessageAction(content=task_str) - - config = setup_config_from_args(args) # Set session name session_name = args.name diff --git a/openhands/core/setup.py b/openhands/core/setup.py index 2888847801..82bdaf0c20 100644 --- a/openhands/core/setup.py +++ b/openhands/core/setup.py @@ -11,6 +11,7 @@ from openhands.core.config import ( ) from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream +from openhands.events.event import Event from openhands.llm.llm import LLM from openhands.runtime import get_runtime_cls from openhands.runtime.base import Runtime @@ -78,7 +79,11 @@ def create_agent(runtime: Runtime, config: AppConfig) -> Agent: def create_controller( - agent: Agent, runtime: Runtime, config: AppConfig, headless_mode: bool = True + agent: Agent, + runtime: Runtime, + config: AppConfig, + headless_mode: bool = True, + replay_events: list[Event] | None = None, ) -> Tuple[AgentController, State | None]: event_stream = runtime.event_stream initial_state = None @@ -101,6 +106,7 @@ def create_controller( initial_state=initial_state, headless_mode=headless_mode, confirmation_mode=config.security.confirmation_mode, + replay_events=replay_events, ) return (controller, initial_state) diff --git a/openhands/events/event.py b/openhands/events/event.py index 1bdece59eb..9d7af19160 100644 --- a/openhands/events/event.py +++ b/openhands/events/event.py @@ -24,6 +24,8 @@ class FileReadSource(str, Enum): @dataclass class Event: + INVALID_ID = -1 + @property def message(self) -> str | None: if hasattr(self, '_message'): @@ -34,7 +36,7 @@ class Event: def id(self) -> int: if hasattr(self, '_id'): return self._id # type: ignore[attr-defined] - return -1 + return Event.INVALID_ID @property def timestamp(self): diff --git a/openhands/events/observation/browse.py b/openhands/events/observation/browse.py index 2ab73f047d..bc347d6574 100644 --- a/openhands/events/observation/browse.py +++ b/openhands/events/observation/browse.py @@ -12,7 +12,7 @@ class BrowserOutputObservation(Observation): url: str trigger_by_action: str - screenshot: str = field(repr=False) # don't show in repr + screenshot: str = field(repr=False, default='') # don't show in repr error: bool = False observation: str = ObservationType.BROWSE # do not include in the memory