fix: Context window truncation using CondensationAction (#7578)

Co-authored-by: Calvin Smith <calvin@all-hands.dev>
Co-authored-by: Graham Neubig <neubig@gmail.com>
This commit is contained in:
Calvin Smith 2025-03-31 13:47:00 -06:00 committed by GitHub
parent 648c8ffb21
commit abaf0da9fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 412 additions and 445 deletions

View File

@ -150,13 +150,13 @@ class BrowsingAgent(Agent):
last_obs = None
last_action = None
if EVAL_MODE and len(state.history) == 1:
if EVAL_MODE and len(state.view) == 1:
# for webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
# initialize and retrieve the first observation by issuing an noop OP
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
return BrowseInteractiveAction(browser_actions='noop()')
for event in state.history:
for event in state.view:
if isinstance(event, BrowseInteractiveAction):
prev_actions.append(event.browser_actions)
last_action = event

View File

@ -130,7 +130,7 @@ class DummyAgent(Agent):
if 'observations' in prev_step and prev_step['observations']:
expected_observations = prev_step['observations']
hist_events = state.history[-len(expected_observations) :]
hist_events = state.view[-len(expected_observations) :]
if len(hist_events) < len(expected_observations):
print(

View File

@ -204,13 +204,13 @@ Note:
last_action = None
set_of_marks = None # Initialize set_of_marks to None
if len(state.history) == 1:
if len(state.view) == 1:
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
# initialize and retrieve the first observation by issuing an noop OP
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
return BrowseInteractiveAction(browser_actions='noop(1000)')
for event in state.history:
for event in state.view:
if isinstance(event, BrowseInteractiveAction):
prev_actions.append(event)
last_action = event

View File

@ -57,7 +57,6 @@ from openhands.events.action import (
from openhands.events.action.agent import CondensationAction, RecallAction
from openhands.events.event import Event
from openhands.events.observation import (
AgentCondensationObservation,
AgentDelegateObservation,
AgentStateChangedObservation,
ErrorObservation,
@ -928,12 +927,6 @@ class AgentController:
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
- Excludes all events between the action and observation
- Includes the delegate action and observation themselves
The history is loaded in two parts if truncation_id is set:
1. First user message from start_id onwards
2. Rest of history from truncation_id to the end
Otherwise loads normally from start_id.
"""
# define range of events to fetch
# delegates start with a start_id and initially won't find any events
@ -956,29 +949,6 @@ class AgentController:
events: list[Event] = []
# If we have a truncation point, get first user message and then rest of history
if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
# Find first user message from stream
first_user_msg = next(
(
e
for e in self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=False,
filter_out_type=self.filter_out,
filter_hidden=True,
)
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)
if first_user_msg:
events.append(first_user_msg)
# the rest of the events are from the truncation point
start_id = self.state.truncation_id
# Get rest of history
events_to_add = list(
self.event_stream.get_events(
@ -1046,7 +1016,10 @@ class AgentController:
def _handle_long_context_error(self) -> None:
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(self.state.history)
kept_event_ids = {
e.id for e in self._apply_conversation_window(self.state.history)
}
forgotten_event_ids = {e.id for e in self.state.history} - kept_event_ids
# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
@ -1054,8 +1027,9 @@ class AgentController:
# Add an error event to trigger another step by the agent
self.event_stream.add_event(
AgentCondensationObservation(
content='Trimming prompt to meet context window limitations'
CondensationAction(
forgotten_events_start_id=min(forgotten_event_ids),
forgotten_events_end_id=max(forgotten_event_ids),
),
EventSource.AGENT,
)
@ -1133,10 +1107,6 @@ class AgentController:
# if it's an action with source == EventSource.AGENT, we're good
break
# Save where to continue from in next reload
if kept_events:
self.state.truncation_id = kept_events[0].id
# Ensure first user message is included
if first_user_msg and first_user_msg not in kept_events:
kept_events = [first_user_msg] + kept_events

View File

@ -15,6 +15,7 @@ from openhands.events.action import (
from openhands.events.action.agent import AgentFinishAction
from openhands.events.event import Event, EventSource
from openhands.llm.metrics import Metrics
from openhands.memory.view import View
from openhands.storage.files import FileStore
from openhands.storage.locations import get_conversation_agent_state_filename
@ -96,8 +97,6 @@ class State:
# start_id and end_id track the range of events in history
start_id: int = -1
end_id: int = -1
# truncation_id tracks where to load history after context window truncation
truncation_id: int = -1
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
# NOTE: This will never be used by the controller, but it can be used by different
@ -170,6 +169,12 @@ class State:
# don't pickle history, it will be restored from the event stream
state = self.__dict__.copy()
state['history'] = []
# Remove any view caching attributes. They'll be rebuilt frmo the
# history after that gets reloaded.
state.pop('_history_checksum', None)
state.pop('_view', None)
return state
def __setstate__(self, state):
@ -183,7 +188,7 @@ class State:
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
last_user_message = None
last_user_message_image_urls: list[str] | None = []
for event in reversed(self.history):
for event in reversed(self.view):
if isinstance(event, MessageAction) and event.source == 'user':
last_user_message = event.content
last_user_message_image_urls = event.image_urls
@ -194,13 +199,13 @@ class State:
return last_user_message, last_user_message_image_urls
def get_last_agent_message(self) -> MessageAction | None:
for event in reversed(self.history):
for event in reversed(self.view):
if isinstance(event, MessageAction) and event.source == EventSource.AGENT:
return event
return None
def get_last_user_message(self) -> MessageAction | None:
for event in reversed(self.history):
for event in reversed(self.view):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
return event
return None
@ -211,7 +216,22 @@ class State:
'trace_version': openhands.__version__,
'tags': [
f'agent:{agent_name}',
f'web_host:{os.environ.get("WEB_HOST", "unspecified")}',
f"web_host:{os.environ.get('WEB_HOST', 'unspecified')}",
f'openhands_version:{openhands.__version__}',
],
}
@property
def view(self) -> View:
# Compute a simple checksum from the history to see if we can re-use any
# cached view.
history_checksum = len(self.history)
old_history_checksum = getattr(self, '_history_checksum', -1)
# If the history has changed, we need to re-create the view and update
# the caching.
if history_checksum != old_history_checksum:
self._history_checksum = history_checksum
self._view = View.from_events(self.history)
return self._view

View File

@ -1,3 +0,0 @@
from openhands.memory.condenser import Condenser
__all__ = ['Condenser']

View File

@ -2,15 +2,14 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, overload
from typing import Any
from pydantic import BaseModel
from openhands.controller.state.state import State
from openhands.core.config.condenser_config import CondenserConfig
from openhands.events.action.agent import CondensationAction
from openhands.events.event import Event
from openhands.events.observation.agent import AgentCondensationObservation
from openhands.memory.view import View
CONDENSER_METADATA_KEY = 'condenser_meta'
"""Key identifying where metadata is stored in a `State` object's `extra_data` field."""
@ -34,69 +33,6 @@ CONDENSER_REGISTRY: dict[type[CondenserConfig], type[Condenser]] = {}
"""Registry of condenser configurations to their corresponding condenser classes."""
class View(BaseModel):
"""Linearly ordered view of events.
Produced by a condenser to indicate the included events are ready to process as LLM input.
"""
events: list[Event]
def __len__(self) -> int:
return len(self.events)
def __iter__(self):
return iter(self.events)
# To preserve list-like indexing, we ideally support slicing and position-based indexing.
# The only challenge with that is switching the return type based on the input type -- we
# can mark the different signatures for MyPy with `@overload` decorators.
@overload
def __getitem__(self, key: slice) -> list[Event]: ...
@overload
def __getitem__(self, key: int) -> Event: ...
def __getitem__(self, key: int | slice) -> Event | list[Event]:
if isinstance(key, slice):
start, stop, step = key.indices(len(self))
return [self[i] for i in range(start, stop, step)]
elif isinstance(key, int):
return self.events[key]
else:
raise ValueError(f'Invalid key type: {type(key)}')
@staticmethod
def from_events(events: list[Event]) -> View:
"""Create a view from a list of events, respecting the semantics of any condensation events."""
forgotten_event_ids: set[int] = set()
for event in events:
if isinstance(event, CondensationAction):
forgotten_event_ids.update(event.forgotten)
kept_events = [event for event in events if event.id not in forgotten_event_ids]
# If we have a summary, insert it at the specified offset.
summary: str | None = None
summary_offset: int | None = None
# The relevant summary is always in the last condensation event (i.e., the most recent one).
for event in reversed(events):
if isinstance(event, CondensationAction):
if event.summary is not None and event.summary_offset is not None:
summary = event.summary
summary_offset = event.summary_offset
break
if summary is not None and summary_offset is not None:
kept_events.insert(
summary_offset, AgentCondensationObservation(content=summary)
)
return View(events=kept_events)
class Condensation(BaseModel):
"""Produced by a condenser to indicate the history has been condensed."""
@ -150,13 +86,13 @@ class Condenser(ABC):
self.write_metadata(state)
@abstractmethod
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, View) -> View | Condensation:
"""Condense a sequence of events into a potentially smaller list.
New condenser strategies should override this method to implement their own condensation logic. Call `self.add_metadata` in the implementation to record any relevant per-condensation diagnostic information.
Args:
events: A list of events representing the entire history of the agent.
View: A view of the history containing all events that should be condensed.
Returns:
View | Condensation: A condensed view of the events or an event indicating the history has been condensed.
@ -165,7 +101,7 @@ class Condenser(ABC):
def condensed_history(self, state: State) -> View | Condensation:
"""Condense the state's history."""
with self.metadata_batch(state):
return self.condense(state.history)
return self.condense(state.view)
@classmethod
def register_config(cls, configuration_type: type[CondenserConfig]) -> None:
@ -221,10 +157,7 @@ class RollingCondenser(Condenser, ABC):
def get_condensation(self, view: View) -> Condensation:
"""Get the condensation from a view."""
def condense(self, events: list[Event]) -> View | Condensation:
# Convert the state to a view. This might require some condenser-specific logic.
view = View.from_events(events)
def condense(self, view: View) -> View | Condensation:
# If we trigger the condenser-specific condensation threshold, compute and return
# the condensation.
if self.should_condense(view):

View File

@ -17,11 +17,11 @@ class BrowserOutputCondenser(Condenser):
self.attention_window = attention_window
super().__init__()
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, view: View) -> View | Condensation:
"""Replace the content of browser observations outside of the attention window with a placeholder."""
results: list[Event] = []
cnt: int = 0
for event in reversed(events):
for event in reversed(view):
if (
isinstance(event, BrowserOutputObservation)
and cnt >= self.attention_window

View File

@ -1,16 +1,15 @@
from __future__ import annotations
from openhands.core.config.condenser_config import NoOpCondenserConfig
from openhands.events.event import Event
from openhands.memory.condenser.condenser import Condensation, Condenser, View
class NoOpCondenser(Condenser):
"""A condenser that does nothing to the event sequence."""
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, view: View) -> View | Condensation:
"""Returns the list of events unchanged."""
return View(events=events)
return view
@classmethod
def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser:

View File

@ -15,14 +15,11 @@ class ObservationMaskingCondenser(Condenser):
super().__init__()
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, view: View) -> View | Condensation:
"""Replace the content of observations outside of the attention window with a placeholder."""
results: list[Event] = []
for i, event in enumerate(events):
if (
isinstance(event, Observation)
and i < len(events) - self.attention_window
):
for i, event in enumerate(view):
if isinstance(event, Observation) and i < len(view) - self.attention_window:
results.append(AgentCondensationObservation('<MASKED>'))
else:
results.append(event)

View File

@ -1,7 +1,6 @@
from __future__ import annotations
from openhands.core.config.condenser_config import RecentEventsCondenserConfig
from openhands.events.event import Event
from openhands.memory.condenser.condenser import Condensation, Condenser, View
@ -14,11 +13,11 @@ class RecentEventsCondenser(Condenser):
super().__init__()
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, view: View) -> View | Condensation:
"""Keep only the most recent events (up to `max_events`)."""
head = events[: self.keep_first]
head = view[: self.keep_first]
tail_length = max(0, self.max_events - len(head))
tail = events[-tail_length:]
tail = view[-tail_length:]
return View(events=head + tail)
@classmethod

