From 42712a44d801f3bcf07747a8d16cff945b94f7f1 Mon Sep 17 00:00:00 2001 From: Calvin Smith Date: Thu, 27 Mar 2025 13:16:31 -0600 Subject: [PATCH] (fix): Condensation events to reconstruct contexts added to event stream (#7353) Co-authored-by: Calvin Smith --- .../agenthub/codeact_agent/codeact_agent.py | 33 +- openhands/controller/agent_controller.py | 4 +- openhands/core/schema/action.py | 3 + openhands/events/action/agent.py | 84 ++++ openhands/events/serialization/action.py | 2 + openhands/memory/condenser/__init__.py | 15 +- openhands/memory/condenser/condenser.py | 133 +++-- .../impl/amortized_forgetting_condenser.py | 31 +- .../impl/browser_output_condenser.py | 6 +- .../condenser/impl/llm_attention_condenser.py | 44 +- .../impl/llm_summarizing_condenser.py | 40 +- .../memory/condenser/impl/no_op_condenser.py | 6 +- .../impl/observation_masking_condenser.py | 6 +- .../condenser/impl/recent_events_condenser.py | 6 +- tests/unit/test_codeact_agent.py | 8 +- tests/unit/test_condenser.py | 476 +++++++----------- tests/unit/test_prompt_caching.py | 10 +- 17 files changed, 485 insertions(+), 422 deletions(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 2ebc00aafb..03aa113877 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -11,8 +11,10 @@ from openhands.events.action import ( Action, AgentFinishAction, ) +from openhands.events.event import Event from openhands.llm.llm import LLM from openhands.memory.condenser import Condenser +from openhands.memory.condenser.condenser import Condensation, View from openhands.memory.conversation_memory import ConversationMemory from openhands.runtime.plugins import ( AgentSkillsRequirement, @@ -92,6 +94,7 @@ class CodeActAgent(Agent): def step(self, state: State) -> Action: """Performs one step using the CodeAct Agent. + This includes gathering info on previous steps and prompting the model to make a command to execute. Parameters: @@ -113,8 +116,23 @@ class CodeActAgent(Agent): if latest_user_message and latest_user_message.content.strip() == '/exit': return AgentFinishAction() - # prepare what we want to send to the LLM - messages = self._get_messages(state) + # Condense the events from the state. If we get a view we'll pass those + # to the conversation manager for processing, but if we get a condensation + # event we'll just return that instead of an action. The controller will + # immediately ask the agent to step again with the new view. + condensed_history: list[Event] = [] + match self.condenser.condensed_history(state): + case View(events=events): + condensed_history = events + + case Condensation(action=condensation_action): + return condensation_action + + logger.debug( + f'Processing {len(condensed_history)} events from a total of {len(state.history)} events' + ) + + messages = self._get_messages(condensed_history) params: dict = { 'messages': self.llm.format_messages_for_llm(messages), } @@ -127,7 +145,7 @@ class CodeActAgent(Agent): self.pending_actions.append(action) return self.pending_actions.popleft() - def _get_messages(self, state: State) -> list[Message]: + def _get_messages(self, events: list[Event]) -> list[Message]: """Constructs the message history for the LLM conversation. This method builds a structured conversation history by processing events from the state @@ -143,7 +161,7 @@ class CodeActAgent(Agent): 6. Adds environment reminders for non-function-calling mode Args: - state (State): The current state object containing conversation history and other metadata + events: The list of events to convert to messages Returns: list[Message]: A list of formatted messages ready for LLM consumption, including: @@ -167,13 +185,6 @@ class CodeActAgent(Agent): with_caching=self.llm.is_caching_prompt_active() ) - # Condense the events from the state. - events = self.condenser.condensed_history(state) - - logger.debug( - f'Processing {len(events)} events from a total of {len(state.history)} events' - ) - # Use ConversationMemory to process events messages = self.conversation_memory.process_events( condensed_history=events, diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 43c6250bdb..07ad189661 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -54,7 +54,7 @@ from openhands.events.action import ( MessageAction, NullAction, ) -from openhands.events.action.agent import RecallAction +from openhands.events.action.agent import CondensationAction, RecallAction from openhands.events.event import Event from openhands.events.observation import ( AgentCondensationObservation, @@ -305,6 +305,8 @@ class AgentController: return True if isinstance(event, AgentDelegateAction): return True + if isinstance(event, CondensationAction): + return True return False if isinstance(event, Observation): if ( diff --git a/openhands/core/schema/action.py b/openhands/core/schema/action.py index ccd1c14ae1..9e24bea542 100644 --- a/openhands/core/schema/action.py +++ b/openhands/core/schema/action.py @@ -80,3 +80,6 @@ class ActionType(str, Enum): RECALL = 'recall' """Retrieves content from a user workspace, microagent, or other source.""" + + CONDENSATION = 'condensation' + """Condenses a list of events into a summary.""" diff --git a/openhands/events/action/agent.py b/openhands/events/action/agent.py index 958ad7d7d7..e7ba216628 100644 --- a/openhands/events/action/agent.py +++ b/openhands/events/action/agent.py @@ -111,3 +111,87 @@ class RecallAction(Action): ret = '**RecallAction**\n' ret += f'QUERY: {self.query[:50]}' return ret + + +@dataclass +class CondensationAction(Action): + """This action indicates a condensation of the conversation history is happening. + + There are two ways to specify the events to be forgotten: + 1. By providing a list of event IDs. + 2. By providing the start and end IDs of a range of events. + + In the second case, we assume that event IDs are monotonically increasing, and that _all_ events between the start and end IDs are to be forgotten. + + Raises: + ValueError: If the optional fields are not instantiated in a valid configuration. + """ + + action: str = ActionType.CONDENSATION + + forgotten_event_ids: list[int] | None = None + """The IDs of the events that are being forgotten (removed from the `View` given to the LLM).""" + + forgotten_events_start_id: int | None = None + """The ID of the first event to be forgotten in a range of events.""" + + forgotten_events_end_id: int | None = None + """The ID of the last event to be forgotten in a range of events.""" + + summary: str | None = None + """An optional summary of the events being forgotten.""" + + summary_offset: int | None = None + """An optional offset to the start of the resulting view indicating where the summary should be inserted.""" + + def _validate_field_polymorphism(self) -> bool: + """Check if the optional fields are instantiated in a valid configuration.""" + # For the forgotton events, there are only two valid configurations: + # 1. We're forgetting events based on the list of provided IDs, or + using_event_ids = self.forgotten_event_ids is not None + # 2. We're forgetting events based on the range of IDs. + using_event_range = ( + self.forgotten_events_start_id is not None + and self.forgotten_events_end_id is not None + ) + + # Either way, we can only have one of the two valid configurations. + forgotten_event_configuration = using_event_ids ^ using_event_range + + # We also need to check that if the summary is provided, so is the + # offset (and vice versa). + summary_configuration = ( + self.summary is None and self.summary_offset is None + ) or (self.summary is not None and self.summary_offset is not None) + + return forgotten_event_configuration and summary_configuration + + def __post_init__(self): + if not self._validate_field_polymorphism(): + raise ValueError('Invalid configuration of the optional fields.') + + @property + def forgotten(self) -> list[int]: + """The list of event IDs that should be forgotten.""" + # Start by making sure the fields are instantiated in a valid + # configuration. We check this whenever the event is initialized, but we + # can't make the dataclass immutable so we need to check it again here + # to make sure the configuration is still valid. + if not self._validate_field_polymorphism(): + raise ValueError('Invalid configuration of the optional fields.') + + if self.forgotten_event_ids is not None: + return self.forgotten_event_ids + + # If we've gotten this far, the start/end IDs are not None. + assert self.forgotten_events_start_id is not None + assert self.forgotten_events_end_id is not None + return list( + range(self.forgotten_events_start_id, self.forgotten_events_end_id + 1) + ) + + @property + def message(self) -> str: + if self.summary: + return f'Summary: {self.summary}' + return f'Condenser is dropping the events: {self.forgotten}.' diff --git a/openhands/events/serialization/action.py b/openhands/events/serialization/action.py index 5c314b0f80..9e6d366cb6 100644 --- a/openhands/events/serialization/action.py +++ b/openhands/events/serialization/action.py @@ -8,6 +8,7 @@ from openhands.events.action.agent import ( AgentRejectAction, AgentThinkAction, ChangeAgentStateAction, + CondensationAction, RecallAction, ) from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction @@ -39,6 +40,7 @@ actions = ( RecallAction, ChangeAgentStateAction, MessageAction, + CondensationAction, ) ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined] diff --git a/openhands/memory/condenser/__init__.py b/openhands/memory/condenser/__init__.py index b7bd0244ef..8c20a2b257 100644 --- a/openhands/memory/condenser/__init__.py +++ b/openhands/memory/condenser/__init__.py @@ -1,4 +1,15 @@ import openhands.memory.condenser.impl # noqa F401 (we import this to get the condensers registered) -from openhands.memory.condenser.condenser import Condenser, get_condensation_metadata +from openhands.memory.condenser.condenser import ( + Condenser, + get_condensation_metadata, + View, + Condensation, +) -__all__ = ['Condenser', 'get_condensation_metadata', 'CONDENSER_REGISTRY'] +__all__ = [ + 'Condenser', + 'get_condensation_metadata', + 'CONDENSER_REGISTRY', + 'View', + 'Condensation', +] diff --git a/openhands/memory/condenser/condenser.py b/openhands/memory/condenser/condenser.py index 411ed39386..e2b8213d76 100644 --- a/openhands/memory/condenser/condenser.py +++ b/openhands/memory/condenser/condenser.py @@ -2,13 +2,15 @@ from __future__ import annotations from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any +from typing import Any, overload -from typing_extensions import override +from pydantic import BaseModel from openhands.controller.state.state import State from openhands.core.config.condenser_config import CondenserConfig +from openhands.events.action.agent import CondensationAction from openhands.events.event import Event +from openhands.events.observation.agent import AgentCondensationObservation CONDENSER_METADATA_KEY = 'condenser_meta' """Key identifying where metadata is stored in a `State` object's `extra_data` field.""" @@ -32,6 +34,75 @@ CONDENSER_REGISTRY: dict[type[CondenserConfig], type[Condenser]] = {} """Registry of condenser configurations to their corresponding condenser classes.""" +class View(BaseModel): + """Linearly ordered view of events. + + Produced by a condenser to indicate the included events are ready to process as LLM input. + """ + + events: list[Event] + + def __len__(self) -> int: + return len(self.events) + + def __iter__(self): + return iter(self.events) + + # To preserve list-like indexing, we ideally support slicing and position-based indexing. + # The only challenge with that is switching the return type based on the input type -- we + # can mark the different signatures for MyPy with `@overload` decorators. + + @overload + def __getitem__(self, key: slice) -> list[Event]: ... + + @overload + def __getitem__(self, key: int) -> Event: ... + + def __getitem__(self, key: int | slice) -> Event | list[Event]: + if isinstance(key, slice): + start, stop, step = key.indices(len(self)) + return [self[i] for i in range(start, stop, step)] + elif isinstance(key, int): + return self.events[key] + else: + raise ValueError(f'Invalid key type: {type(key)}') + + @staticmethod + def from_events(events: list[Event]) -> View: + """Create a view from a list of events, respecting the semantics of any condensation events.""" + forgotten_event_ids: set[int] = set() + for event in events: + if isinstance(event, CondensationAction): + forgotten_event_ids.update(event.forgotten) + + kept_events = [event for event in events if event.id not in forgotten_event_ids] + + # If we have a summary, insert it at the specified offset. + summary: str | None = None + summary_offset: int | None = None + + # The relevant summary is always in the last condensation event (i.e., the most recent one). + for event in reversed(events): + if isinstance(event, CondensationAction): + if event.summary is not None and event.summary_offset is not None: + summary = event.summary + summary_offset = event.summary_offset + break + + if summary is not None and summary_offset is not None: + kept_events.insert( + summary_offset, AgentCondensationObservation(content=summary) + ) + + return View(events=kept_events) + + +class Condensation(BaseModel): + """Produced by a condenser to indicate the history has been condensed.""" + + action: CondensationAction + + class Condenser(ABC): """Abstract condenser interface. @@ -39,10 +110,7 @@ class Condenser(ABC): Agents can use condensers to reduce the amount of events they need to consider when deciding which action to take. To use a condenser, agents can call the `condensed_history` method on the current `State` being considered and use the results instead of the full history. - Example usage:: - - condenser = Condenser.from_config(condenser_config) - events = condenser.condensed_history(state) + If the condenser returns a `Condensation` instead of a `View`, the agent should return `Condensation.action` instead of producing its own action. On the next agent step the condenser will use that condensation event to produce a new `View`. """ def __init__(self): @@ -82,7 +150,7 @@ class Condenser(ABC): self.write_metadata(state) @abstractmethod - def condense(self, events: list[Event]) -> list[Event]: + def condense(self, events: list[Event]) -> View | Condensation: """Condense a sequence of events into a potentially smaller list. New condenser strategies should override this method to implement their own condensation logic. Call `self.add_metadata` in the implementation to record any relevant per-condensation diagnostic information. @@ -91,10 +159,10 @@ class Condenser(ABC): events: A list of events representing the entire history of the agent. Returns: - list[Event]: An event sequence representing a condensed history of the agent. + View | Condensation: A condensed view of the events or an event indicating the history has been condensed. """ - def condensed_history(self, state: State) -> list[Event]: + def condensed_history(self, state: State) -> View | Condensation: """Condense the state's history.""" with self.metadata_batch(state): return self.condense(state.history) @@ -140,39 +208,28 @@ class Condenser(ABC): class RollingCondenser(Condenser, ABC): """Base class for a specialized condenser strategy that applies condensation to a rolling history. - The rolling history is computed by appending new events to the most recent condensation. For example, the sequence of calls:: + The rolling history is generated by `View.from_events`, which analyzes all events in the history and produces a `View` object representing what will be sent to the LLM. - assert state.history == [event1, event2, event3] - condensation = condenser.condensed_history(state) - - # ...new events are added to the state... - - assert state.history == [event1, event2, event3, event4, event5] - condenser.condensed_history(state) - - will result in second call to `condensed_history` passing `condensation + [event4, event5]` to the `condense` method. + If `should_condense` says so, the condenser is then responsible for generating a `Condensation` object from the `View` object. This will be added to the event history which should -- when given to `get_view` -- produce the condensed `View` to be passed to the LLM. """ - def __init__(self) -> None: - self._condensation: list[Event] = [] - self._last_history_length: int = 0 + @abstractmethod + def should_condense(self, view: View) -> bool: + """Determine if a view should be condensed.""" - super().__init__() + @abstractmethod + def get_condensation(self, view: View) -> Condensation: + """Get the condensation from a view.""" - @override - def condensed_history(self, state: State) -> list[Event]: - # The history should grow monotonically -- if it doesn't, something has - # truncated the history and we need to reset our tracking. - if len(state.history) < self._last_history_length: - self._condensation = [] - self._last_history_length = 0 + def condense(self, events: list[Event]) -> View | Condensation: + # Convert the state to a view. This might require some condenser-specific logic. + view = View.from_events(events) - new_events = state.history[self._last_history_length :] + # If we trigger the condenser-specific condensation threshold, compute and return + # the condensation. + if self.should_condense(view): + return self.get_condensation(view) - with self.metadata_batch(state): - results = self.condense(self._condensation + new_events) - - self._condensation = results - self._last_history_length = len(state.history) - - return results + # Otherwise we're safe to just return the view. + else: + return view diff --git a/openhands/memory/condenser/impl/amortized_forgetting_condenser.py b/openhands/memory/condenser/impl/amortized_forgetting_condenser.py index ee6e26aa06..f9fd54bd49 100644 --- a/openhands/memory/condenser/impl/amortized_forgetting_condenser.py +++ b/openhands/memory/condenser/impl/amortized_forgetting_condenser.py @@ -1,8 +1,12 @@ from __future__ import annotations from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig -from openhands.events.event import Event -from openhands.memory.condenser.condenser import RollingCondenser +from openhands.events.action.agent import CondensationAction +from openhands.memory.condenser.condenser import ( + Condensation, + RollingCondenser, + View, +) class AmortizedForgettingCondenser(RollingCondenser): @@ -32,18 +36,25 @@ class AmortizedForgettingCondenser(RollingCondenser): super().__init__() - def condense(self, events: list[Event]) -> list[Event]: - """Apply the amortized forgetting strategy to the given list of events.""" - if len(events) <= self.max_size: - return events - + def get_condensation(self, view: View) -> Condensation: target_size = self.max_size // 2 - head = events[: self.keep_first] + head = view[: self.keep_first] events_from_tail = target_size - len(head) - tail = events[-events_from_tail:] + tail = view[-events_from_tail:] - return head + tail + event_ids_to_keep = {event.id for event in head + tail} + event_ids_to_forget = {event.id for event in view} - event_ids_to_keep + + event = CondensationAction( + forgotten_events_start_id=min(event_ids_to_forget), + forgotten_events_end_id=max(event_ids_to_forget), + ) + + return Condensation(action=event) + + def should_condense(self, view: View) -> bool: + return len(view) > self.max_size @classmethod def from_config( diff --git a/openhands/memory/condenser/impl/browser_output_condenser.py b/openhands/memory/condenser/impl/browser_output_condenser.py index b21a6c8f46..b1b445aa39 100644 --- a/openhands/memory/condenser/impl/browser_output_condenser.py +++ b/openhands/memory/condenser/impl/browser_output_condenser.py @@ -4,7 +4,7 @@ from openhands.core.config.condenser_config import BrowserOutputCondenserConfig from openhands.events.event import Event from openhands.events.observation import BrowserOutputObservation from openhands.events.observation.agent import AgentCondensationObservation -from openhands.memory.condenser.condenser import Condenser +from openhands.memory.condenser.condenser import Condensation, Condenser, View class BrowserOutputCondenser(Condenser): @@ -17,7 +17,7 @@ class BrowserOutputCondenser(Condenser): self.attention_window = attention_window super().__init__() - def condense(self, events: list[Event]) -> list[Event]: + def condense(self, events: list[Event]) -> View | Condensation: """Replace the content of browser observations outside of the attention window with a placeholder.""" results: list[Event] = [] cnt: int = 0 @@ -36,7 +36,7 @@ class BrowserOutputCondenser(Condenser): if isinstance(event, BrowserOutputObservation): cnt += 1 - return list(reversed(results)) + return View(events=list(reversed(results))) @classmethod def from_config( diff --git a/openhands/memory/condenser/impl/llm_attention_condenser.py b/openhands/memory/condenser/impl/llm_attention_condenser.py index 9a638c071d..8e869a3a08 100644 --- a/openhands/memory/condenser/impl/llm_attention_condenser.py +++ b/openhands/memory/condenser/impl/llm_attention_condenser.py @@ -4,9 +4,13 @@ from litellm import supports_response_schema from pydantic import BaseModel from openhands.core.config.condenser_config import LLMAttentionCondenserConfig -from openhands.events.event import Event +from openhands.events.action.agent import CondensationAction from openhands.llm.llm import LLM -from openhands.memory.condenser.condenser import RollingCondenser +from openhands.memory.condenser.condenser import ( + Condensation, + RollingCondenser, + View, +) class ImportantEventSelection(BaseModel): @@ -43,15 +47,11 @@ class LLMAttentionCondenser(RollingCondenser): super().__init__() - def condense(self, events: list[Event]) -> list[Event]: - """If the history is too long, use an LLM to select the most important events.""" - if len(events) <= self.max_size: - return events - + def get_condensation(self, view: View) -> Condensation: target_size = self.max_size // 2 - head = events[: self.keep_first] + head_event_ids = [event.id for event in view.events[: self.keep_first]] - events_from_tail = target_size - len(head) + events_from_tail = target_size - len(head_event_ids) message: str = """You will be given a list of actions, observations, and thoughts from a coding agent. Each item in the list has an identifier. Please sort the identifiers in order of how important the @@ -66,7 +66,7 @@ class LLMAttentionCondenser(RollingCondenser): 'content': f'{e.id}\n{e.message}', 'role': 'user', } - for e in events + for e in view ], ], response_format={ @@ -82,27 +82,35 @@ class LLMAttentionCondenser(RollingCondenser): response.choices[0].message.content ).ids - self.add_metadata('all_event_ids', [event.id for event in events]) - self.add_metadata('response_ids', response_ids) self.add_metadata('metrics', self.llm.metrics.get()) # Filter out any IDs from the head and trim the results down - head_ids = [event.id for event in head] response_ids = [ - response_id for response_id in response_ids if response_id not in head_ids + response_id + for response_id in response_ids + if response_id not in head_event_ids ][:events_from_tail] # If the response IDs aren't _long_ enough, iterate backwards through the events and add any unfound IDs to the list. - for event in reversed(events): + for event in reversed(view): if len(response_ids) >= events_from_tail: break if event.id not in response_ids: response_ids.append(event.id) - # Grab the events associated with the response IDs - tail = [event for event in events if event.id in response_ids] + # Now that we've found the right number of events to keep, convert this into a list of events to forget. + event = CondensationAction( + forgotten_event_ids=[ + event.id + for event in view + if event.id not in response_ids and event.id not in head_event_ids + ], + ) - return head + tail + return Condensation(action=event) + + def should_condense(self, view: View) -> bool: + return len(view) > self.max_size @classmethod def from_config(cls, config: LLMAttentionCondenserConfig) -> LLMAttentionCondenser: diff --git a/openhands/memory/condenser/impl/llm_summarizing_condenser.py b/openhands/memory/condenser/impl/llm_summarizing_condenser.py index 4edd3247c2..f3ad5791ee 100644 --- a/openhands/memory/condenser/impl/llm_summarizing_condenser.py +++ b/openhands/memory/condenser/impl/llm_summarizing_condenser.py @@ -2,10 +2,14 @@ from __future__ import annotations from openhands.core.config.condenser_config import LLMSummarizingCondenserConfig from openhands.core.message import Message, TextContent -from openhands.events.event import Event +from openhands.events.action.agent import CondensationAction from openhands.events.observation.agent import AgentCondensationObservation from openhands.llm import LLM -from openhands.memory.condenser.condenser import RollingCondenser +from openhands.memory.condenser.condenser import ( + Condensation, + RollingCondenser, + View, +) class LLMSummarizingCondenser(RollingCondenser): @@ -32,26 +36,22 @@ class LLMSummarizingCondenser(RollingCondenser): super().__init__() - def condense(self, events: list[Event]) -> list[Event]: - """Apply the amortized forgetting strategy with LLM summarization to the given list of events.""" - if len(events) <= self.max_size: - return events - - head = events[: self.keep_first] - + def get_condensation(self, view: View) -> Condensation: + head = view[: self.keep_first] target_size = self.max_size // 2 - events_from_tail = target_size - len(head) - tail = events[-events_from_tail:] + # Number of events to keep from the tail -- target size, minus however many + # prefix events from the head, minus one for the summarization event + events_from_tail = target_size - len(head) - 1 summary_event = ( - events[self.keep_first] - if isinstance(events[self.keep_first], AgentCondensationObservation) + view[self.keep_first] + if isinstance(view[self.keep_first], AgentCondensationObservation) else AgentCondensationObservation('No events summarized') ) # Identify events to be forgotten (those not in head or tail) forgotten_events = [] - for event in events[self.keep_first : -events_from_tail]: + for event in view[self.keep_first : -events_from_tail]: if not isinstance(event, AgentCondensationObservation): forgotten_events.append(event) @@ -101,7 +101,17 @@ INTENT: Fix precision while maintaining FITS compliance""" self.add_metadata('response', response.model_dump()) self.add_metadata('metrics', self.llm.metrics.get()) - return head + [AgentCondensationObservation(summary)] + tail + return Condensation( + action=CondensationAction( + forgotten_events_start_id=min(event.id for event in forgotten_events), + forgotten_events_end_id=max(event.id for event in forgotten_events), + summary=summary, + summary_offset=self.keep_first, + ) + ) + + def should_condense(self, view: View) -> bool: + return len(view) > self.max_size @classmethod def from_config( diff --git a/openhands/memory/condenser/impl/no_op_condenser.py b/openhands/memory/condenser/impl/no_op_condenser.py index 0eb73a4386..6ad304d728 100644 --- a/openhands/memory/condenser/impl/no_op_condenser.py +++ b/openhands/memory/condenser/impl/no_op_condenser.py @@ -2,15 +2,15 @@ from __future__ import annotations from openhands.core.config.condenser_config import NoOpCondenserConfig from openhands.events.event import Event -from openhands.memory.condenser.condenser import Condenser +from openhands.memory.condenser.condenser import Condensation, Condenser, View class NoOpCondenser(Condenser): """A condenser that does nothing to the event sequence.""" - def condense(self, events: list[Event]) -> list[Event]: + def condense(self, events: list[Event]) -> View | Condensation: """Returns the list of events unchanged.""" - return events + return View(events=events) @classmethod def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser: diff --git a/openhands/memory/condenser/impl/observation_masking_condenser.py b/openhands/memory/condenser/impl/observation_masking_condenser.py index 17780a224f..f3b9bb652c 100644 --- a/openhands/memory/condenser/impl/observation_masking_condenser.py +++ b/openhands/memory/condenser/impl/observation_masking_condenser.py @@ -4,7 +4,7 @@ from openhands.core.config.condenser_config import ObservationMaskingCondenserCo from openhands.events.event import Event from openhands.events.observation import Observation from openhands.events.observation.agent import AgentCondensationObservation -from openhands.memory.condenser.condenser import Condenser +from openhands.memory.condenser.condenser import Condensation, Condenser, View class ObservationMaskingCondenser(Condenser): @@ -15,7 +15,7 @@ class ObservationMaskingCondenser(Condenser): super().__init__() - def condense(self, events: list[Event]) -> list[Event]: + def condense(self, events: list[Event]) -> View | Condensation: """Replace the content of observations outside of the attention window with a placeholder.""" results: list[Event] = [] for i, event in enumerate(events): @@ -27,7 +27,7 @@ class ObservationMaskingCondenser(Condenser): else: results.append(event) - return results + return View(events=results) @classmethod def from_config( diff --git a/openhands/memory/condenser/impl/recent_events_condenser.py b/openhands/memory/condenser/impl/recent_events_condenser.py index a779048363..5c63b45451 100644 --- a/openhands/memory/condenser/impl/recent_events_condenser.py +++ b/openhands/memory/condenser/impl/recent_events_condenser.py @@ -2,7 +2,7 @@ from __future__ import annotations from openhands.core.config.condenser_config import RecentEventsCondenserConfig from openhands.events.event import Event -from openhands.memory.condenser.condenser import Condenser +from openhands.memory.condenser.condenser import Condensation, Condenser, View class RecentEventsCondenser(Condenser): @@ -14,12 +14,12 @@ class RecentEventsCondenser(Condenser): super().__init__() - def condense(self, events: list[Event]) -> list[Event]: + def condense(self, events: list[Event]) -> View | Condensation: """Keep only the most recent events (up to `max_events`).""" head = events[: self.keep_first] tail_length = max(0, self.max_events - len(head)) tail = events[-tail_length:] - return head + tail + return View(events=head + tail) @classmethod def from_config(cls, config: RecentEventsCondenserConfig) -> RecentEventsCondenser: diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py index 4b81d8babb..01ab5adbdd 100644 --- a/tests/unit/test_codeact_agent.py +++ b/tests/unit/test_codeact_agent.py @@ -324,21 +324,21 @@ def test_mismatched_tool_call_events(mock_state: State): # 2. The action message, and # 3. The observation message mock_state.history = [action, observation] - messages = agent._get_messages(mock_state) + messages = agent._get_messages(mock_state.history) assert len(messages) == 3 # The same should hold if the events are presented out-of-order mock_state.history = [observation, action] - messages = agent._get_messages(mock_state) + messages = agent._get_messages(mock_state.history) assert len(messages) == 3 # If only one of the two events is present, then we should just get the system message mock_state.history = [action] - messages = agent._get_messages(mock_state) + messages = agent._get_messages(mock_state.history) assert len(messages) == 1 mock_state.history = [observation] - messages = agent._get_messages(mock_state) + messages = agent._get_messages(mock_state.history) assert len(messages) == 1 diff --git a/tests/unit/test_condenser.py b/tests/unit/test_condenser.py index 17b569fdec..bf484c7b36 100644 --- a/tests/unit/test_condenser.py +++ b/tests/unit/test_condenser.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import Any, Callable, Iterable from unittest.mock import MagicMock import pytest @@ -22,7 +22,7 @@ from openhands.events.observation.agent import AgentCondensationObservation from openhands.events.observation.observation import Observation from openhands.llm import LLM from openhands.memory.condenser import Condenser -from openhands.memory.condenser.condenser import RollingCondenser +from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View from openhands.memory.condenser.impl import ( AmortizedForgettingCondenser, BrowserOutputCondenser, @@ -98,6 +98,73 @@ def mock_state() -> State: return mock_state +class RollingCondenserTestHarness: + """Test harness for rolling condensers. + + Simulates the behavior of a simple agent loop (appropriately handling the distinction between `View` and `Condensation` results) and provides utilities for testing the results. + """ + + def __init__(self, condenser: RollingCondenser): + self.condenser = condenser + self.callbacks: list[Callable[[list[Event]], None]] = [] + + def add_callback(self, callback: Callable[[list[Event]], None]): + """Add a callback to the test harness. + + This callback will be called on the history after each event is added, but before the condenser is applied. You can use this to export information about the event that was just added, or to set LLM responses based on the state. + """ + self.callbacks.append(callback) + + def views(self, events: Iterable[Event]) -> Iterable[View]: + """Generate a sequence of views similating the condenser's behavior over the given event stream. + + This generator assumes we're starting from an empty history. + """ + mock_state = MagicMock() + mock_state.extra_data = {} + mock_state.history = [] + + for event in events: + mock_state.history.append(event) + for callback in self.callbacks: + callback(mock_state.history) + + match self.condenser.condensed_history(mock_state): + case View() as view: + yield view + + case Condensation(event=condensation_event): + mock_state.history.append(condensation_event) + + def expected_size(self, index: int, max_size: int) -> int: + """Calculate the expected size of the view at the given index. + + Assumes the condenser triggers condensation when the view is _longer_ than the max size, and that the target size is half the max size. + """ + # Until we hit the max size, the views should grow monotonically. + if index < max_size: + return index + 1 + + # Once we hit the max size, the next view should be reduced to the target size. + target_size = max_size // 2 + + # So when the index is the same as the max size, we should have target size + 1 events in the view. + # And the maximum value we will ever see is the max size (approximately 2 * target size). + # Put together, we get the following formula: + return ((index - max_size) % target_size) + target_size + 1 + + def expected_condensations(self, index: int, max_size: int) -> int: + """Calculate the expected number of condensation events at the given index. + + Assumes the condenser triggers condensation when the view is _longer_ than the max size, and that the target size is half the max size. + """ + if index < max_size: + return 0 + + target_size = max_size // 2 + return ((index - max_size) // target_size) + 1 + + def test_noop_condenser_from_config(): """Test that the NoOpCondenser objects can be made from config.""" config = NoOpCondenserConfig() @@ -120,7 +187,7 @@ def test_noop_condenser(): condenser = NoOpCondenser() result = condenser.condensed_history(mock_state) - assert result == events + assert result == View(events=events) def test_observation_masking_condenser_from_config(): @@ -231,7 +298,7 @@ def test_recent_events_condenser(): condenser = RecentEventsCondenser(max_events=len(events)) result = condenser.condensed_history(mock_state) - assert result == events + assert result == View(events=events) # If the max_events are smaller than the number of events, only keep the last few. max_events = 3 @@ -265,7 +332,7 @@ def test_recent_events_condenser(): assert result[2]._message == 'Event 5' # kept from max_events -def test_llm_summarization_condenser_from_config(): +def test_llm_summarizing_condenser_from_config(): """Test that LLMSummarizingCondenser objects can be made from config.""" config = LLMSummarizingCondenserConfig( max_size=50, @@ -284,7 +351,7 @@ def test_llm_summarization_condenser_from_config(): assert condenser.keep_first == 10 -def test_llm_amortized_summarization_condenser_invalid_config(): +def test_llm_summarizing_condenser_invalid_config(): """Test that LLMSummarizingCondenser raises error when keep_first > max_size.""" pytest.raises( ValueError, @@ -297,135 +364,49 @@ def test_llm_amortized_summarization_condenser_invalid_config(): pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), keep_first=-1) -def test_llm_summarizing_condenser_grows_to_max_size(mock_llm, mock_state): - """Test that LLMSummarizingCondenser correctly maintains an event context up to max size.""" - max_size = 15 +def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm): + """Test that LLMSummarizingCondenser maintains the correct view size.""" + max_size = 10 condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm) - for i in range(max_size): - event = create_test_event(f'Event {i}') - mock_state.history.append(event) - results = condenser.condensed_history(mock_state) - assert len(results) == i + 1 - - -def test_llm_summarizing_condenser_forgets_and_summarizes(mock_llm, mock_state): - """Test that the LLMSummarizingCondenser forgets events and maintains a summary.""" - max_size = 4 - keep_first = 1 - condenser = LLMSummarizingCondenser( - max_size=max_size, keep_first=keep_first, llm=mock_llm - ) - - # Add initial event - first_event = create_test_event('Event 0') - mock_state.history.append(first_event) + events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] # Set up mock LLM response mock_llm.set_mock_response_content('Summary of forgotten events') - # Add enough events to trigger forgetting - for i in range(max_size + 3): # +3 to ensure we're well past max_size - event = create_test_event(f'Event {i+1}') - mock_state.history.append(event) + harness = RollingCondenserTestHarness(condenser) - # Get the condensed history - results = condenser.condensed_history(mock_state) - - # We should have exactly 3 events: - # 1. First event (keep_first = 1) - # 2. Summary event - # 3. Most recent event - assert len(results) == 3, f'Expected 3 events, got {len(results)}: {results}' - assert ( - results[0] == first_event - ), f'First event should be {first_event}, got {results[0]}' - assert isinstance( - results[1], AgentCondensationObservation - ), f'Second event should be a summary, got {results[1]}' - assert ( - results[1].content == 'Summary of forgotten events' - ), f"Summary content should be 'Summary of forgotten events', got {results[1].content}" - assert results[2] == event, f'Last event should be {event}, got {results[2]}' + for i, view in enumerate(harness.views(events)): + assert len(view) == harness.expected_size(i, max_size) -def test_llm_summarizing_condenser_llm_call(mock_llm, mock_state): - """Test that the LLM is called correctly when forgetting events.""" - max_size = 4 - keep_first = 1 +def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm, mock_state): + """Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events.""" + max_size = 10 + keep_first = 3 condenser = LLMSummarizingCondenser( max_size=max_size, keep_first=keep_first, llm=mock_llm ) - # Add initial event - first_event = create_test_event('Event 0') - mock_state.history.append(first_event) - - # Set up mock LLM response - mock_llm.set_mock_response_content('Summary of forgotten events') - mock_llm.metrics.get.return_value = {'test_metric': 1.0} - - # Add enough events to trigger forgetting - for i in range(max_size): - event = create_test_event(f'Event {i+1}') - mock_state.history.append(event) - condenser.condensed_history(mock_state) - - # Verify LLM was called with correct prompt - mock_llm.completion.assert_called_once() - call_args = mock_llm.completion.call_args[1] - assert 'messages' in call_args - assert len(call_args['messages']) == 1 - - # Verify metrics were added to state - assert 'condenser_meta' in mock_state.extra_data - assert len(mock_state.extra_data['condenser_meta']) == 1 - assert mock_state.extra_data['condenser_meta'][0]['metrics'] == {'test_metric': 1.0} - - -def test_llm_summarizing_condenser_resets_when_given_truncated_history( - mock_llm, mock_state -): - """Test that the condenser, when it sees a shorter history than it has in the past (due to truncation), will reset its tracking.""" - max_size = 4 - keep_first = 1 - condenser = LLMSummarizingCondenser( - max_size=max_size, keep_first=keep_first, llm=mock_llm - ) - - # Add initial event - first_event = create_test_event('Event 0') - mock_state.history.append(first_event) - - # Set up mock LLM response mock_llm.set_mock_response_content('Summary of forgotten events') - # Add enough events to trigger forgetting - for i in range(max_size + 3): # +3 to ensure we're well past max_size - event = create_test_event(f'Event {i+1}') - mock_state.history.append(event) + events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] + harness = RollingCondenserTestHarness(condenser) - # Get the condensed history - results = condenser.condensed_history(mock_state) + for i, view in enumerate(harness.views(events)): + assert len(view) == harness.expected_size(i, max_size) - # We should have exactly 3 events: - # 1. First event (keep_first = 1) - # 2. Summary event - # 3. Most recent event - assert len(results) == 3, f'Expected 3 events, got {len(results)}: {results}' + # Ensure that the we've called out the summarizing LLM once per condensation + assert mock_llm.completion.call_count == harness.expected_condensations( + i, max_size + ) - # Now, call condensation on a small history that contains only two events. - alternate_history = [ - create_test_event('Alt. Event 0'), - create_test_event('Alt. Event 1'), - ] - mock_state.history = alternate_history + # Ensure that the prefix is appropiately maintained + assert view[:keep_first] == events[: min(keep_first, i + 1)] - # When we do this, the condenser should start tracking the alternative history - # as the de-facto history. That means we lose the summarization event and any - # other events that were in the previous history. - results = condenser.condensed_history(mock_state) - assert results == alternate_history + # If we've condensed, ensure that the summary event is present + if i > max_size: + assert isinstance(view[keep_first], AgentCondensationObservation) def test_amortized_forgetting_condenser_from_config(): @@ -449,69 +430,46 @@ def test_amortized_forgetting_condenser_invalid_config(): pytest.raises(ValueError, AmortizedForgettingCondenser, keep_first=-1) -def test_amortized_forgetting_condenser_grows_to_max_size(): - """Test that AmortizedForgettingCondenser correctly maintains an event context up to max size.""" - max_size = 15 +def test_amortized_forgetting_condenser_gives_expected_view_size(): + """Test that AmortizedForgettingCondenser maintains a context view of the correct size.""" + max_size = 12 condenser = AmortizedForgettingCondenser(max_size=max_size) - mock_state = MagicMock() - mock_state.extra_data = {} - mock_state.history = [] + events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] - for i in range(max_size): - event = create_test_event(f'Event {i}') - mock_state.history.append(event) - results = condenser.condensed_history(mock_state) - assert len(results) == i + 1 + harness = RollingCondenserTestHarness(condenser) + + for i, view in enumerate(harness.views(events)): + assert len(view) == harness.expected_size(i, max_size) -def test_amortized_forgetting_condenser_forgets_when_larger_than_max_size(): - """Test that the AmortizedForgettingCondenser forgets events when the context grows too large.""" - max_size = 2 - condenser = AmortizedForgettingCondenser(max_size=max_size) - - mock_state = MagicMock() - mock_state.extra_data = {} - mock_state.history = [] - - for i in range(max_size * 10): - event = create_test_event(f'Event {i}') - mock_state.history.append(event) - results = condenser.condensed_history(mock_state) - - # The last event in the results is always the event we just added. - assert results[-1] == event - - # The number of results should bounce back and forth between 1, 2, 1, 2, ... - assert len(results) == (i % 2) + 1 - - -def test_amortized_forgetting_condenser_keeps_first_events(): - """Test that the AmortizedForgettingCondenser keeps the right number of initial events when forgetting.""" - max_size = 4 - keep_first = 1 +def test_amortized_forgetting_condenser_keeps_first_and_last_events(): + """Test that the AmortizedForgettingCondenser keeps the prefix and suffix events, even when condensing.""" + max_size = 12 + keep_first = 4 condenser = AmortizedForgettingCondenser(max_size=max_size, keep_first=keep_first) - first_event = create_test_event('Event 0') + events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] - mock_state = MagicMock() - mock_state.extra_data = {} - mock_state.history = [first_event] + # To ensure the most recent event is always recorded, track it in a non-local variable udpated + # with a closure we'll pass to the view generator as a callback. + most_recent_event: Event | None = None - for i in range(max_size * 10): - event = create_test_event(f'Event {i+1}', datetime(2024, 1, 1, 10, i + 1)) - mock_state.history.append(event) - results = condenser.condensed_history(mock_state) + def set_most_recent_event(history: list[Event]): + nonlocal most_recent_event + most_recent_event = history[-1] - # The last event is always the event we just added. - assert results[-1] == event + harness = RollingCondenserTestHarness(condenser) + harness.add_callback(set_most_recent_event) - # The first event is always the first event. - assert results[0] == first_event + for i, view in enumerate(harness.views(events)): + assert len(view) == harness.expected_size(i, max_size) - # The number of results should bounce back between 2, 3, 4, 2, 3, 4, ... - print(len(results)) - assert len(results) == (i % 3) + 2 + # The last event should always be the most-recently added. + assert view[-1] == most_recent_event + + # The prefix should always match the list of events, up to the keep_first limit. + assert view[:keep_first] == events[: min(keep_first, i + 1)] def test_llm_attention_condenser_from_config(): @@ -547,128 +505,46 @@ def test_llm_attention_condenser_invalid_config(): pytest.raises(ValueError, LLMAttentionCondenser.from_config, config) -def test_rolling_condenser_handles_truncation(mock_state: State): - """Test that RollingCondenser correctly handles history truncation.""" - - class TestRollingCondenser(RollingCondenser): - """Test implementation of RollingCondenser that just returns all events.""" - - def condense(self, events: list[Event]) -> list[Event]: - return events - - condenser = TestRollingCondenser() - - # Initial history with 3 events - events = [ - create_test_event('Event 1', id=1), - create_test_event('Event 2', id=2), - create_test_event('Event 3', id=3), - ] - mock_state.history = events - - # First condensation - should return all events - results = condenser.condensed_history(mock_state) - assert len(results) == 3 - assert [e._id for e in results] == [1, 2, 3] - - # Simulate truncation - history is now shorter, and the condensation should - # just include the truncated history - mock_state.history = mock_state.history[-1:] - - results = condenser.condensed_history(mock_state) - assert len(results) == 1 - assert results[0]._id == 3 - - # Adding more events and condensing should "rebase" us from the truncated history - mock_state.history += [ - create_test_event('Event 4', id=4), - create_test_event('Event 5', id=5), - ] - - results = condenser.condensed_history(mock_state) - assert len(results) == 3 - assert [e._id for e in results] == [3, 4, 5] - - -def test_llm_attention_condenser_keeps_first_events(mock_llm, mock_state): - """Test that the LLMAttentionCondenser keeps the right number of initial events when forgetting.""" - max_size = 4 - condenser = LLMAttentionCondenser(max_size=max_size, keep_first=1, llm=mock_llm) - - first_event = create_test_event('Event 0', id=0) - mock_state.history.append(first_event) - - for i in range(max_size * 10): - event = create_test_event(f'Event {i+1}', id=i + 1) - mock_state.history.append(event) - - mock_llm.set_mock_response_content( - ImportantEventSelection( - ids=[event.id for event in mock_state.history] - ).model_dump_json() - ) - results = condenser.condensed_history(mock_state) - - # The first event is always the first event. - assert results[0] == first_event - - -def test_llm_attention_condenser_grows_to_max_size(mock_llm, mock_state): - """Test that LLMAttentionCondenser correctly maintains an event context up to max size.""" - max_size = 15 - condenser = LLMAttentionCondenser(max_size=max_size, llm=mock_llm) - - for i in range(max_size): - event = create_test_event(f'Event {i}') - mock_state.history.append(event) - mock_llm.set_mock_response_content( - ImportantEventSelection(ids=[event.id for event in mock_state.history]) - ) - results = condenser.condensed_history(mock_state) - assert len(results) == i + 1 - - -def test_llm_attention_condenser_forgets_when_larger_than_max_size( - mock_llm, mock_state -): - """Test that the LLMAttentionCondenser forgets events when the context grows too large.""" - max_size = 2 +def test_llm_attention_condenser_gives_expected_view_size(mock_llm): + """Test that the LLMAttentionCondenser gives views of the expected size.""" + max_size = 10 condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) - for i in range(max_size * 10): - event = create_test_event(f'Event {i}', id=i) - mock_state.history.append(event) + events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] + def set_response_content(history: list[Event]): mock_llm.set_mock_response_content( ImportantEventSelection( - ids=[event.id for event in mock_state.history] + ids=[event.id for event in history] ).model_dump_json() ) - results = condenser.condensed_history(mock_state) + harness = RollingCondenserTestHarness(condenser) + harness.add_callback(set_response_content) - # The number of results should bounce back and forth between 1, 2, 1, 2, ... - assert len(results) == (i % 2) + 1 + for i, view in enumerate(harness.views(events)): + assert len(view) == harness.expected_size(i, max_size) -def test_llm_attention_condenser_handles_events_outside_history(mock_llm, mock_state): +def test_llm_attention_condenser_handles_events_outside_history(mock_llm): """Test that the LLMAttentionCondenser handles event IDs that aren't from the event history.""" max_size = 2 condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) - for i in range(max_size * 10): - event = create_test_event(f'Event {i}', id=i) - mock_state.history.append(event) + events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] + def set_response_content(history: list[Event]): mock_llm.set_mock_response_content( ImportantEventSelection( - ids=[event.id for event in mock_state.history] + [-1, -2, -3, -4] + ids=[event.id for event in history] + [-1, -2, -3, -4] ).model_dump_json() ) - results = condenser.condensed_history(mock_state) - # The number of results should bounce back and forth between 1, 2, 1, 2, ... - assert len(results) == (i % 2) + 1 + harness = RollingCondenserTestHarness(condenser) + harness.add_callback(set_response_content) + + for i, view in enumerate(harness.views(events)): + assert len(view) == harness.expected_size(i, max_size) def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_state): @@ -676,67 +552,61 @@ def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_state): max_size = 2 condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) - for i in range(max_size * 10): - event = create_test_event(f'Event {i}', id=i) - mock_state.history.append(event) + events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] + + def set_response_content(history: list[Event]): mock_llm.set_mock_response_content( ImportantEventSelection( - ids=[event.id for event in mock_state.history] - + [event.id for event in mock_state.history] + ids=[event.id for event in history] + [event.id for event in history] ).model_dump_json() ) - results = condenser.condensed_history(mock_state) - # The number of results should bounce back and forth between 1, 2, 1, 2, ... - assert len(results) == (i % 2) + 1 + harness = RollingCondenserTestHarness(condenser) + harness.add_callback(set_response_content) + + for i, view in enumerate(harness.views(events)): + assert len(view) == harness.expected_size(i, max_size) -def test_llm_attention_condenser_handles_too_few_events(mock_llm, mock_state): +def test_llm_attention_condenser_handles_too_few_events(mock_llm): """Test that the LLMAttentionCondenser handles when the response contains too few event IDs.""" max_size = 2 # Developer note: We must specify keep_first=0 because # keep_first (1) >= max_size//2 (1) is invalid. condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) - for i in range(max_size * 10): - event = create_test_event(f'Event {i}', id=i) - mock_state.history.append(event) + events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] + def set_response_content(history: list[Event]): mock_llm.set_mock_response_content( ImportantEventSelection(ids=[]).model_dump_json() ) - results = condenser.condensed_history(mock_state) + harness = RollingCondenserTestHarness(condenser) + harness.add_callback(set_response_content) - # The number of results should bounce back and forth between 1, 2, 1, 2, ... - assert len(results) == (i % 2) + 1 - - # Add a new test verifying that keep_first=1 works with max_size > 2 + for i, view in enumerate(harness.views(events)): + assert len(view) == harness.expected_size(i, max_size) -def test_llm_attention_condenser_handles_keep_first_for_larger_max_size( - mock_llm, mock_state -): +def test_llm_attention_condenser_handles_keep_first_events(mock_llm): """Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size).""" - max_size = 4 # so keep_first=1 < (max_size // 2) = 2 - condenser = LLMAttentionCondenser(max_size=max_size, keep_first=1, llm=mock_llm) + max_size = 12 + keep_first = 4 + condenser = LLMAttentionCondenser( + max_size=max_size, keep_first=keep_first, llm=mock_llm + ) - for i in range(max_size * 2): - # We append new events, then ensure some are pruned. - event = create_test_event(f'Event {i}', id=i) - mock_state.history.append(event) + events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] + def set_response_content(history: list[Event]): mock_llm.set_mock_response_content( ImportantEventSelection(ids=[]).model_dump_json() ) - results = condenser.condensed_history(mock_state) + harness = RollingCondenserTestHarness(condenser) + harness.add_callback(set_response_content) - # We expect that the first event is always kept, and the tail grows until max_size - if len(mock_state.history) <= max_size: - # No condensation needed yet - assert len(results) == len(mock_state.history) - else: - # The first event is kept, plus some from the tail - assert results[0].id == 0 - assert len(results) <= max_size + for i, view in enumerate(harness.views(events)): + assert len(view) == harness.expected_size(i, max_size) + assert view[:keep_first] == events[: min(keep_first, i + 1)] diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 193363fd4b..fdb9f1f2fb 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -1,5 +1,3 @@ -from unittest.mock import Mock - import pytest from litellm import ModelResponse @@ -74,9 +72,7 @@ def test_get_messages(codeact_agent: CodeActAgent): history.append(message_action_5) codeact_agent.reset() - messages = codeact_agent._get_messages( - Mock(history=history, max_iterations=5, iteration=0, extra_data={}) - ) + messages = codeact_agent._get_messages(history) assert ( len(messages) == 6 @@ -110,9 +106,7 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent): history.append(message_action_agent) codeact_agent.reset() - messages = codeact_agent._get_messages( - Mock(history=history, max_iterations=10, iteration=5, extra_data={}) - ) + messages = codeact_agent._get_messages(history) # Check that only the last two user messages have cache_prompt=True cached_user_messages = [