mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
fix: Context window truncation using CondensationAction (#7578)
Co-authored-by: Calvin Smith <calvin@all-hands.dev> Co-authored-by: Graham Neubig <neubig@gmail.com>
This commit is contained in:
parent
648c8ffb21
commit
abaf0da9fe
@ -150,13 +150,13 @@ class BrowsingAgent(Agent):
|
||||
last_obs = None
|
||||
last_action = None
|
||||
|
||||
if EVAL_MODE and len(state.history) == 1:
|
||||
if EVAL_MODE and len(state.view) == 1:
|
||||
# for webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||
# initialize and retrieve the first observation by issuing an noop OP
|
||||
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
|
||||
return BrowseInteractiveAction(browser_actions='noop()')
|
||||
|
||||
for event in state.history:
|
||||
for event in state.view:
|
||||
if isinstance(event, BrowseInteractiveAction):
|
||||
prev_actions.append(event.browser_actions)
|
||||
last_action = event
|
||||
|
||||
@ -130,7 +130,7 @@ class DummyAgent(Agent):
|
||||
|
||||
if 'observations' in prev_step and prev_step['observations']:
|
||||
expected_observations = prev_step['observations']
|
||||
hist_events = state.history[-len(expected_observations) :]
|
||||
hist_events = state.view[-len(expected_observations) :]
|
||||
|
||||
if len(hist_events) < len(expected_observations):
|
||||
print(
|
||||
|
||||
@ -204,13 +204,13 @@ Note:
|
||||
last_action = None
|
||||
set_of_marks = None # Initialize set_of_marks to None
|
||||
|
||||
if len(state.history) == 1:
|
||||
if len(state.view) == 1:
|
||||
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||
# initialize and retrieve the first observation by issuing an noop OP
|
||||
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
|
||||
return BrowseInteractiveAction(browser_actions='noop(1000)')
|
||||
|
||||
for event in state.history:
|
||||
for event in state.view:
|
||||
if isinstance(event, BrowseInteractiveAction):
|
||||
prev_actions.append(event)
|
||||
last_action = event
|
||||
|
||||
@ -57,7 +57,6 @@ from openhands.events.action import (
|
||||
from openhands.events.action.agent import CondensationAction, RecallAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
AgentDelegateObservation,
|
||||
AgentStateChangedObservation,
|
||||
ErrorObservation,
|
||||
@ -928,12 +927,6 @@ class AgentController:
|
||||
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
|
||||
- Excludes all events between the action and observation
|
||||
- Includes the delegate action and observation themselves
|
||||
|
||||
The history is loaded in two parts if truncation_id is set:
|
||||
1. First user message from start_id onwards
|
||||
2. Rest of history from truncation_id to the end
|
||||
|
||||
Otherwise loads normally from start_id.
|
||||
"""
|
||||
# define range of events to fetch
|
||||
# delegates start with a start_id and initially won't find any events
|
||||
@ -956,29 +949,6 @@ class AgentController:
|
||||
|
||||
events: list[Event] = []
|
||||
|
||||
# If we have a truncation point, get first user message and then rest of history
|
||||
if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
|
||||
# Find first user message from stream
|
||||
first_user_msg = next(
|
||||
(
|
||||
e
|
||||
for e in self.event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter_out_type=self.filter_out,
|
||||
filter_hidden=True,
|
||||
)
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
if first_user_msg:
|
||||
events.append(first_user_msg)
|
||||
|
||||
# the rest of the events are from the truncation point
|
||||
start_id = self.state.truncation_id
|
||||
|
||||
# Get rest of history
|
||||
events_to_add = list(
|
||||
self.event_stream.get_events(
|
||||
@ -1046,7 +1016,10 @@ class AgentController:
|
||||
|
||||
def _handle_long_context_error(self) -> None:
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
self.state.history = self._apply_conversation_window(self.state.history)
|
||||
kept_event_ids = {
|
||||
e.id for e in self._apply_conversation_window(self.state.history)
|
||||
}
|
||||
forgotten_event_ids = {e.id for e in self.state.history} - kept_event_ids
|
||||
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
@ -1054,8 +1027,9 @@ class AgentController:
|
||||
|
||||
# Add an error event to trigger another step by the agent
|
||||
self.event_stream.add_event(
|
||||
AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
CondensationAction(
|
||||
forgotten_events_start_id=min(forgotten_event_ids),
|
||||
forgotten_events_end_id=max(forgotten_event_ids),
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
@ -1133,10 +1107,6 @@ class AgentController:
|
||||
# if it's an action with source == EventSource.AGENT, we're good
|
||||
break
|
||||
|
||||
# Save where to continue from in next reload
|
||||
if kept_events:
|
||||
self.state.truncation_id = kept_events[0].id
|
||||
|
||||
# Ensure first user message is included
|
||||
if first_user_msg and first_user_msg not in kept_events:
|
||||
kept_events = [first_user_msg] + kept_events
|
||||
|
||||
@ -15,6 +15,7 @@ from openhands.events.action import (
|
||||
from openhands.events.action.agent import AgentFinishAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.view import View
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_agent_state_filename
|
||||
|
||||
@ -96,8 +97,6 @@ class State:
|
||||
# start_id and end_id track the range of events in history
|
||||
start_id: int = -1
|
||||
end_id: int = -1
|
||||
# truncation_id tracks where to load history after context window truncation
|
||||
truncation_id: int = -1
|
||||
|
||||
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
|
||||
# NOTE: This will never be used by the controller, but it can be used by different
|
||||
@ -170,6 +169,12 @@ class State:
|
||||
# don't pickle history, it will be restored from the event stream
|
||||
state = self.__dict__.copy()
|
||||
state['history'] = []
|
||||
|
||||
# Remove any view caching attributes. They'll be rebuilt frmo the
|
||||
# history after that gets reloaded.
|
||||
state.pop('_history_checksum', None)
|
||||
state.pop('_view', None)
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
@ -183,7 +188,7 @@ class State:
|
||||
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
last_user_message = None
|
||||
last_user_message_image_urls: list[str] | None = []
|
||||
for event in reversed(self.history):
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
last_user_message = event.content
|
||||
last_user_message_image_urls = event.image_urls
|
||||
@ -194,13 +199,13 @@ class State:
|
||||
return last_user_message, last_user_message_image_urls
|
||||
|
||||
def get_last_agent_message(self) -> MessageAction | None:
|
||||
for event in reversed(self.history):
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.AGENT:
|
||||
return event
|
||||
return None
|
||||
|
||||
def get_last_user_message(self) -> MessageAction | None:
|
||||
for event in reversed(self.history):
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
return event
|
||||
return None
|
||||
@ -211,7 +216,22 @@ class State:
|
||||
'trace_version': openhands.__version__,
|
||||
'tags': [
|
||||
f'agent:{agent_name}',
|
||||
f'web_host:{os.environ.get("WEB_HOST", "unspecified")}',
|
||||
f"web_host:{os.environ.get('WEB_HOST', 'unspecified')}",
|
||||
f'openhands_version:{openhands.__version__}',
|
||||
],
|
||||
}
|
||||
|
||||
@property
|
||||
def view(self) -> View:
|
||||
# Compute a simple checksum from the history to see if we can re-use any
|
||||
# cached view.
|
||||
history_checksum = len(self.history)
|
||||
old_history_checksum = getattr(self, '_history_checksum', -1)
|
||||
|
||||
# If the history has changed, we need to re-create the view and update
|
||||
# the caching.
|
||||
if history_checksum != old_history_checksum:
|
||||
self._history_checksum = history_checksum
|
||||
self._view = View.from_events(self.history)
|
||||
|
||||
return self._view
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
from openhands.memory.condenser import Condenser
|
||||
|
||||
__all__ = ['Condenser']
|
||||
@ -2,15 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, overload
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config.condenser_config import CondenserConfig
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.memory.view import View
|
||||
|
||||
CONDENSER_METADATA_KEY = 'condenser_meta'
|
||||
"""Key identifying where metadata is stored in a `State` object's `extra_data` field."""
|
||||
@ -34,69 +33,6 @@ CONDENSER_REGISTRY: dict[type[CondenserConfig], type[Condenser]] = {}
|
||||
"""Registry of condenser configurations to their corresponding condenser classes."""
|
||||
|
||||
|
||||
class View(BaseModel):
|
||||
"""Linearly ordered view of events.
|
||||
|
||||
Produced by a condenser to indicate the included events are ready to process as LLM input.
|
||||
"""
|
||||
|
||||
events: list[Event]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.events)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.events)
|
||||
|
||||
# To preserve list-like indexing, we ideally support slicing and position-based indexing.
|
||||
# The only challenge with that is switching the return type based on the input type -- we
|
||||
# can mark the different signatures for MyPy with `@overload` decorators.
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> list[Event]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> Event: ...
|
||||
|
||||
def __getitem__(self, key: int | slice) -> Event | list[Event]:
|
||||
if isinstance(key, slice):
|
||||
start, stop, step = key.indices(len(self))
|
||||
return [self[i] for i in range(start, stop, step)]
|
||||
elif isinstance(key, int):
|
||||
return self.events[key]
|
||||
else:
|
||||
raise ValueError(f'Invalid key type: {type(key)}')
|
||||
|
||||
@staticmethod
|
||||
def from_events(events: list[Event]) -> View:
|
||||
"""Create a view from a list of events, respecting the semantics of any condensation events."""
|
||||
forgotten_event_ids: set[int] = set()
|
||||
for event in events:
|
||||
if isinstance(event, CondensationAction):
|
||||
forgotten_event_ids.update(event.forgotten)
|
||||
|
||||
kept_events = [event for event in events if event.id not in forgotten_event_ids]
|
||||
|
||||
# If we have a summary, insert it at the specified offset.
|
||||
summary: str | None = None
|
||||
summary_offset: int | None = None
|
||||
|
||||
# The relevant summary is always in the last condensation event (i.e., the most recent one).
|
||||
for event in reversed(events):
|
||||
if isinstance(event, CondensationAction):
|
||||
if event.summary is not None and event.summary_offset is not None:
|
||||
summary = event.summary
|
||||
summary_offset = event.summary_offset
|
||||
break
|
||||
|
||||
if summary is not None and summary_offset is not None:
|
||||
kept_events.insert(
|
||||
summary_offset, AgentCondensationObservation(content=summary)
|
||||
)
|
||||
|
||||
return View(events=kept_events)
|
||||
|
||||
|
||||
class Condensation(BaseModel):
|
||||
"""Produced by a condenser to indicate the history has been condensed."""
|
||||
|
||||
@ -150,13 +86,13 @@ class Condenser(ABC):
|
||||
self.write_metadata(state)
|
||||
|
||||
@abstractmethod
|
||||
def condense(self, events: list[Event]) -> View | Condensation:
|
||||
def condense(self, View) -> View | Condensation:
|
||||
"""Condense a sequence of events into a potentially smaller list.
|
||||
|
||||
New condenser strategies should override this method to implement their own condensation logic. Call `self.add_metadata` in the implementation to record any relevant per-condensation diagnostic information.
|
||||
|
||||
Args:
|
||||
events: A list of events representing the entire history of the agent.
|
||||
View: A view of the history containing all events that should be condensed.
|
||||
|
||||
Returns:
|
||||
View | Condensation: A condensed view of the events or an event indicating the history has been condensed.
|
||||
@ -165,7 +101,7 @@ class Condenser(ABC):
|
||||
def condensed_history(self, state: State) -> View | Condensation:
|
||||
"""Condense the state's history."""
|
||||
with self.metadata_batch(state):
|
||||
return self.condense(state.history)
|
||||
return self.condense(state.view)
|
||||
|
||||
@classmethod
|
||||
def register_config(cls, configuration_type: type[CondenserConfig]) -> None:
|
||||
@ -221,10 +157,7 @@ class RollingCondenser(Condenser, ABC):
|
||||
def get_condensation(self, view: View) -> Condensation:
|
||||
"""Get the condensation from a view."""
|
||||
|
||||
def condense(self, events: list[Event]) -> View | Condensation:
|
||||
# Convert the state to a view. This might require some condenser-specific logic.
|
||||
view = View.from_events(events)
|
||||
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
# If we trigger the condenser-specific condensation threshold, compute and return
|
||||
# the condensation.
|
||||
if self.should_condense(view):
|
||||
|
||||
@ -17,11 +17,11 @@ class BrowserOutputCondenser(Condenser):
|
||||
self.attention_window = attention_window
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> View | Condensation:
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
"""Replace the content of browser observations outside of the attention window with a placeholder."""
|
||||
results: list[Event] = []
|
||||
cnt: int = 0
|
||||
for event in reversed(events):
|
||||
for event in reversed(view):
|
||||
if (
|
||||
isinstance(event, BrowserOutputObservation)
|
||||
and cnt >= self.attention_window
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
class NoOpCondenser(Condenser):
|
||||
"""A condenser that does nothing to the event sequence."""
|
||||
|
||||
def condense(self, events: list[Event]) -> View | Condensation:
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
"""Returns the list of events unchanged."""
|
||||
return View(events=events)
|
||||
return view
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser:
|
||||
|
||||
@ -15,14 +15,11 @@ class ObservationMaskingCondenser(Condenser):
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> View | Condensation:
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
"""Replace the content of observations outside of the attention window with a placeholder."""
|
||||
results: list[Event] = []
|
||||
for i, event in enumerate(events):
|
||||
if (
|
||||
isinstance(event, Observation)
|
||||
and i < len(events) - self.attention_window
|
||||
):
|
||||
for i, event in enumerate(view):
|
||||
if isinstance(event, Observation) and i < len(view) - self.attention_window:
|
||||
results.append(AgentCondensationObservation('<MASKED>'))
|
||||
else:
|
||||
results.append(event)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import RecentEventsCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
@ -14,11 +13,11 @@ class RecentEventsCondenser(Condenser):
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, events: list[Event]) -> View | Condensation:
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
"""Keep only the most recent events (up to `max_events`)."""
|
||||
head = events[: self.keep_first]
|
||||
head = view[: self.keep_first]
|
||||
tail_length = max(0, self.max_events - len(head))
|
||||
tail = events[-tail_length:]
|
||||
tail = view[-tail_length:]
|
||||
return View(events=head + tail)
|
||||
|
||||
@classmethod
|
||||
|
||||
72
openhands/memory/view.py
Normal file
72
openhands/memory/view.py
Normal file
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
|
||||
|
||||
class View(BaseModel):
|
||||
"""Linearly ordered view of events.
|
||||
|
||||
Produced by a condenser to indicate the included events are ready to process as LLM input.
|
||||
"""
|
||||
|
||||
events: list[Event]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.events)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.events)
|
||||
|
||||
# To preserve list-like indexing, we ideally support slicing and position-based indexing.
|
||||
# The only challenge with that is switching the return type based on the input type -- we
|
||||
# can mark the different signatures for MyPy with `@overload` decorators.
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> list[Event]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> Event: ...
|
||||
|
||||
def __getitem__(self, key: int | slice) -> Event | list[Event]:
|
||||
if isinstance(key, slice):
|
||||
start, stop, step = key.indices(len(self))
|
||||
return [self[i] for i in range(start, stop, step)]
|
||||
elif isinstance(key, int):
|
||||
return self.events[key]
|
||||
else:
|
||||
raise ValueError(f'Invalid key type: {type(key)}')
|
||||
|
||||
@staticmethod
|
||||
def from_events(events: list[Event]) -> View:
|
||||
"""Create a view from a list of events, respecting the semantics of any condensation events."""
|
||||
forgotten_event_ids: set[int] = set()
|
||||
for event in events:
|
||||
if isinstance(event, CondensationAction):
|
||||
forgotten_event_ids.update(event.forgotten)
|
||||
|
||||
kept_events = [event for event in events if event.id not in forgotten_event_ids]
|
||||
|
||||
# If we have a summary, insert it at the specified offset.
|
||||
summary: str | None = None
|
||||
summary_offset: int | None = None
|
||||
|
||||
# The relevant summary is always in the last condensation event (i.e., the most recent one).
|
||||
for event in reversed(events):
|
||||
if isinstance(event, CondensationAction):
|
||||
if event.summary is not None and event.summary_offset is not None:
|
||||
summary = event.summary
|
||||
summary_offset = event.summary_offset
|
||||
break
|
||||
|
||||
if summary is not None and summary_offset is not None:
|
||||
kept_events.insert(
|
||||
summary_offset, AgentCondensationObservation(content=summary)
|
||||
)
|
||||
|
||||
return View(events=kept_events)
|
||||
@ -20,6 +20,7 @@ from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.observation.commands import CmdOutputObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
@ -643,19 +644,27 @@ async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_ca
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_window_exceeded_error_handling(mock_agent, mock_event_stream):
|
||||
"""Test that context window exceeded errors are handled correctly by truncating history."""
|
||||
async def test_context_window_exceeded_error_handling(
|
||||
mock_agent, mock_runtime, test_event_stream
|
||||
):
|
||||
"""Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact."""
|
||||
max_iterations = 5
|
||||
error_after = 2
|
||||
|
||||
class StepState:
|
||||
def __init__(self):
|
||||
self.has_errored = False
|
||||
self.index = 0
|
||||
self.views = []
|
||||
|
||||
def step(self, state: State):
|
||||
# Append a few messages to the history -- these will be truncated when we throw the error
|
||||
state.history = [
|
||||
MessageAction(content='Test message 0'),
|
||||
MessageAction(content='Test message 1'),
|
||||
]
|
||||
self.views.append(state.view)
|
||||
|
||||
# Wait until the right step to throw the error, and make sure we
|
||||
# only throw it once.
|
||||
if self.index < error_after or self.has_errored:
|
||||
self.index += 1
|
||||
return MessageAction(content=f'Test message {self.index}')
|
||||
|
||||
error = ContextWindowExceededError(
|
||||
message='prompt is too long: 233885 tokens > 200000 maximum',
|
||||
@ -665,28 +674,78 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
|
||||
self.has_errored = True
|
||||
raise error
|
||||
|
||||
state = StepState()
|
||||
mock_agent.step = state.step
|
||||
step_state = StepState()
|
||||
mock_agent.step = step_state.step
|
||||
mock_agent.config = AgentConfig()
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
# Because we're sending message actions, we need to respond to the recall
|
||||
# actions that get generated as a response.
|
||||
|
||||
# We do that by playing the role of the recall module -- subscribe to the
|
||||
# event stream and respond to recall actions by inserting fake recall
|
||||
# obesrvations.
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
mock_runtime.event_stream = test_event_stream
|
||||
|
||||
# Now we can run the controller for a fixed number of steps. Since the step
|
||||
# state is set to error out before then, if this terminates and we have a
|
||||
# record of the error being thrown we can be confident that the controller
|
||||
# handles the truncation correctly.
|
||||
final_state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=AppConfig(max_iterations=max_iterations),
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Set the agent running and take a step in the controller -- this is similar
|
||||
# to taking a single step using `run_controller`, but much easier to control
|
||||
# termination for testing purposes
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
await controller._step()
|
||||
# Check that the context window exception was thrown and the controller
|
||||
# called the agent's `step` function the right number of times.
|
||||
assert step_state.has_errored
|
||||
assert len(step_state.views) == max_iterations
|
||||
|
||||
# Check that the error was thrown and the history has been truncated
|
||||
assert state.has_errored
|
||||
assert controller.state.history == [MessageAction(content='Test message 1')]
|
||||
# Look at pre/post-step views. Normally, these should always increase in
|
||||
# size (because we return a message action, which triggers a recall, which
|
||||
# triggers a recall response). But if the pre/post-views are on the turn
|
||||
# when we throw the context window exceeded error, we should see the
|
||||
# post-step view compressed.
|
||||
for index, (first_view, second_view) in enumerate(
|
||||
zip(step_state.views[:-1], step_state.views[1:])
|
||||
):
|
||||
if index == error_after:
|
||||
assert len(first_view) > len(second_view)
|
||||
else:
|
||||
assert len(first_view) < len(second_view)
|
||||
|
||||
# The final state's history should contain:
|
||||
# - max_iterations number of message actions,
|
||||
# - max_iterations number of recall actions,
|
||||
# - max_iterations number of recall observations,
|
||||
# - and exactly one condensation action.
|
||||
assert len(final_state.history) == max_iterations * 3 + 1
|
||||
|
||||
# ...but the final state's view should be identical to the last view (plus
|
||||
# the final message action and associated recall action/observation).
|
||||
assert len(final_state.view) == len(step_state.views[-1]) + 3
|
||||
|
||||
# And these two representations of the state are _not_ the same.
|
||||
assert len(final_state.history) != len(final_state.view)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -1168,3 +1227,123 @@ def test_agent_controller_should_step_with_null_observation_cause_zero():
|
||||
assert (
|
||||
result is False
|
||||
), 'should_step should return False for NullObservation with cause = 0'
|
||||
|
||||
|
||||
def test_apply_conversation_window_basic(mock_event_stream, mock_agent):
|
||||
"""Test that the _apply_conversation_window method correctly prunes a list of events."""
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_apply_conversation_window_basic',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a sequence of events with IDs
|
||||
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
# Add agent question
|
||||
agent_msg = MessageAction(
|
||||
content='What task would you like me to perform?', wait_for_response=True
|
||||
)
|
||||
agent_msg._source = EventSource.AGENT
|
||||
agent_msg._id = 2
|
||||
|
||||
# Add user response
|
||||
user_response = MessageAction(
|
||||
content='Please list all files and show me current directory',
|
||||
wait_for_response=False,
|
||||
)
|
||||
user_response._source = EventSource.USER
|
||||
user_response._id = 3
|
||||
|
||||
cmd1 = CmdRunAction(command='ls')
|
||||
cmd1._id = 4
|
||||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
|
||||
obs1._id = 5
|
||||
obs1._cause = 4
|
||||
|
||||
cmd2 = CmdRunAction(command='pwd')
|
||||
cmd2._id = 6
|
||||
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=6)
|
||||
obs2._id = 7
|
||||
obs2._cause = 6
|
||||
|
||||
events = [first_msg, agent_msg, user_response, cmd1, obs1, cmd2, obs2]
|
||||
|
||||
# Apply truncation
|
||||
truncated = controller._apply_conversation_window(events)
|
||||
|
||||
# Verify truncation occured
|
||||
# Should keep first user message and roughly half of other events
|
||||
assert (
|
||||
3 <= len(truncated) < len(events)
|
||||
) # First message + at least one action-observation pair
|
||||
assert truncated[0] == first_msg # First message always preserved
|
||||
assert controller.state.start_id == first_msg._id
|
||||
|
||||
# Verify pairs aren't split
|
||||
for i, event in enumerate(truncated[1:]):
|
||||
if isinstance(event, CmdOutputObservation):
|
||||
assert any(e._id == event._cause for e in truncated[: i + 1])
|
||||
|
||||
|
||||
def test_history_restoration_after_truncation(mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create events with IDs
|
||||
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
events = [first_msg]
|
||||
for i in range(5):
|
||||
cmd = CmdRunAction(command=f'cmd{i}')
|
||||
cmd._id = i + 2
|
||||
obs = CmdOutputObservation(
|
||||
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
|
||||
)
|
||||
obs._cause = cmd._id
|
||||
events.extend([cmd, obs])
|
||||
|
||||
# Set up initial history
|
||||
controller.state.history = events.copy()
|
||||
|
||||
# Force truncation
|
||||
controller.state.history = controller._apply_conversation_window(
|
||||
controller.state.history
|
||||
)
|
||||
|
||||
# Save state
|
||||
saved_start_id = controller.state.start_id
|
||||
saved_history_len = len(controller.state.history)
|
||||
|
||||
# Set up mock event stream for new controller
|
||||
mock_event_stream.get_events.return_value = controller.state.history
|
||||
|
||||
# Create new controller with saved state
|
||||
new_controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
new_controller.state.start_id = saved_start_id
|
||||
new_controller.state.history = mock_event_stream.get_events()
|
||||
|
||||
# Verify restoration
|
||||
assert len(new_controller.state.history) == saved_history_len
|
||||
assert new_controller.state.history[0] == first_msg
|
||||
assert new_controller.state.start_id == saved_start_id
|
||||
|
||||
@ -127,7 +127,6 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
assert session.controller.agent.name == 'test-agent'
|
||||
assert session.controller.state.start_id == 0
|
||||
assert session.controller.state.end_id == -1
|
||||
assert session.controller.state.truncation_id == -1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -164,7 +163,6 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
mock_restored_state = MagicMock(spec=State)
|
||||
mock_restored_state.start_id = -1
|
||||
mock_restored_state.end_id = -1
|
||||
mock_restored_state.truncation_id = -1
|
||||
mock_restored_state.max_iterations = 5
|
||||
|
||||
# Create a spy on set_initial_state by subclassing AgentController
|
||||
@ -211,4 +209,3 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
assert session.controller.state.max_iterations == 5
|
||||
assert session.controller.state.start_id == 0
|
||||
assert session.controller.state.end_id == -1
|
||||
assert session.controller.state.truncation_id == -1
|
||||
|
||||
@ -88,16 +88,6 @@ def mock_llm() -> LLM:
|
||||
return mock_llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state() -> State:
|
||||
"""Mocks a State object with the only parameters needed for testing condensers: history and extra_data."""
|
||||
mock_state = MagicMock(spec=State)
|
||||
mock_state.history = []
|
||||
mock_state.extra_data = {}
|
||||
|
||||
return mock_state
|
||||
|
||||
|
||||
class RollingCondenserTestHarness:
|
||||
"""Test harness for rolling condensers.
|
||||
|
||||
@ -120,21 +110,19 @@ class RollingCondenserTestHarness:
|
||||
|
||||
This generator assumes we're starting from an empty history.
|
||||
"""
|
||||
mock_state = MagicMock()
|
||||
mock_state.extra_data = {}
|
||||
mock_state.history = []
|
||||
state = State()
|
||||
|
||||
for event in events:
|
||||
mock_state.history.append(event)
|
||||
state.history.append(event)
|
||||
for callback in self.callbacks:
|
||||
callback(mock_state.history)
|
||||
callback(state.history)
|
||||
|
||||
match self.condenser.condensed_history(mock_state):
|
||||
match self.condenser.condensed_history(state):
|
||||
case View() as view:
|
||||
yield view
|
||||
|
||||
case Condensation(event=condensation_event):
|
||||
mock_state.history.append(condensation_event)
|
||||
state.history.append(condensation_event)
|
||||
|
||||
def expected_size(self, index: int, max_size: int) -> int:
|
||||
"""Calculate the expected size of the view at the given index.
|
||||
@ -180,12 +168,11 @@ def test_noop_condenser():
|
||||
create_test_event('Event 2'),
|
||||
create_test_event('Event 3'),
|
||||
]
|
||||
|
||||
mock_state = MagicMock()
|
||||
mock_state.history = events
|
||||
state = State()
|
||||
state.history = events
|
||||
|
||||
condenser = NoOpCondenser()
|
||||
result = condenser.condensed_history(mock_state)
|
||||
result = condenser.condensed_history(state)
|
||||
|
||||
assert result == View(events=events)
|
||||
|
||||
@ -200,7 +187,7 @@ def test_observation_masking_condenser_from_config():
|
||||
assert condenser.attention_window == attention_window
|
||||
|
||||
|
||||
def test_observation_masking_condenser_respects_attention_window(mock_state):
|
||||
def test_observation_masking_condenser_respects_attention_window():
|
||||
"""Test that ObservationMaskingCondenser only masks events outside the attention window."""
|
||||
attention_window = 3
|
||||
condenser = ObservationMaskingCondenser(attention_window=attention_window)
|
||||
@ -213,8 +200,9 @@ def test_observation_masking_condenser_respects_attention_window(mock_state):
|
||||
Observation('Observation 2'),
|
||||
]
|
||||
|
||||
mock_state.history = events
|
||||
result = condenser.condensed_history(mock_state)
|
||||
state = State()
|
||||
state.history = events
|
||||
result = condenser.condensed_history(state)
|
||||
|
||||
assert len(result) == len(events)
|
||||
|
||||
@ -239,7 +227,7 @@ def test_browser_output_condenser_from_config():
|
||||
assert condenser.attention_window == attention_window
|
||||
|
||||
|
||||
def test_browser_output_condenser_respects_attention_window(mock_state):
|
||||
def test_browser_output_condenser_respects_attention_window():
|
||||
"""Test that BrowserOutputCondenser only masks events outside the attention window."""
|
||||
attention_window = 3
|
||||
condenser = BrowserOutputCondenser(attention_window=attention_window)
|
||||
@ -253,8 +241,10 @@ def test_browser_output_condenser_respects_attention_window(mock_state):
|
||||
BrowserOutputObservation('Observation 4', url='', trigger_by_action=''),
|
||||
]
|
||||
|
||||
mock_state.history = events
|
||||
result = condenser.condensed_history(mock_state)
|
||||
state = State()
|
||||
state.history = events
|
||||
|
||||
result = condenser.condensed_history(state)
|
||||
|
||||
assert len(result) == len(events)
|
||||
cnt = 4
|
||||
@ -291,19 +281,19 @@ def test_recent_events_condenser():
|
||||
create_test_event('Event 5'),
|
||||
]
|
||||
|
||||
mock_state = MagicMock()
|
||||
mock_state.history = events
|
||||
state = State()
|
||||
state.history = events
|
||||
|
||||
# If the max_events are larger than the number of events, equivalent to a NoOpCondenser.
|
||||
condenser = RecentEventsCondenser(max_events=len(events))
|
||||
result = condenser.condensed_history(mock_state)
|
||||
result = condenser.condensed_history(state)
|
||||
|
||||
assert result == View(events=events)
|
||||
|
||||
# If the max_events are smaller than the number of events, only keep the last few.
|
||||
max_events = 3
|
||||
condenser = RecentEventsCondenser(max_events=max_events)
|
||||
result = condenser.condensed_history(mock_state)
|
||||
result = condenser.condensed_history(state)
|
||||
|
||||
assert len(result) == max_events
|
||||
assert result[0]._message == 'Event 1' # kept from keep_first
|
||||
@ -314,7 +304,7 @@ def test_recent_events_condenser():
|
||||
keep_first = 1
|
||||
max_events = 2
|
||||
condenser = RecentEventsCondenser(keep_first=keep_first, max_events=max_events)
|
||||
result = condenser.condensed_history(mock_state)
|
||||
result = condenser.condensed_history(state)
|
||||
|
||||
assert len(result) == max_events
|
||||
assert result[0]._message == 'Event 1'
|
||||
@ -324,7 +314,7 @@ def test_recent_events_condenser():
|
||||
keep_first = 2
|
||||
max_events = 3
|
||||
condenser = RecentEventsCondenser(keep_first=keep_first, max_events=max_events)
|
||||
result = condenser.condensed_history(mock_state)
|
||||
result = condenser.condensed_history(state)
|
||||
|
||||
assert len(result) == max_events
|
||||
assert result[0]._message == 'Event 1' # kept from keep_first
|
||||
@ -380,7 +370,7 @@ def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm, mock_state):
|
||||
def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
"""Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events."""
|
||||
max_size = 10
|
||||
keep_first = 3
|
||||
@ -547,7 +537,7 @@ def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_state):
|
||||
def test_llm_attention_condenser_handles_too_many_events(mock_llm):
|
||||
"""Test that the LLMAttentionCondenser handles when the response contains too many event IDs."""
|
||||
max_size = 2
|
||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
||||
|
||||
58
tests/unit/test_state.py
Normal file
58
tests/unit/test_state.py
Normal file
@ -0,0 +1,58 @@
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.events.event import Event
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
def example_event(index: int) -> Event:
|
||||
event = Event()
|
||||
event._message = f'Test message {index}'
|
||||
event._id = index
|
||||
return event
|
||||
|
||||
|
||||
def test_state_view_caching_avoids_unnecessary_rebuilding():
|
||||
"""Test that the state view caching avoids unnecessarily rebuilding the view when the history hasn't changed."""
|
||||
state = State()
|
||||
state.history = [example_event(i) for i in range(5)]
|
||||
|
||||
# Build the view once.
|
||||
view = state.view
|
||||
|
||||
# Easy way to check that the cache works -- `view` and future calls of
|
||||
# `state.view` should be the same object. We'll check that by using the `id`
|
||||
# of the view.
|
||||
assert id(view) == id(state.view)
|
||||
|
||||
# Add an event to the history. This should produce a different view.
|
||||
state.history.append(example_event(100))
|
||||
|
||||
new_view = state.view
|
||||
assert id(new_view) != id(view)
|
||||
|
||||
# But once we have the new view once, it should be cached.
|
||||
assert id(new_view) == id(state.view)
|
||||
|
||||
|
||||
def test_state_view_cache_not_serialized():
|
||||
"""Test that the fields used to cache view construction are not serialized when state is saved."""
|
||||
state = State()
|
||||
state.history = [example_event(i) for i in range(5)]
|
||||
|
||||
# Build the view once.
|
||||
view = state.view
|
||||
|
||||
# Serialize the state.
|
||||
store = InMemoryFileStore()
|
||||
state.save_to_session('test_sid', store, None)
|
||||
restored_state = State.restore_from_session('test_sid', store, None)
|
||||
|
||||
# The state usually has the history rebuilt from the event stream -- we'll
|
||||
# simulate this by manually setting the state history to the same events.
|
||||
restored_state.history = state.history
|
||||
|
||||
restored_view = restored_state.view
|
||||
|
||||
# Since serialization doesn't include the view cache, the restored view will
|
||||
# be structurally identical but _not_ the same object.
|
||||
assert id(restored_view) != id(view)
|
||||
assert restored_view.events == view.events
|
||||
@ -1,244 +0,0 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
stream = MagicMock()
|
||||
# Mock get_events to return an empty list by default
|
||||
stream.get_events.return_value = []
|
||||
# Mock get_latest_event_id to return a valid integer
|
||||
stream.get_latest_event_id.return_value = 0
|
||||
return stream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
agent = MagicMock()
|
||||
agent.llm = MagicMock()
|
||||
|
||||
# Create a step function that returns an action without an ID
|
||||
def agent_step_fn(state):
|
||||
return MessageAction(content='Agent returned a message')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
class TestTruncation:
|
||||
def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a sequence of events with IDs
|
||||
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
cmd1 = CmdRunAction(command='ls')
|
||||
cmd1._id = 2
|
||||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
|
||||
obs1._id = 3
|
||||
obs1._cause = 2
|
||||
|
||||
cmd2 = CmdRunAction(command='pwd')
|
||||
cmd2._id = 4
|
||||
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
|
||||
obs2._id = 5
|
||||
obs2._cause = 4
|
||||
|
||||
events = [first_msg, cmd1, obs1, cmd2, obs2]
|
||||
|
||||
# Apply truncation
|
||||
truncated = controller._apply_conversation_window(events)
|
||||
|
||||
# Should keep first user message and roughly half of other events
|
||||
assert (
|
||||
len(truncated) >= 3
|
||||
) # First message + at least one action-observation pair
|
||||
assert truncated[0] == first_msg # First message always preserved
|
||||
assert controller.state.start_id == first_msg._id
|
||||
assert controller.state.truncation_id is not None
|
||||
|
||||
# Verify pairs aren't split
|
||||
for i, event in enumerate(truncated[1:]):
|
||||
if isinstance(event, CmdOutputObservation):
|
||||
assert any(e._id == event._cause for e in truncated[: i + 1])
|
||||
|
||||
def test_truncation_does_not_impact_trajectory(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a sequence of events with IDs
|
||||
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
pairs = 10
|
||||
history_len = 1 + 2 * pairs
|
||||
events = [first_msg]
|
||||
for i in range(pairs):
|
||||
cmd = CmdRunAction(command=f'cmd{i}')
|
||||
cmd._id = i + 2
|
||||
obs = CmdOutputObservation(
|
||||
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
|
||||
)
|
||||
obs._cause = cmd._id
|
||||
events.extend([cmd, obs])
|
||||
|
||||
# patch events to history for testing purpose
|
||||
controller.state.history = events
|
||||
|
||||
# Update mock event stream
|
||||
mock_event_stream.get_events.return_value = controller.state.history
|
||||
|
||||
assert len(controller.state.history) == history_len
|
||||
|
||||
# Force apply truncation
|
||||
controller._handle_long_context_error()
|
||||
|
||||
# Check that the history has been truncated before closing the controller
|
||||
assert len(controller.state.history) == 13 < history_len
|
||||
|
||||
# Check that after properly closing the controller, history is recovered
|
||||
asyncio.run(controller.close())
|
||||
assert len(controller.event_stream.get_events()) == history_len
|
||||
assert len(controller.state.history) == history_len
|
||||
assert len(controller.get_trajectory()) == history_len
|
||||
|
||||
def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Setup initial history with IDs
|
||||
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
# Add agent question
|
||||
agent_msg = MessageAction(
|
||||
content='What task would you like me to perform?', wait_for_response=True
|
||||
)
|
||||
agent_msg._source = EventSource.AGENT
|
||||
agent_msg._id = 2
|
||||
|
||||
# Add user response
|
||||
user_response = MessageAction(
|
||||
content='Please list all files and show me current directory',
|
||||
wait_for_response=False,
|
||||
)
|
||||
user_response._source = EventSource.USER
|
||||
user_response._id = 3
|
||||
|
||||
cmd1 = CmdRunAction(command='ls')
|
||||
cmd1._id = 4
|
||||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
|
||||
obs1._id = 5
|
||||
obs1._cause = 4
|
||||
|
||||
# Update mock event stream to include new messages
|
||||
mock_event_stream.get_events.return_value = [
|
||||
first_msg,
|
||||
agent_msg,
|
||||
user_response,
|
||||
cmd1,
|
||||
obs1,
|
||||
]
|
||||
controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
|
||||
original_history_len = len(controller.state.history)
|
||||
|
||||
# Simulate ContextWindowExceededError and truncation
|
||||
controller.state.history = controller._apply_conversation_window(
|
||||
controller.state.history
|
||||
)
|
||||
|
||||
# Verify truncation occurred
|
||||
assert len(controller.state.history) < original_history_len
|
||||
assert controller.state.start_id == first_msg._id
|
||||
assert controller.state.truncation_id is not None
|
||||
assert controller.state.truncation_id > controller.state.start_id
|
||||
|
||||
def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create events with IDs
|
||||
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
events = [first_msg]
|
||||
for i in range(5):
|
||||
cmd = CmdRunAction(command=f'cmd{i}')
|
||||
cmd._id = i + 2
|
||||
obs = CmdOutputObservation(
|
||||
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
|
||||
)
|
||||
obs._cause = cmd._id
|
||||
events.extend([cmd, obs])
|
||||
|
||||
# Set up initial history
|
||||
controller.state.history = events.copy()
|
||||
|
||||
# Force truncation
|
||||
controller.state.history = controller._apply_conversation_window(
|
||||
controller.state.history
|
||||
)
|
||||
|
||||
# Save state
|
||||
saved_start_id = controller.state.start_id
|
||||
saved_truncation_id = controller.state.truncation_id
|
||||
saved_history_len = len(controller.state.history)
|
||||
|
||||
# Set up mock event stream for new controller
|
||||
mock_event_stream.get_events.return_value = controller.state.history
|
||||
|
||||
# Create new controller with saved state
|
||||
new_controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
new_controller.state.start_id = saved_start_id
|
||||
new_controller.state.truncation_id = saved_truncation_id
|
||||
new_controller.state.history = mock_event_stream.get_events()
|
||||
|
||||
# Verify restoration
|
||||
assert len(new_controller.state.history) == saved_history_len
|
||||
assert new_controller.state.history[0] == first_msg
|
||||
assert new_controller.state.start_id == saved_start_id
|
||||
Loading…
x
Reference in New Issue
Block a user