72
openhands/memory/view.py Normal file
View File

@ -0,0 +1,72 @@
from __future__ import annotations
from typing import overload
from pydantic import BaseModel
from openhands.events.action.agent import CondensationAction
from openhands.events.event import Event
from openhands.events.observation.agent import AgentCondensationObservation
class View(BaseModel):
"""Linearly ordered view of events.
Produced by a condenser to indicate the included events are ready to process as LLM input.
"""
events: list[Event]
def __len__(self) -> int:
return len(self.events)
def __iter__(self):
return iter(self.events)
# To preserve list-like indexing, we ideally support slicing and position-based indexing.
# The only challenge with that is switching the return type based on the input type -- we
# can mark the different signatures for MyPy with `@overload` decorators.
@overload
def __getitem__(self, key: slice) -> list[Event]: ...
@overload
def __getitem__(self, key: int) -> Event: ...
def __getitem__(self, key: int | slice) -> Event | list[Event]:
if isinstance(key, slice):
start, stop, step = key.indices(len(self))
return [self[i] for i in range(start, stop, step)]
elif isinstance(key, int):
return self.events[key]
else:
raise ValueError(f'Invalid key type: {type(key)}')
@staticmethod
def from_events(events: list[Event]) -> View:
"""Create a view from a list of events, respecting the semantics of any condensation events."""
forgotten_event_ids: set[int] = set()
for event in events:
if isinstance(event, CondensationAction):
forgotten_event_ids.update(event.forgotten)
kept_events = [event for event in events if event.id not in forgotten_event_ids]
# If we have a summary, insert it at the specified offset.
summary: str | None = None
summary_offset: int | None = None
# The relevant summary is always in the last condensation event (i.e., the most recent one).
for event in reversed(events):
if isinstance(event, CondensationAction):
if event.summary is not None and event.summary_offset is not None:
summary = event.summary
summary_offset = event.summary_offset
break
if summary is not None and summary_offset is not None:
kept_events.insert(
summary_offset, AgentCondensationObservation(content=summary)
)
return View(events=kept_events)

