Save complete trajectory in presence of history truncation (#6751)

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Boxuan Li 2025-02-21 00:14:30 -08:00 committed by GitHub
parent d33913e036
commit fab4532f6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 77 additions and 21 deletions

View File

@ -51,7 +51,7 @@ from openhands.events.observation import (
NullObservation,
Observation,
)
from openhands.events.serialization.event import truncate_content
from openhands.events.serialization.event import event_to_trajectory, truncate_content
from openhands.llm.llm import LLM
# note: RESUME is only available on web GUI
@ -149,12 +149,13 @@ class AgentController:
# replay-related
self._replay_manager = ReplayManager(replay_events)
async def close(self) -> None:
async def close(self, set_stop_state=True) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
Note that it's fairly important that this closes properly, otherwise the state is incomplete.
"""
await self.set_agent_state_to(AgentState.STOPPED)
if set_stop_state:
await self.set_agent_state_to(AgentState.STOPPED)
# we made history, now is the time to rewrite it!
# the final state.history will be used by external scripts like evals, tests, etc.
@ -701,22 +702,7 @@ class AgentController:
or isinstance(e, ContextWindowExceededError)
):
if self.agent.config.enable_history_truncation:
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(
self.state.history
)
# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Add an error event to trigger another step by the agent
self.event_stream.add_event(
AgentCondensationObservation(
content='Trimming prompt to meet context window limitations'
),
EventSource.AGENT,
)
self._handle_long_context_error()
return
else:
raise LLMContextWindowExceedError()
@ -848,6 +834,11 @@ class AgentController:
# Always load from the event stream to avoid losing history
self._init_history()
def get_trajectory(self) -> list[dict]:
# state history could be partially hidden/truncated before controller is closed
assert self._closed
return [event_to_trajectory(event) for event in self.state.history]
def _init_history(self) -> None:
"""Initializes the agent's history from the event stream.
@ -973,6 +964,22 @@ class AgentController:
# make sure history is in sync
self.state.start_id = start_id
def _handle_long_context_error(self) -> None:
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(self.state.history)
# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Add an error event to trigger another step by the agent
self.event_stream.add_event(
AgentCondensationObservation(
content='Trimming prompt to meet context window limitations'
),
EventSource.AGENT,
)
def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
"""Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
and ensuring the first user message is always included.

View File

@ -27,7 +27,6 @@ from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.events.serialization.event import event_to_trajectory
from openhands.io import read_input, read_task
from openhands.runtime.base import Runtime
@ -167,6 +166,8 @@ async def run_controller(
# NOTE: the saved state does not include delegates events
end_state.save_to_session(event_stream.sid, event_stream.file_store)
await controller.close(set_stop_state=False)
state = controller.get_state()
# save trajectories if applicable
@ -177,7 +178,7 @@ async def run_controller(
else:
file_path = config.save_trajectory_path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
histories = [event_to_trajectory(event) for event in state.history]
histories = controller.get_trajectory()
with open(file_path, 'w') as f:
json.dump(histories, f)

View File

@ -1,3 +1,4 @@
import asyncio
from unittest.mock import MagicMock
import pytest
@ -72,6 +73,53 @@ class TestTruncation:
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])
def test_truncation_does_not_impact_trajectory(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
pairs = 10
history_len = 1 + 2 * pairs
events = [first_msg]
for i in range(pairs):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])
# patch events to history for testing purpose
controller.state.history = events
# Update mock event stream
mock_event_stream.get_events.return_value = controller.state.history
assert len(controller.state.history) == history_len
# Force apply truncation
controller._handle_long_context_error()
# Check that the history has been truncated before closing the controller
assert len(controller.state.history) == 13 < history_len
# Check that after properly closing the controller, history is recovered
asyncio.run(controller.close())
assert len(controller.event_stream.get_events()) == history_len
assert len(controller.state.history) == history_len
assert len(controller.get_trajectory()) == history_len
def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,