fix: Context window truncation using CondensationAction (#7578)

Co-authored-by: Calvin Smith <calvin@all-hands.dev>
Co-authored-by: Graham Neubig <neubig@gmail.com>
This commit is contained in:
Calvin Smith
2025-03-31 13:47:00 -06:00
committed by GitHub
parent 648c8ffb21
commit abaf0da9fe
17 changed files with 412 additions and 445 deletions

View File

@@ -1,3 +0,0 @@
from openhands.memory.condenser import Condenser
__all__ = ['Condenser']

View File

@@ -2,15 +2,14 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, overload
from typing import Any
from pydantic import BaseModel
from openhands.controller.state.state import State
from openhands.core.config.condenser_config import CondenserConfig
from openhands.events.action.agent import CondensationAction
from openhands.events.event import Event
from openhands.events.observation.agent import AgentCondensationObservation
from openhands.memory.view import View
CONDENSER_METADATA_KEY = 'condenser_meta'
"""Key identifying where metadata is stored in a `State` object's `extra_data` field."""
@@ -34,69 +33,6 @@ CONDENSER_REGISTRY: dict[type[CondenserConfig], type[Condenser]] = {}
"""Registry of condenser configurations to their corresponding condenser classes."""
class View(BaseModel):
"""Linearly ordered view of events.
Produced by a condenser to indicate the included events are ready to process as LLM input.
"""
events: list[Event]
def __len__(self) -> int:
return len(self.events)
def __iter__(self):
return iter(self.events)
# To preserve list-like indexing, we ideally support slicing and position-based indexing.
# The only challenge with that is switching the return type based on the input type -- we
# can mark the different signatures for MyPy with `@overload` decorators.
@overload
def __getitem__(self, key: slice) -> list[Event]: ...
@overload
def __getitem__(self, key: int) -> Event: ...
def __getitem__(self, key: int | slice) -> Event | list[Event]:
if isinstance(key, slice):
start, stop, step = key.indices(len(self))
return [self[i] for i in range(start, stop, step)]
elif isinstance(key, int):
return self.events[key]
else:
raise ValueError(f'Invalid key type: {type(key)}')
@staticmethod
def from_events(events: list[Event]) -> View:
"""Create a view from a list of events, respecting the semantics of any condensation events."""
forgotten_event_ids: set[int] = set()
for event in events:
if isinstance(event, CondensationAction):
forgotten_event_ids.update(event.forgotten)
kept_events = [event for event in events if event.id not in forgotten_event_ids]
# If we have a summary, insert it at the specified offset.
summary: str | None = None
summary_offset: int | None = None
# The relevant summary is always in the last condensation event (i.e., the most recent one).
for event in reversed(events):
if isinstance(event, CondensationAction):
if event.summary is not None and event.summary_offset is not None:
summary = event.summary
summary_offset = event.summary_offset
break
if summary is not None and summary_offset is not None:
kept_events.insert(
summary_offset, AgentCondensationObservation(content=summary)
)
return View(events=kept_events)
class Condensation(BaseModel):
"""Produced by a condenser to indicate the history has been condensed."""
@@ -150,13 +86,13 @@ class Condenser(ABC):
self.write_metadata(state)
@abstractmethod
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, View) -> View | Condensation:
"""Condense a sequence of events into a potentially smaller list.
New condenser strategies should override this method to implement their own condensation logic. Call `self.add_metadata` in the implementation to record any relevant per-condensation diagnostic information.
Args:
events: A list of events representing the entire history of the agent.
View: A view of the history containing all events that should be condensed.
Returns:
View | Condensation: A condensed view of the events or an event indicating the history has been condensed.
@@ -165,7 +101,7 @@ class Condenser(ABC):
def condensed_history(self, state: State) -> View | Condensation:
"""Condense the state's history."""
with self.metadata_batch(state):
return self.condense(state.history)
return self.condense(state.view)
@classmethod
def register_config(cls, configuration_type: type[CondenserConfig]) -> None:
@@ -221,10 +157,7 @@ class RollingCondenser(Condenser, ABC):
def get_condensation(self, view: View) -> Condensation:
"""Get the condensation from a view."""
def condense(self, events: list[Event]) -> View | Condensation:
# Convert the state to a view. This might require some condenser-specific logic.
view = View.from_events(events)
def condense(self, view: View) -> View | Condensation:
# If we trigger the condenser-specific condensation threshold, compute and return
# the condensation.
if self.should_condense(view):