View File

@ -20,6 +20,7 @@ from openhands.events.observation import (
ErrorObservation,
)
from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.commands import CmdOutputObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
@ -643,19 +644,27 @@ async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_ca
@pytest.mark.asyncio
async def test_context_window_exceeded_error_handling(mock_agent, mock_event_stream):
"""Test that context window exceeded errors are handled correctly by truncating history."""
async def test_context_window_exceeded_error_handling(
mock_agent, mock_runtime, test_event_stream
):
"""Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact."""
max_iterations = 5
error_after = 2
class StepState:
def __init__(self):
self.has_errored = False
self.index = 0
self.views = []
def step(self, state: State):
# Append a few messages to the history -- these will be truncated when we throw the error
state.history = [
MessageAction(content='Test message 0'),
MessageAction(content='Test message 1'),
]
self.views.append(state.view)
# Wait until the right step to throw the error, and make sure we
# only throw it once.
if self.index < error_after or self.has_errored:
self.index += 1
return MessageAction(content=f'Test message {self.index}')
error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
@ -665,28 +674,78 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
self.has_errored = True
raise error
state = StepState()
mock_agent.step = state.step
step_state = StepState()
mock_agent.step = step_state.step
mock_agent.config = AgentConfig()
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
# Because we're sending message actions, we need to respond to the recall
# actions that get generated as a response.
# We do that by playing the role of the recall module -- subscribe to the
# event stream and respond to recall actions by inserting fake recall
# obesrvations.
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = RecallObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
microagent_obs._cause = event.id
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
test_event_stream.subscribe(
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
)
mock_runtime.event_stream = test_event_stream
# Now we can run the controller for a fixed number of steps. Since the step
# state is set to error out before then, if this terminates and we have a
# record of the error being thrown we can be confident that the controller
# handles the truncation correctly.
final_state = await asyncio.wait_for(
run_controller(
config=AppConfig(max_iterations=max_iterations),
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory,
),
timeout=10,
)
# Set the agent running and take a step in the controller -- this is similar
# to taking a single step using `run_controller`, but much easier to control
# termination for testing purposes
controller.state.agent_state = AgentState.RUNNING
await controller._step()
# Check that the context window exception was thrown and the controller
# called the agent's `step` function the right number of times.
assert step_state.has_errored
assert len(step_state.views) == max_iterations
# Check that the error was thrown and the history has been truncated
assert state.has_errored
assert controller.state.history == [MessageAction(content='Test message 1')]
# Look at pre/post-step views. Normally, these should always increase in
# size (because we return a message action, which triggers a recall, which
# triggers a recall response). But if the pre/post-views are on the turn
# when we throw the context window exceeded error, we should see the
# post-step view compressed.
for index, (first_view, second_view) in enumerate(
zip(step_state.views[:-1], step_state.views[1:])
):
if index == error_after:
assert len(first_view) > len(second_view)
else:
assert len(first_view) < len(second_view)
# The final state's history should contain:
# - max_iterations number of message actions,
# - max_iterations number of recall actions,
# - max_iterations number of recall observations,
# - and exactly one condensation action.
assert len(final_state.history) == max_iterations * 3 + 1
# ...but the final state's view should be identical to the last view (plus
# the final message action and associated recall action/observation).
assert len(final_state.view) == len(step_state.views[-1]) + 3
# And these two representations of the state are _not_ the same.
assert len(final_state.history) != len(final_state.view)
@pytest.mark.asyncio
@ -1168,3 +1227,123 @@ def test_agent_controller_should_step_with_null_observation_cause_zero():
assert (
result is False
), 'should_step should return False for NullObservation with cause = 0'
def test_apply_conversation_window_basic(mock_event_stream, mock_agent):
"""Test that the _apply_conversation_window method correctly prunes a list of events."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_apply_conversation_window_basic',
confirmation_mode=False,
headless_mode=True,
)
# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
# Add agent question
agent_msg = MessageAction(
content='What task would you like me to perform?', wait_for_response=True
)
agent_msg._source = EventSource.AGENT
agent_msg._id = 2
# Add user response
user_response = MessageAction(
content='Please list all files and show me current directory',
wait_for_response=False,
)
user_response._source = EventSource.USER
user_response._id = 3
cmd1 = CmdRunAction(command='ls')
cmd1._id = 4
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
obs1._id = 5
obs1._cause = 4
cmd2 = CmdRunAction(command='pwd')
cmd2._id = 6
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=6)
obs2._id = 7
obs2._cause = 6
events = [first_msg, agent_msg, user_response, cmd1, obs1, cmd2, obs2]
# Apply truncation
truncated = controller._apply_conversation_window(events)
# Verify truncation occured
# Should keep first user message and roughly half of other events
assert (
3 <= len(truncated) < len(events)
) # First message + at least one action-observation pair
assert truncated[0] == first_msg # First message always preserved
assert controller.state.start_id == first_msg._id
# Verify pairs aren't split
for i, event in enumerate(truncated[1:]):
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])
def test_history_restoration_after_truncation(mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Create events with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
events = [first_msg]
for i in range(5):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])
# Set up initial history
controller.state.history = events.copy()
# Force truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)
# Save state
saved_start_id = controller.state.start_id
saved_history_len = len(controller.state.history)
# Set up mock event stream for new controller
mock_event_stream.get_events.return_value = controller.state.history
# Create new controller with saved state
new_controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
new_controller.state.start_id = saved_start_id
new_controller.state.history = mock_event_stream.get_events()
# Verify restoration
assert len(new_controller.state.history) == saved_history_len
assert new_controller.state.history[0] == first_msg
assert new_controller.state.start_id == saved_start_id

View File

@ -127,7 +127,6 @@ async def test_agent_session_start_with_no_state(mock_agent):
assert session.controller.agent.name == 'test-agent'
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
assert session.controller.state.truncation_id == -1
@pytest.mark.asyncio
@ -164,7 +163,6 @@ async def test_agent_session_start_with_restored_state(mock_agent):
mock_restored_state = MagicMock(spec=State)
mock_restored_state.start_id = -1
mock_restored_state.end_id = -1
mock_restored_state.truncation_id = -1
mock_restored_state.max_iterations = 5
# Create a spy on set_initial_state by subclassing AgentController
@ -211,4 +209,3 @@ async def test_agent_session_start_with_restored_state(mock_agent):
assert session.controller.state.max_iterations == 5
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
assert session.controller.state.truncation_id == -1

View File

@ -88,16 +88,6 @@ def mock_llm() -> LLM:
return mock_llm
@pytest.fixture
def mock_state() -> State:
"""Mocks a State object with the only parameters needed for testing condensers: history and extra_data."""
mock_state = MagicMock(spec=State)
mock_state.history = []
mock_state.extra_data = {}
return mock_state
class RollingCondenserTestHarness:
"""Test harness for rolling condensers.
@ -120,21 +110,19 @@ class RollingCondenserTestHarness:
This generator assumes we're starting from an empty history.
"""
mock_state = MagicMock()
mock_state.extra_data = {}
mock_state.history = []
state = State()
for event in events:
mock_state.history.append(event)
state.history.append(event)
for callback in self.callbacks:
callback(mock_state.history)
callback(state.history)
match self.condenser.condensed_history(mock_state):
match self.condenser.condensed_history(state):
case View() as view:
yield view
case Condensation(event=condensation_event):
mock_state.history.append(condensation_event)
state.history.append(condensation_event)
def expected_size(self, index: int, max_size: int) -> int:
"""Calculate the expected size of the view at the given index.
@ -180,12 +168,11 @@ def test_noop_condenser():
create_test_event('Event 2'),
create_test_event('Event 3'),
]
mock_state = MagicMock()
mock_state.history = events
state = State()
state.history = events
condenser = NoOpCondenser()
result = condenser.condensed_history(mock_state)
result = condenser.condensed_history(state)
assert result == View(events=events)
@ -200,7 +187,7 @@ def test_observation_masking_condenser_from_config():
assert condenser.attention_window == attention_window
def test_observation_masking_condenser_respects_attention_window(mock_state):
def test_observation_masking_condenser_respects_attention_window():
"""Test that ObservationMaskingCondenser only masks events outside the attention window."""
attention_window = 3
condenser = ObservationMaskingCondenser(attention_window=attention_window)
@ -213,8 +200,9 @@ def test_observation_masking_condenser_respects_attention_window(mock_state):
Observation('Observation 2'),
]
mock_state.history = events
result = condenser.condensed_history(mock_state)
state = State()
state.history = events
result = condenser.condensed_history(state)
assert len(result) == len(events)
@ -239,7 +227,7 @@ def test_browser_output_condenser_from_config():
assert condenser.attention_window == attention_window
def test_browser_output_condenser_respects_attention_window(mock_state):
def test_browser_output_condenser_respects_attention_window():
"""Test that BrowserOutputCondenser only masks events outside the attention window."""
attention_window = 3
condenser = BrowserOutputCondenser(attention_window=attention_window)
@ -253,8 +241,10 @@ def test_browser_output_condenser_respects_attention_window(mock_state):
BrowserOutputObservation('Observation 4', url='', trigger_by_action=''),
]
mock_state.history = events
result = condenser.condensed_history(mock_state)
state = State()
state.history = events
result = condenser.condensed_history(state)
assert len(result) == len(events)
cnt = 4
@ -291,19 +281,19 @@ def test_recent_events_condenser():
create_test_event('Event 5'),
]
mock_state = MagicMock()
mock_state.history = events
state = State()
state.history = events
# If the max_events are larger than the number of events, equivalent to a NoOpCondenser.
condenser = RecentEventsCondenser(max_events=len(events))
result = condenser.condensed_history(mock_state)
result = condenser.condensed_history(state)
assert result == View(events=events)
# If the max_events are smaller than the number of events, only keep the last few.
max_events = 3
condenser = RecentEventsCondenser(max_events=max_events)
result = condenser.condensed_history(mock_state)
result = condenser.condensed_history(state)
assert len(result) == max_events
assert result[0]._message == 'Event 1' # kept from keep_first
@ -314,7 +304,7 @@ def test_recent_events_condenser():
keep_first = 1
max_events = 2
condenser = RecentEventsCondenser(keep_first=keep_first, max_events=max_events)
result = condenser.condensed_history(mock_state)
result = condenser.condensed_history(state)
assert len(result) == max_events
assert result[0]._message == 'Event 1'
@ -324,7 +314,7 @@ def test_recent_events_condenser():
keep_first = 2
max_events = 3
condenser = RecentEventsCondenser(keep_first=keep_first, max_events=max_events)
result = condenser.condensed_history(mock_state)
result = condenser.condensed_history(state)
assert len(result) == max_events
assert result[0]._message == 'Event 1' # kept from keep_first
@ -380,7 +370,7 @@ def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
assert len(view) == harness.expected_size(i, max_size)
def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm, mock_state):
def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
"""Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events."""
max_size = 10
keep_first = 3
@ -547,7 +537,7 @@ def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
assert len(view) == harness.expected_size(i, max_size)
def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_state):
def test_llm_attention_condenser_handles_too_many_events(mock_llm):
"""Test that the LLMAttentionCondenser handles when the response contains too many event IDs."""
max_size = 2
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)

