fix: Context window truncation makes progress (#9052)

Co-authored-by: Calvin Smith <calvin@all-hands.dev>
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Calvin Smith
2025-06-11 12:47:34 -06:00
committed by GitHub
parent 7dede37fd8
commit a356f56237
3 changed files with 87 additions and 69 deletions

View File

@@ -72,6 +72,7 @@ from openhands.events.observation import (
from openhands.events.serialization.event import event_to_trajectory, truncate_content
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics, TokenUsage
from openhands.memory.view import View
# note: RESUME is only available on web GUI
TRAFFIC_CONTROL_REMINDER = (
@@ -1161,7 +1162,8 @@ class AgentController:
def _handle_long_context_error(self) -> None:
# When context window is exceeded, keep roughly half of agent interactions
kept_events = self._apply_conversation_window()
current_view = View.from_events(self.state.history)
kept_events = self._apply_conversation_window(current_view.events)
kept_event_ids = {e.id for e in kept_events}
self.log(
@@ -1198,7 +1200,7 @@ class AgentController:
EventSource.AGENT,
)
def _apply_conversation_window(self) -> list[Event]:
def _apply_conversation_window(self, history: list[Event]) -> list[Event]:
"""Cuts history roughly in half when context window is exceeded.
It preserves action-observation pairs and ensures that the system message,
@@ -1217,11 +1219,9 @@ class AgentController:
Returns:
Filtered list of events keeping newest half while preserving pairs and essential initial events.
"""
if not self.state.history:
# Handle empty history
if not history:
return []
history = self.state.history
# 1. Identify essential initial events
system_message: SystemMessageAction | None = None
first_user_msg: MessageAction | None = None
@@ -1238,50 +1238,59 @@ class AgentController:
and system_message.id == history[0].id
)
# Find First User Message, which MUST exist
first_user_msg = self._first_user_message()
# Find First User Message in the history, which MUST exist
first_user_msg = self._first_user_message(history)
if first_user_msg is None:
raise RuntimeError('No first user message found in the event stream.')
# If not found in history, try the event stream
first_user_msg = self._first_user_message()
if first_user_msg is None:
raise RuntimeError('No first user message found in the event stream.')
self.log(
'warning',
'First user message not found in history. Using cached version from event stream.',
)
# Find the first user message index in the history
first_user_msg_index = -1
for i, event in enumerate(history):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
first_user_msg = event
first_user_msg_index = i
break
# Find Recall Action and Observation related to the First User Message
if first_user_msg is not None and first_user_msg_index != -1:
# Look for RecallAction after the first user message
for i in range(first_user_msg_index + 1, len(history)):
event = history[i]
if (
isinstance(event, RecallAction)
and event.query == first_user_msg.content
):
# Found RecallAction, now look for its Observation
recall_action = event
for j in range(i + 1, len(history)):
obs_event = history[j]
# Check for Observation caused by this RecallAction
if (
isinstance(obs_event, Observation)
and obs_event.cause == recall_action.id
):
recall_observation = obs_event
break # Found the observation, stop inner loop
break # Found the recall action (and maybe obs), stop outer loop
# Look for RecallAction after the first user message
for i in range(first_user_msg_index + 1, len(history)):
event = history[i]
if (
isinstance(event, RecallAction)
and event.query == first_user_msg.content
):
# Found RecallAction, now look for its Observation
recall_action = event
for j in range(i + 1, len(history)):
obs_event = history[j]
# Check for Observation caused by this RecallAction
if (
isinstance(obs_event, Observation)
and obs_event.cause == recall_action.id
):
recall_observation = obs_event
break # Found the observation, stop inner loop
break # Found the recall action (and maybe obs), stop outer loop
essential_events: list[Event] = []
if system_message:
essential_events.append(system_message)
if first_user_msg:
# Only include first user message if history is not empty
if history:
essential_events.append(first_user_msg)
# Also keep the RecallAction that triggered the essential RecallObservation
if recall_action:
essential_events.append(recall_action)
if recall_observation:
essential_events.append(recall_observation)
# Include recall action and observation if both exist
if recall_action and recall_observation:
essential_events.append(recall_action)
essential_events.append(recall_observation)
# Include recall action without observation for backward compatibility
elif recall_action:
essential_events.append(recall_action)
# 2. Determine the slice of recent events to potentially keep
num_non_essential_events = len(history) - len(essential_events)
@@ -1430,15 +1439,32 @@ class AgentController:
return result
return False
def _first_user_message(self) -> MessageAction | None:
def _first_user_message(
self, events: list[Event] | None = None
) -> MessageAction | None:
"""Get the first user message for this agent.
For regular agents, this is the first user message from the beginning (start_id=0).
For delegate agents, this is the first user message after the delegate's start_id.
Args:
events: Optional list of events to search through. If None, uses the event stream.
Returns:
MessageAction | None: The first user message, or None if no user message found
"""
# If events list is provided, search through it
if events is not None:
return next(
(
e
for e in events
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)
# Otherwise, use the original event stream logic with caching
# Return cached message if any
if self._cached_first_user_message is not None:
return self._cached_first_user_message