View File

@@ -17,11 +17,11 @@ class BrowserOutputCondenser(Condenser):
self.attention_window = attention_window
super().__init__()
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, view: View) -> View | Condensation:
"""Replace the content of browser observations outside of the attention window with a placeholder."""
results: list[Event] = []
cnt: int = 0
for event in reversed(events):
for event in reversed(view):
if (
isinstance(event, BrowserOutputObservation)
and cnt >= self.attention_window

View File

@@ -1,16 +1,15 @@
from __future__ import annotations
from openhands.core.config.condenser_config import NoOpCondenserConfig
from openhands.events.event import Event
from openhands.memory.condenser.condenser import Condensation, Condenser, View
class NoOpCondenser(Condenser):
"""A condenser that does nothing to the event sequence."""
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, view: View) -> View | Condensation:
"""Returns the list of events unchanged."""
return View(events=events)
return view
@classmethod
def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser:

View File

@@ -15,14 +15,11 @@ class ObservationMaskingCondenser(Condenser):
super().__init__()
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, view: View) -> View | Condensation:
"""Replace the content of observations outside of the attention window with a placeholder."""
results: list[Event] = []
for i, event in enumerate(events):
if (
isinstance(event, Observation)
and i < len(events) - self.attention_window
):
for i, event in enumerate(view):
if isinstance(event, Observation) and i < len(view) - self.attention_window:
results.append(AgentCondensationObservation('<MASKED>'))
else:
results.append(event)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from openhands.core.config.condenser_config import RecentEventsCondenserConfig
from openhands.events.event import Event
from openhands.memory.condenser.condenser import Condensation, Condenser, View
@@ -14,11 +13,11 @@ class RecentEventsCondenser(Condenser):
super().__init__()
def condense(self, events: list[Event]) -> View | Condensation:
def condense(self, view: View) -> View | Condensation:
"""Keep only the most recent events (up to `max_events`)."""
head = events[: self.keep_first]
head = view[: self.keep_first]
tail_length = max(0, self.max_events - len(head))
tail = events[-tail_length:]
tail = view[-tail_length:]
return View(events=head + tail)
@classmethod

72
openhands/memory/view.py Normal file
View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import overload
from pydantic import BaseModel
from openhands.events.action.agent import CondensationAction
from openhands.events.event import Event
from openhands.events.observation.agent import AgentCondensationObservation
class View(BaseModel):
"""Linearly ordered view of events.
Produced by a condenser to indicate the included events are ready to process as LLM input.
"""
events: list[Event]
def __len__(self) -> int:
return len(self.events)
def __iter__(self):
return iter(self.events)
# To preserve list-like indexing, we ideally support slicing and position-based indexing.
# The only challenge with that is switching the return type based on the input type -- we
# can mark the different signatures for MyPy with `@overload` decorators.
@overload
def __getitem__(self, key: slice) -> list[Event]: ...
@overload
def __getitem__(self, key: int) -> Event: ...
def __getitem__(self, key: int | slice) -> Event | list[Event]:
if isinstance(key, slice):
start, stop, step = key.indices(len(self))
return [self[i] for i in range(start, stop, step)]
elif isinstance(key, int):
return self.events[key]
else:
raise ValueError(f'Invalid key type: {type(key)}')
@staticmethod
def from_events(events: list[Event]) -> View:
"""Create a view from a list of events, respecting the semantics of any condensation events."""
forgotten_event_ids: set[int] = set()
for event in events:
if isinstance(event, CondensationAction):
forgotten_event_ids.update(event.forgotten)
kept_events = [event for event in events if event.id not in forgotten_event_ids]
# If we have a summary, insert it at the specified offset.
summary: str | None = None
summary_offset: int | None = None
# The relevant summary is always in the last condensation event (i.e., the most recent one).
for event in reversed(events):
if isinstance(event, CondensationAction):
if event.summary is not None and event.summary_offset is not None:
summary = event.summary
summary_offset = event.summary_offset
break
if summary is not None and summary_offset is not None:
kept_events.insert(
summary_offset, AgentCondensationObservation(content=summary)
)
return View(events=kept_events)