58
tests/unit/test_state.py Normal file
View File

@ -0,0 +1,58 @@
from openhands.controller.state.state import State
from openhands.events.event import Event
from openhands.storage.memory import InMemoryFileStore
def example_event(index: int) -> Event:
event = Event()
event._message = f'Test message {index}'
event._id = index
return event
def test_state_view_caching_avoids_unnecessary_rebuilding():
"""Test that the state view caching avoids unnecessarily rebuilding the view when the history hasn't changed."""
state = State()
state.history = [example_event(i) for i in range(5)]
# Build the view once.
view = state.view
# Easy way to check that the cache works -- `view` and future calls of
# `state.view` should be the same object. We'll check that by using the `id`
# of the view.
assert id(view) == id(state.view)
# Add an event to the history. This should produce a different view.
state.history.append(example_event(100))
new_view = state.view
assert id(new_view) != id(view)
# But once we have the new view once, it should be cached.
assert id(new_view) == id(state.view)
def test_state_view_cache_not_serialized():
"""Test that the fields used to cache view construction are not serialized when state is saved."""
state = State()
state.history = [example_event(i) for i in range(5)]
# Build the view once.
view = state.view
# Serialize the state.
store = InMemoryFileStore()
state.save_to_session('test_sid', store, None)
restored_state = State.restore_from_session('test_sid', store, None)
# The state usually has the history rebuilt from the event stream -- we'll
# simulate this by manually setting the state history to the same events.
restored_state.history = state.history
restored_view = restored_state.view
# Since serialization doesn't include the view cache, the restored view will
# be structurally identical but _not_ the same object.
assert id(restored_view) != id(view)
assert restored_view.events == view.events

