mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
(feat) Add trajectory replay for headless mode (#6215)
This commit is contained in:
parent
b4d20e3e18
commit
4383be1ab4
@ -39,6 +39,11 @@ workspace_base = "./workspace"
|
||||
# If it's a folder, the session id will be used as the file name
|
||||
#save_trajectory_path="./trajectories"
|
||||
|
||||
# Path to replay a trajectory, must be a file path
|
||||
# If provided, trajectory will be loaded and replayed before the
|
||||
# agent responds to any user instruction
|
||||
#replay_trajectory_path = ""
|
||||
|
||||
# File store path
|
||||
#file_store_path = "/tmp/file_store"
|
||||
|
||||
|
||||
@ -55,6 +55,11 @@ The core configuration options are defined in the `[core]` section of the `confi
|
||||
- Default: `"./trajectories"`
|
||||
- Description: Path to store trajectories (can be a folder or a file). If it's a folder, the trajectories will be saved in a file named with the session id name and .json extension, in that folder.
|
||||
|
||||
- `replay_trajectory_path`
|
||||
- Type: `str`
|
||||
- Default: `""`
|
||||
- Description: Path to load a trajectory and replay. If given, must be a path to the trajectory file in JSON format. The actions in the trajectory file would be replayed first before any user instruction is executed.
|
||||
|
||||
### File Store
|
||||
- `file_store_path`
|
||||
- Type: `str`
|
||||
|
||||
@ -12,6 +12,7 @@ from litellm.exceptions import (
|
||||
)
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.replay import ReplayManager
|
||||
from openhands.controller.state.state import State, TrafficControlState
|
||||
from openhands.controller.stuck import StuckDetector
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
@ -90,6 +91,7 @@ class AgentController:
|
||||
is_delegate: bool = False,
|
||||
headless_mode: bool = True,
|
||||
status_callback: Callable | None = None,
|
||||
replay_events: list[Event] | None = None,
|
||||
):
|
||||
"""Initializes a new instance of the AgentController class.
|
||||
|
||||
@ -108,6 +110,7 @@ class AgentController:
|
||||
is_delegate: Whether this controller is a delegate.
|
||||
headless_mode: Whether the agent is run in headless mode.
|
||||
status_callback: Optional callback function to handle status updates.
|
||||
replay_events: A list of logs to replay.
|
||||
"""
|
||||
self.id = sid
|
||||
self.agent = agent
|
||||
@ -139,6 +142,9 @@ class AgentController:
|
||||
self._stuck_detector = StuckDetector(self.state)
|
||||
self.status_callback = status_callback
|
||||
|
||||
# replay-related
|
||||
self._replay_manager = ReplayManager(replay_events)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
|
||||
|
||||
@ -234,6 +240,11 @@ class AgentController:
|
||||
await self._react_to_exception(reported)
|
||||
|
||||
def should_step(self, event: Event) -> bool:
|
||||
"""
|
||||
Whether the agent should take a step based on an event. In general,
|
||||
the agent should take a step if it receives a message from the user,
|
||||
or observes something in the environment (after acting).
|
||||
"""
|
||||
# it might be the delegate's day in the sun
|
||||
if self.delegate is not None:
|
||||
return False
|
||||
@ -641,42 +652,50 @@ class AgentController:
|
||||
|
||||
self.update_state_before_step()
|
||||
action: Action = NullAction()
|
||||
try:
|
||||
action = self.agent.step(self.state)
|
||||
if action is None:
|
||||
raise LLMNoActionError('No action was returned')
|
||||
except (
|
||||
LLMMalformedActionError,
|
||||
LLMNoActionError,
|
||||
LLMResponseError,
|
||||
FunctionCallValidationError,
|
||||
FunctionCallNotExistsError,
|
||||
) as e:
|
||||
self.event_stream.add_event(
|
||||
ErrorObservation(
|
||||
content=str(e),
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
return
|
||||
except (ContextWindowExceededError, BadRequestError) as e:
|
||||
# FIXME: this is a hack until a litellm fix is confirmed
|
||||
# Check if this is a nested context window error
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
'contextwindowexceedederror' in error_str
|
||||
or 'prompt is too long' in error_str
|
||||
or isinstance(e, ContextWindowExceededError)
|
||||
):
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
self.state.history = self._apply_conversation_window(self.state.history)
|
||||
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
self.state.start_id = self.state.history[0].id
|
||||
# Don't add error event - let the agent retry with reduced context
|
||||
if self._replay_manager.should_replay():
|
||||
# in replay mode, we don't let the agent to proceed
|
||||
# instead, we replay the action from the replay trajectory
|
||||
action = self._replay_manager.step()
|
||||
else:
|
||||
try:
|
||||
action = self.agent.step(self.state)
|
||||
if action is None:
|
||||
raise LLMNoActionError('No action was returned')
|
||||
except (
|
||||
LLMMalformedActionError,
|
||||
LLMNoActionError,
|
||||
LLMResponseError,
|
||||
FunctionCallValidationError,
|
||||
FunctionCallNotExistsError,
|
||||
) as e:
|
||||
self.event_stream.add_event(
|
||||
ErrorObservation(
|
||||
content=str(e),
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
return
|
||||
raise
|
||||
except (ContextWindowExceededError, BadRequestError) as e:
|
||||
# FIXME: this is a hack until a litellm fix is confirmed
|
||||
# Check if this is a nested context window error
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
'contextwindowexceedederror' in error_str
|
||||
or 'prompt is too long' in error_str
|
||||
or isinstance(e, ContextWindowExceededError)
|
||||
):
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
self.state.history = self._apply_conversation_window(
|
||||
self.state.history
|
||||
)
|
||||
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
self.state.start_id = self.state.history[0].id
|
||||
# Don't add error event - let the agent retry with reduced context
|
||||
return
|
||||
raise
|
||||
|
||||
if action.runnable:
|
||||
if self.state.confirmation_mode and (
|
||||
|
||||
52
openhands/controller/replay.py
Normal file
52
openhands/controller/replay.py
Normal file
@ -0,0 +1,52 @@
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.event import Event, EventSource
|
||||
|
||||
|
||||
class ReplayManager:
|
||||
"""ReplayManager manages the lifecycle of a replay session of a given trajectory.
|
||||
|
||||
Replay manager keeps track of a list of events, replays actions, and ignore
|
||||
messages and observations. It could lead to unexpected or even errorneous
|
||||
results if any action is non-deterministic, or if the initial state before
|
||||
the replay session is different from the initial state of the trajectory.
|
||||
"""
|
||||
|
||||
def __init__(self, replay_events: list[Event] | None):
|
||||
if replay_events:
|
||||
logger.info(f'Replay logs loaded, events length = {len(replay_events)}')
|
||||
self.replay_events = replay_events
|
||||
self.replay_mode = bool(replay_events)
|
||||
self.replay_index = 0
|
||||
|
||||
def _replayable(self) -> bool:
|
||||
return (
|
||||
self.replay_events is not None
|
||||
and self.replay_index < len(self.replay_events)
|
||||
and isinstance(self.replay_events[self.replay_index], Action)
|
||||
and self.replay_events[self.replay_index].source != EventSource.USER
|
||||
)
|
||||
|
||||
def should_replay(self) -> bool:
|
||||
"""
|
||||
Whether the controller is in trajectory replay mode, and the replay
|
||||
hasn't finished. Note: after the replay is finished, the user and
|
||||
the agent could continue to message/act.
|
||||
|
||||
This method also moves "replay_index" to the next action, if applicable.
|
||||
"""
|
||||
if not self.replay_mode:
|
||||
return False
|
||||
|
||||
assert self.replay_events is not None
|
||||
while self.replay_index < len(self.replay_events) and not self._replayable():
|
||||
self.replay_index += 1
|
||||
|
||||
return self._replayable()
|
||||
|
||||
def step(self) -> Action:
|
||||
assert self.replay_events is not None
|
||||
event = self.replay_events[self.replay_index]
|
||||
assert isinstance(event, Action)
|
||||
self.replay_index += 1
|
||||
return event
|
||||
@ -28,6 +28,7 @@ class AppConfig(BaseModel):
|
||||
file_store: Type of file store to use.
|
||||
file_store_path: Path to the file store.
|
||||
save_trajectory_path: Either a folder path to store trajectories with auto-generated filenames, or a designated trajectory file path.
|
||||
replay_trajectory_path: Path to load trajectory and replay. If provided, trajectory would be replayed first before user's instruction.
|
||||
workspace_base: Base path for the workspace. Defaults to `./workspace` as absolute path.
|
||||
workspace_mount_path: Path to mount the workspace. Defaults to `workspace_base`.
|
||||
workspace_mount_path_in_sandbox: Path to mount the workspace in sandbox. Defaults to `/workspace`.
|
||||
@ -55,6 +56,7 @@ class AppConfig(BaseModel):
|
||||
file_store: str = Field(default='local')
|
||||
file_store_path: str = Field(default='/tmp/openhands_file_store')
|
||||
save_trajectory_path: str | None = Field(default=None)
|
||||
replay_trajectory_path: str | None = Field(default=None)
|
||||
workspace_base: str | None = Field(default=None)
|
||||
workspace_mount_path: str | None = Field(default=None)
|
||||
workspace_mount_path_in_sandbox: str = Field(default='/workspace')
|
||||
|
||||
@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Callable, Protocol
|
||||
|
||||
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
|
||||
@ -22,10 +23,11 @@ from openhands.core.setup import (
|
||||
generate_sid,
|
||||
)
|
||||
from openhands.events import EventSource, EventStreamSubscriber
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.action import MessageAction, NullAction
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.events.serialization import event_from_dict
|
||||
from openhands.events.serialization.event import event_to_trajectory
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
@ -101,7 +103,17 @@ async def run_controller(
|
||||
if agent is None:
|
||||
agent = create_agent(runtime, config)
|
||||
|
||||
controller, initial_state = create_controller(agent, runtime, config)
|
||||
replay_events: list[Event] | None = None
|
||||
if config.replay_trajectory_path:
|
||||
logger.info('Trajectory replay is enabled')
|
||||
assert isinstance(initial_user_action, NullAction)
|
||||
replay_events, initial_user_action = load_replay_log(
|
||||
config.replay_trajectory_path
|
||||
)
|
||||
|
||||
controller, initial_state = create_controller(
|
||||
agent, runtime, config, replay_events=replay_events
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
initial_user_action, Action
|
||||
@ -194,21 +206,64 @@ def auto_continue_response(
|
||||
return message
|
||||
|
||||
|
||||
def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
|
||||
"""
|
||||
Load trajectory from given path, serialize it to a list of events, and return
|
||||
two things:
|
||||
1) A list of events except the first action
|
||||
2) First action (user message, a.k.a. initial task)
|
||||
"""
|
||||
try:
|
||||
path = Path(trajectory_path).resolve()
|
||||
|
||||
if not path.exists():
|
||||
raise ValueError(f'Trajectory file not found: {path}')
|
||||
|
||||
if not path.is_file():
|
||||
raise ValueError(f'Trajectory path is a directory, not a file: {path}')
|
||||
|
||||
with open(path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(
|
||||
f'Expected a list in {path}, got {type(data).__name__}'
|
||||
)
|
||||
events = []
|
||||
for item in data:
|
||||
event = event_from_dict(item)
|
||||
# cannot add an event with _id to event stream
|
||||
event._id = None # type: ignore[attr-defined]
|
||||
events.append(event)
|
||||
assert isinstance(events[0], MessageAction)
|
||||
return events[1:], events[0]
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f'Invalid JSON format in {trajectory_path}: {e}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
config = setup_config_from_args(args)
|
||||
|
||||
# Determine the task
|
||||
task_str = ''
|
||||
if args.file:
|
||||
task_str = read_task_from_file(args.file)
|
||||
elif args.task:
|
||||
task_str = args.task
|
||||
elif not sys.stdin.isatty():
|
||||
task_str = read_task_from_stdin()
|
||||
|
||||
initial_user_action: Action = NullAction()
|
||||
if config.replay_trajectory_path:
|
||||
if task_str:
|
||||
raise ValueError(
|
||||
'User-specified task is not supported under trajectory replay mode'
|
||||
)
|
||||
elif task_str:
|
||||
initial_user_action = MessageAction(content=task_str)
|
||||
else:
|
||||
raise ValueError('No task provided. Please specify a task through -t, -f.')
|
||||
initial_user_action: MessageAction = MessageAction(content=task_str)
|
||||
|
||||
config = setup_config_from_args(args)
|
||||
|
||||
# Set session name
|
||||
session_name = args.name
|
||||
|
||||
@ -11,6 +11,7 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
@ -78,7 +79,11 @@ def create_agent(runtime: Runtime, config: AppConfig) -> Agent:
|
||||
|
||||
|
||||
def create_controller(
|
||||
agent: Agent, runtime: Runtime, config: AppConfig, headless_mode: bool = True
|
||||
agent: Agent,
|
||||
runtime: Runtime,
|
||||
config: AppConfig,
|
||||
headless_mode: bool = True,
|
||||
replay_events: list[Event] | None = None,
|
||||
) -> Tuple[AgentController, State | None]:
|
||||
event_stream = runtime.event_stream
|
||||
initial_state = None
|
||||
@ -101,6 +106,7 @@ def create_controller(
|
||||
initial_state=initial_state,
|
||||
headless_mode=headless_mode,
|
||||
confirmation_mode=config.security.confirmation_mode,
|
||||
replay_events=replay_events,
|
||||
)
|
||||
return (controller, initial_state)
|
||||
|
||||
|
||||
@ -24,6 +24,8 @@ class FileReadSource(str, Enum):
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
INVALID_ID = -1
|
||||
|
||||
@property
|
||||
def message(self) -> str | None:
|
||||
if hasattr(self, '_message'):
|
||||
@ -34,7 +36,7 @@ class Event:
|
||||
def id(self) -> int:
|
||||
if hasattr(self, '_id'):
|
||||
return self._id # type: ignore[attr-defined]
|
||||
return -1
|
||||
return Event.INVALID_ID
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
|
||||
@ -12,7 +12,7 @@ class BrowserOutputObservation(Observation):
|
||||
|
||||
url: str
|
||||
trigger_by_action: str
|
||||
screenshot: str = field(repr=False) # don't show in repr
|
||||
screenshot: str = field(repr=False, default='') # don't show in repr
|
||||
error: bool = False
|
||||
observation: str = ObservationType.BROWSE
|
||||
# do not include in the memory
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user