mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Save complete trajectory in presence of history truncation (#6751)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
parent
d33913e036
commit
fab4532f6b
@ -51,7 +51,7 @@ from openhands.events.observation import (
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.events.serialization.event import event_to_trajectory, truncate_content
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
# note: RESUME is only available on web GUI
|
||||
@ -149,12 +149,13 @@ class AgentController:
|
||||
# replay-related
|
||||
self._replay_manager = ReplayManager(replay_events)
|
||||
|
||||
async def close(self) -> None:
|
||||
async def close(self, set_stop_state=True) -> None:
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
|
||||
|
||||
Note that it's fairly important that this closes properly, otherwise the state is incomplete.
|
||||
"""
|
||||
await self.set_agent_state_to(AgentState.STOPPED)
|
||||
if set_stop_state:
|
||||
await self.set_agent_state_to(AgentState.STOPPED)
|
||||
|
||||
# we made history, now is the time to rewrite it!
|
||||
# the final state.history will be used by external scripts like evals, tests, etc.
|
||||
@ -701,22 +702,7 @@ class AgentController:
|
||||
or isinstance(e, ContextWindowExceededError)
|
||||
):
|
||||
if self.agent.config.enable_history_truncation:
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
self.state.history = self._apply_conversation_window(
|
||||
self.state.history
|
||||
)
|
||||
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
self.state.start_id = self.state.history[0].id
|
||||
|
||||
# Add an error event to trigger another step by the agent
|
||||
self.event_stream.add_event(
|
||||
AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
self._handle_long_context_error()
|
||||
return
|
||||
else:
|
||||
raise LLMContextWindowExceedError()
|
||||
@ -848,6 +834,11 @@ class AgentController:
|
||||
# Always load from the event stream to avoid losing history
|
||||
self._init_history()
|
||||
|
||||
def get_trajectory(self) -> list[dict]:
|
||||
# state history could be partially hidden/truncated before controller is closed
|
||||
assert self._closed
|
||||
return [event_to_trajectory(event) for event in self.state.history]
|
||||
|
||||
def _init_history(self) -> None:
|
||||
"""Initializes the agent's history from the event stream.
|
||||
|
||||
@ -973,6 +964,22 @@ class AgentController:
|
||||
# make sure history is in sync
|
||||
self.state.start_id = start_id
|
||||
|
||||
def _handle_long_context_error(self) -> None:
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
self.state.history = self._apply_conversation_window(self.state.history)
|
||||
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
self.state.start_id = self.state.history[0].id
|
||||
|
||||
# Add an error event to trigger another step by the agent
|
||||
self.event_stream.add_event(
|
||||
AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
|
||||
def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
|
||||
"""Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
|
||||
and ensuring the first user message is always included.
|
||||
|
||||
@ -27,7 +27,6 @@ from openhands.events.action.action import Action
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.events.serialization import event_from_dict
|
||||
from openhands.events.serialization.event import event_to_trajectory
|
||||
from openhands.io import read_input, read_task
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
@ -167,6 +166,8 @@ async def run_controller(
|
||||
# NOTE: the saved state does not include delegates events
|
||||
end_state.save_to_session(event_stream.sid, event_stream.file_store)
|
||||
|
||||
await controller.close(set_stop_state=False)
|
||||
|
||||
state = controller.get_state()
|
||||
|
||||
# save trajectories if applicable
|
||||
@ -177,7 +178,7 @@ async def run_controller(
|
||||
else:
|
||||
file_path = config.save_trajectory_path
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
histories = [event_to_trajectory(event) for event in state.history]
|
||||
histories = controller.get_trajectory()
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(histories, f)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -72,6 +73,53 @@ class TestTruncation:
|
||||
if isinstance(event, CmdOutputObservation):
|
||||
assert any(e._id == event._cause for e in truncated[: i + 1])
|
||||
|
||||
def test_truncation_does_not_impact_trajectory(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a sequence of events with IDs
|
||||
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
pairs = 10
|
||||
history_len = 1 + 2 * pairs
|
||||
events = [first_msg]
|
||||
for i in range(pairs):
|
||||
cmd = CmdRunAction(command=f'cmd{i}')
|
||||
cmd._id = i + 2
|
||||
obs = CmdOutputObservation(
|
||||
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
|
||||
)
|
||||
obs._cause = cmd._id
|
||||
events.extend([cmd, obs])
|
||||
|
||||
# patch events to history for testing purpose
|
||||
controller.state.history = events
|
||||
|
||||
# Update mock event stream
|
||||
mock_event_stream.get_events.return_value = controller.state.history
|
||||
|
||||
assert len(controller.state.history) == history_len
|
||||
|
||||
# Force apply truncation
|
||||
controller._handle_long_context_error()
|
||||
|
||||
# Check that the history has been truncated before closing the controller
|
||||
assert len(controller.state.history) == 13 < history_len
|
||||
|
||||
# Check that after properly closing the controller, history is recovered
|
||||
asyncio.run(controller.close())
|
||||
assert len(controller.event_stream.get_events()) == history_len
|
||||
assert len(controller.state.history) == history_len
|
||||
assert len(controller.get_trajectory()) == history_len
|
||||
|
||||
def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user