View File

@ -1,244 +0,0 @@
import asyncio
from unittest.mock import MagicMock
import pytest
from openhands.controller.agent_controller import AgentController
from openhands.events import EventSource
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
@pytest.fixture
def mock_event_stream():
stream = MagicMock()
# Mock get_events to return an empty list by default
stream.get_events.return_value = []
# Mock get_latest_event_id to return a valid integer
stream.get_latest_event_id.return_value = 0
return stream
@pytest.fixture
def mock_agent():
agent = MagicMock()
agent.llm = MagicMock()
# Create a step function that returns an action without an ID
def agent_step_fn(state):
return MessageAction(content='Agent returned a message')
agent.step = agent_step_fn
return agent
class TestTruncation:
def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
cmd1 = CmdRunAction(command='ls')
cmd1._id = 2
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
obs1._id = 3
obs1._cause = 2
cmd2 = CmdRunAction(command='pwd')
cmd2._id = 4
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
obs2._id = 5
obs2._cause = 4
events = [first_msg, cmd1, obs1, cmd2, obs2]
# Apply truncation
truncated = controller._apply_conversation_window(events)
# Should keep first user message and roughly half of other events
assert (
len(truncated) >= 3
) # First message + at least one action-observation pair
assert truncated[0] == first_msg # First message always preserved
assert controller.state.start_id == first_msg._id
assert controller.state.truncation_id is not None
# Verify pairs aren't split
for i, event in enumerate(truncated[1:]):
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])
def test_truncation_does_not_impact_trajectory(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
pairs = 10
history_len = 1 + 2 * pairs
events = [first_msg]
for i in range(pairs):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])
# patch events to history for testing purpose
controller.state.history = events
# Update mock event stream
mock_event_stream.get_events.return_value = controller.state.history
assert len(controller.state.history) == history_len
# Force apply truncation
controller._handle_long_context_error()
# Check that the history has been truncated before closing the controller
assert len(controller.state.history) == 13 < history_len
# Check that after properly closing the controller, history is recovered
asyncio.run(controller.close())
assert len(controller.event_stream.get_events()) == history_len
assert len(controller.state.history) == history_len
assert len(controller.get_trajectory()) == history_len
def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Setup initial history with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
# Add agent question
agent_msg = MessageAction(
content='What task would you like me to perform?', wait_for_response=True
)
agent_msg._source = EventSource.AGENT
agent_msg._id = 2
# Add user response
user_response = MessageAction(
content='Please list all files and show me current directory',
wait_for_response=False,
)
user_response._source = EventSource.USER
user_response._id = 3
cmd1 = CmdRunAction(command='ls')
cmd1._id = 4
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
obs1._id = 5
obs1._cause = 4
# Update mock event stream to include new messages
mock_event_stream.get_events.return_value = [
first_msg,
agent_msg,
user_response,
cmd1,
obs1,
]
controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
original_history_len = len(controller.state.history)
# Simulate ContextWindowExceededError and truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)
# Verify truncation occurred
assert len(controller.state.history) < original_history_len
assert controller.state.start_id == first_msg._id
assert controller.state.truncation_id is not None
assert controller.state.truncation_id > controller.state.start_id
def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Create events with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
events = [first_msg]
for i in range(5):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])
# Set up initial history
controller.state.history = events.copy()
# Force truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)
# Save state
saved_start_id = controller.state.start_id
saved_truncation_id = controller.state.truncation_id
saved_history_len = len(controller.state.history)
# Set up mock event stream for new controller
mock_event_stream.get_events.return_value = controller.state.history
# Create new controller with saved state
new_controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
new_controller.state.start_id = saved_start_id
new_controller.state.truncation_id = saved_truncation_id
new_controller.state.history = mock_event_stream.get_events()
# Verify restoration
assert len(new_controller.state.history) == saved_history_len
assert new_controller.state.history[0] == first_msg
assert new_controller.state.start_id == saved_start_id