mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
feat: Structured summary generation for history condensation (#7696)
Co-authored-by: Calvin Smith <calvin@all-hands.dev>
This commit is contained in:
parent
a4ebb5bf85
commit
f74243542d
@ -126,6 +126,33 @@ class LLMAttentionCondenserConfig(BaseModel):
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
|
||||
class StructuredSummaryCondenserConfig(BaseModel):
|
||||
"""Configuration for StructuredSummaryCondenser instances."""
|
||||
|
||||
type: Literal['structured'] = Field('structured')
|
||||
llm_config: LLMConfig = Field(
|
||||
..., description='Configuration for the LLM to use for condensing.'
|
||||
)
|
||||
|
||||
# at least one event by default, because the best guess is that it's the user task
|
||||
keep_first: int = Field(
|
||||
default=1,
|
||||
description='Number of initial events to always keep in history.',
|
||||
ge=0,
|
||||
)
|
||||
max_size: int = Field(
|
||||
default=100,
|
||||
description='Maximum size of the condensed history before triggering forgetting.',
|
||||
ge=2,
|
||||
)
|
||||
max_event_length: int = Field(
|
||||
default=10_000,
|
||||
description='Maximum length of the event representations to be passed to the LLM.',
|
||||
)
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
|
||||
# Type alias for convenience
|
||||
CondenserConfig = (
|
||||
NoOpCondenserConfig
|
||||
@ -135,6 +162,7 @@ CondenserConfig = (
|
||||
| LLMSummarizingCondenserConfig
|
||||
| AmortizedForgettingCondenserConfig
|
||||
| LLMAttentionCondenserConfig
|
||||
| StructuredSummaryCondenserConfig
|
||||
)
|
||||
|
||||
|
||||
@ -237,6 +265,7 @@ def create_condenser_config(condenser_type: str, data: dict) -> CondenserConfig:
|
||||
'llm': LLMSummarizingCondenserConfig,
|
||||
'amortized': AmortizedForgettingCondenserConfig,
|
||||
'llm_attention': LLMAttentionCondenserConfig,
|
||||
'structured': StructuredSummaryCondenserConfig,
|
||||
}
|
||||
|
||||
if condenser_type not in condenser_classes:
|
||||
|
||||
@ -18,6 +18,9 @@ from openhands.memory.condenser.impl.observation_masking_condenser import (
|
||||
from openhands.memory.condenser.impl.recent_events_condenser import (
|
||||
RecentEventsCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.structured_summary_condenser import (
|
||||
StructuredSummaryCondenser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'AmortizedForgettingCondenser',
|
||||
@ -28,4 +31,5 @@ __all__ = [
|
||||
'ObservationMaskingCondenser',
|
||||
'BrowserOutputCondenser',
|
||||
'RecentEventsCondenser',
|
||||
'StructuredSummaryCondenser',
|
||||
]
|
||||
|
||||
322
openhands/memory/condenser/impl/structured_summary_condenser.py
Normal file
322
openhands/memory/condenser/impl/structured_summary_condenser.py
Normal file
@ -0,0 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.core.config.condenser_config import (
|
||||
StructuredSummaryCondenserConfig,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.llm import LLM
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condensation,
|
||||
RollingCondenser,
|
||||
View,
|
||||
)
|
||||
|
||||
|
||||
class StateSummary(BaseModel):
|
||||
"""A structured representation summarizing the state of the agent and the task."""
|
||||
|
||||
# Required core fields
|
||||
user_context: str = Field(
|
||||
default='',
|
||||
description='Essential user requirements, goals, and clarifications in concise form.',
|
||||
)
|
||||
completed_tasks: str = Field(
|
||||
default='', description='List of tasks completed so far with brief results.'
|
||||
)
|
||||
pending_tasks: str = Field(
|
||||
default='', description='List of tasks that still need to be done.'
|
||||
)
|
||||
current_state: str = Field(
|
||||
default='',
|
||||
description='Current variables, data structures, or other relevant state information.',
|
||||
)
|
||||
|
||||
# Code state fields
|
||||
files_modified: str = Field(
|
||||
default='', description='List of files that have been created or modified.'
|
||||
)
|
||||
function_changes: str = Field(
|
||||
default='', description='List of functions that have been created or modified.'
|
||||
)
|
||||
data_structures: str = Field(
|
||||
default='', description='List of key data structures in use or modified.'
|
||||
)
|
||||
|
||||
# Test status fields
|
||||
tests_written: str = Field(
|
||||
default='',
|
||||
description='Whether tests have been written for the changes. True, false, or unknown.',
|
||||
)
|
||||
tests_passing: str = Field(
|
||||
default='',
|
||||
description='Whether all tests are currently passing. True, false, or unknown.',
|
||||
)
|
||||
failing_tests: str = Field(
|
||||
default='', description='List of names or descriptions of any failing tests.'
|
||||
)
|
||||
error_messages: str = Field(
|
||||
default='', description='List of key error messages encountered.'
|
||||
)
|
||||
|
||||
# Version control fields
|
||||
branch_created: str = Field(
|
||||
default='',
|
||||
description='Whether a branch has been created for this work. True, false, or unknown.',
|
||||
)
|
||||
branch_name: str = Field(
|
||||
default='', description='Name of the current working branch if known.'
|
||||
)
|
||||
commits_made: str = Field(
|
||||
default='',
|
||||
description='Whether any commits have been made. True, false, or unknown.',
|
||||
)
|
||||
pr_created: str = Field(
|
||||
default='',
|
||||
description='Whether a pull request has been created. True, false, or unknown.',
|
||||
)
|
||||
pr_status: str = Field(
|
||||
default='',
|
||||
description="Status of any pull request: 'draft', 'open', 'merged', 'closed', or 'unknown'.",
|
||||
)
|
||||
|
||||
# Other fields
|
||||
dependencies: str = Field(
|
||||
default='',
|
||||
description='List of dependencies or imports that have been added or modified.',
|
||||
)
|
||||
other_relevant_context: str = Field(
|
||||
default='',
|
||||
description="Any other important information that doesn't fit into the categories above.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tool_description(cls) -> dict[str, Any]:
|
||||
"""Description of a tool whose arguments are the fields of this class.
|
||||
|
||||
Can be given to an LLM to force structured generation.
|
||||
"""
|
||||
properties = {}
|
||||
|
||||
# Build properties dictionary from field information
|
||||
for field_name, field in cls.model_fields.items():
|
||||
description = field.description or ''
|
||||
|
||||
properties[field_name] = {'type': 'string', 'description': description}
|
||||
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'create_state_summary',
|
||||
'description': 'Creates a comprehensive summary of the current state of the interaction to preserve context when history grows too large. You must include non-empty values for user_context, completed_tasks, and pending_tasks.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': properties,
|
||||
'required': ['user_context', 'completed_tasks', 'pending_tasks'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Format the state summary in a clear way for Claude 3.7 Sonnet."""
|
||||
sections = [
|
||||
'# State Summary',
|
||||
'## Core Information',
|
||||
f'**User Context**: {self.user_context}',
|
||||
f'**Completed Tasks**: {self.completed_tasks}',
|
||||
f'**Pending Tasks**: {self.pending_tasks}',
|
||||
f'**Current State**: {self.current_state}',
|
||||
'## Code Changes',
|
||||
f'**Files Modified**: {self.files_modified}',
|
||||
f'**Function Changes**: {self.function_changes}',
|
||||
f'**Data Structures**: {self.data_structures}',
|
||||
f'**Dependencies**: {self.dependencies}',
|
||||
'## Testing Status',
|
||||
f'**Tests Written**: {self.tests_written}',
|
||||
f'**Tests Passing**: {self.tests_passing}',
|
||||
f'**Failing Tests**: {self.failing_tests}',
|
||||
f'**Error Messages**: {self.error_messages}',
|
||||
'## Version Control',
|
||||
f'**Branch Created**: {self.branch_created}',
|
||||
f'**Branch Name**: {self.branch_name}',
|
||||
f'**Commits Made**: {self.commits_made}',
|
||||
f'**PR Created**: {self.pr_created}',
|
||||
f'**PR Status**: {self.pr_status}',
|
||||
'## Additional Context',
|
||||
f'**Other Relevant Context**: {self.other_relevant_context}',
|
||||
]
|
||||
|
||||
# Join all sections with double newlines
|
||||
return '\n\n'.join(sections)
|
||||
|
||||
|
||||
class StructuredSummaryCondenser(RollingCondenser):
|
||||
"""A condenser that summarizes forgotten events.
|
||||
|
||||
Maintains a condensed history and forgets old events when it grows too large. Uses structured generation via function-calling to produce summaries that replace forgotten events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
max_size: int = 100,
|
||||
keep_first: int = 1,
|
||||
max_event_length: int = 10_000,
|
||||
):
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
)
|
||||
if keep_first < 0:
|
||||
raise ValueError(f'keep_first ({keep_first}) cannot be negative')
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
||||
|
||||
if not llm.is_function_calling_active():
|
||||
raise ValueError(
|
||||
'LLM must support function calling to use StructuredSummaryCondenser'
|
||||
)
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
self.max_event_length = max_event_length
|
||||
self.llm = llm
|
||||
|
||||
super().__init__()
|
||||
|
||||
def _truncate(self, content: str) -> str:
|
||||
"""Truncate the content to fit within the specified maximum event length."""
|
||||
return truncate_content(content, max_chars=self.max_event_length)
|
||||
|
||||
def get_condensation(self, view: View) -> Condensation:
|
||||
head = view[: self.keep_first]
|
||||
target_size = self.max_size // 2
|
||||
# Number of events to keep from the tail -- target size, minus however many
|
||||
# prefix events from the head, minus one for the summarization event
|
||||
events_from_tail = target_size - len(head) - 1
|
||||
|
||||
summary_event = (
|
||||
view[self.keep_first]
|
||||
if isinstance(view[self.keep_first], AgentCondensationObservation)
|
||||
else AgentCondensationObservation('No events summarized')
|
||||
)
|
||||
|
||||
# Identify events to be forgotten (those not in head or tail)
|
||||
forgotten_events = []
|
||||
for event in view[self.keep_first : -events_from_tail]:
|
||||
if not isinstance(event, AgentCondensationObservation):
|
||||
forgotten_events.append(event)
|
||||
|
||||
# Construct prompt for summarization
|
||||
prompt = """You are maintaining a context-aware state summary for an interactive software agent. This summary is critical because it:
|
||||
1. Preserves essential context when conversation history grows too large
|
||||
2. Prevents lost work when the session length exceeds token limits
|
||||
3. Helps maintain continuity across multiple interactions
|
||||
|
||||
You will be given:
|
||||
- A list of events (actions taken by the agent)
|
||||
- The most recent previous summary (if one exists)
|
||||
|
||||
Capture all relevant information, especially:
|
||||
- User requirements that were explicitly stated
|
||||
- Work that has been completed
|
||||
- Tasks that remain pending
|
||||
- Current state of code, variables, and data structures
|
||||
- The status of any version control operations"""
|
||||
|
||||
prompt += '\n\n'
|
||||
|
||||
# Add the previous summary if it exists. We'll always have a summary
|
||||
# event, but the types aren't precise enought to guarantee that it has a
|
||||
# message attribute.
|
||||
summary_event_content = self._truncate(
|
||||
summary_event.message if summary_event.message else ''
|
||||
)
|
||||
prompt += f'<PREVIOUS SUMMARY>\n{summary_event_content}\n</PREVIOUS SUMMARY>\n'
|
||||
|
||||
prompt += '\n\n'
|
||||
|
||||
# Add all events that are being forgotten. We use the string
|
||||
# representation defined by the event, and truncate it if necessary.
|
||||
for forgotten_event in forgotten_events:
|
||||
event_content = self._truncate(str(forgotten_event))
|
||||
prompt += f'<EVENT id={forgotten_event.id}>\n{event_content}\n</EVENT>\n'
|
||||
|
||||
messages = [Message(role='user', content=[TextContent(text=prompt)])]
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=self.llm.format_messages_for_llm(messages),
|
||||
tools=[StateSummary.tool_description()],
|
||||
tool_choice={
|
||||
'type': 'function',
|
||||
'function': {'name': 'create_state_summary'},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract the message containing tool calls
|
||||
message = response.choices[0].message
|
||||
|
||||
# Check if there are tool calls
|
||||
if not hasattr(message, 'tool_calls') or not message.tool_calls:
|
||||
raise ValueError('No tool calls found in response')
|
||||
|
||||
# Find the create_state_summary tool call
|
||||
summary_tool_call = None
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.function.name == 'create_state_summary':
|
||||
summary_tool_call = tool_call
|
||||
break
|
||||
|
||||
if not summary_tool_call:
|
||||
raise ValueError('create_state_summary tool call not found')
|
||||
|
||||
# Parse the arguments
|
||||
args_json = summary_tool_call.function.arguments
|
||||
args_dict = json.loads(args_json)
|
||||
|
||||
# Create a StateSummary object
|
||||
summary = StateSummary.model_validate(args_dict)
|
||||
|
||||
except (ValueError, AttributeError, KeyError, json.JSONDecodeError) as e:
|
||||
logger.warning(
|
||||
f'Failed to parse summary tool call: {e}. Using empty summary.'
|
||||
)
|
||||
summary = StateSummary()
|
||||
|
||||
self.add_metadata('response', response.model_dump())
|
||||
self.add_metadata('metrics', self.llm.metrics.get())
|
||||
|
||||
return Condensation(
|
||||
action=CondensationAction(
|
||||
forgotten_events_start_id=min(event.id for event in forgotten_events),
|
||||
forgotten_events_end_id=max(event.id for event in forgotten_events),
|
||||
summary=str(summary),
|
||||
summary_offset=self.keep_first,
|
||||
)
|
||||
)
|
||||
|
||||
def should_condense(self, view: View) -> bool:
|
||||
return len(view) > self.max_size
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: StructuredSummaryCondenserConfig
|
||||
) -> StructuredSummaryCondenser:
|
||||
return StructuredSummaryCondenser(
|
||||
llm=LLM(config=config.llm_config),
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
max_event_length=config.max_event_length,
|
||||
)
|
||||
|
||||
|
||||
StructuredSummaryCondenser.register_config(StructuredSummaryCondenserConfig)
|
||||
@ -9,6 +9,7 @@ from openhands.controller.agent import Agent
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.config.condenser_config import (
|
||||
LLMSummarizingCondenserConfig,
|
||||
StructuredSummaryCondenserConfig,
|
||||
)
|
||||
from openhands.core.logger import OpenHandsLoggerAdapter
|
||||
from openhands.core.schema import AgentState
|
||||
@ -19,7 +20,6 @@ from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
@ -128,9 +128,21 @@ class Session:
|
||||
agent_config = self.config.get_agent_config(agent_cls)
|
||||
|
||||
if settings.enable_default_condenser:
|
||||
default_condenser_config = LLMSummarizingCondenserConfig(
|
||||
llm_config=llm.config, keep_first=3, max_size=40
|
||||
)
|
||||
# If function-calling is active we can use the structured summary
|
||||
# condenser for more reliable summaries.
|
||||
if llm.is_function_calling_active():
|
||||
default_condenser_config = StructuredSummaryCondenserConfig(
|
||||
llm_config=llm.config, keep_first=3, max_size=80
|
||||
)
|
||||
|
||||
# Otherwise, we'll fall back to the unstructured summary condenser.
|
||||
# This is a good default but struggles more than the structured
|
||||
# summary condenser with long messages.
|
||||
else:
|
||||
default_condenser_config = LLMSummarizingCondenserConfig(
|
||||
llm_config=llm.config, keep_first=3, max_size=80
|
||||
)
|
||||
|
||||
self.logger.info(f'Enabling default condenser: {default_condenser_config}')
|
||||
agent_config.condenser = default_condenser_config
|
||||
|
||||
@ -200,7 +212,7 @@ class Session:
|
||||
await self.send(event_to_dict(event))
|
||||
# NOTE: ipython observations are not sent here currently
|
||||
elif event.source == EventSource.ENVIRONMENT and isinstance(
|
||||
event, (CmdOutputObservation, AgentStateChangedObservation, RecallObservation)
|
||||
event, (CmdOutputObservation, AgentStateChangedObservation)
|
||||
):
|
||||
# feedback from the environment to agent actions is understood as agent events by the UI
|
||||
event_dict = event_to_dict(event)
|
||||
|
||||
@ -13,6 +13,7 @@ from openhands.core.config.condenser_config import (
|
||||
NoOpCondenserConfig,
|
||||
ObservationMaskingCondenserConfig,
|
||||
RecentEventsCondenserConfig,
|
||||
StructuredSummaryCondenserConfig,
|
||||
)
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.message import Message, TextContent
|
||||
@ -32,6 +33,7 @@ from openhands.memory.condenser.impl import (
|
||||
NoOpCondenser,
|
||||
ObservationMaskingCondenser,
|
||||
RecentEventsCondenser,
|
||||
StructuredSummaryCondenser,
|
||||
)
|
||||
|
||||
|
||||
@ -85,6 +87,8 @@ def mock_llm() -> LLM:
|
||||
Message(role='user', content=[TextContent(text=str(event))]) for event in events
|
||||
]
|
||||
|
||||
mock_llm.is_function_calling_active.return_value = True
|
||||
|
||||
return mock_llm
|
||||
|
||||
|
||||
@ -600,3 +604,93 @@ def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
|
||||
for i, view in enumerate(harness.views(events)):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
||||
|
||||
|
||||
def test_structured_summary_condenser_from_config():
|
||||
"""Test that StructuredSummaryCondenser objects can be made from config."""
|
||||
config = StructuredSummaryCondenserConfig(
|
||||
max_size=50,
|
||||
keep_first=10,
|
||||
llm_config=LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
),
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
|
||||
assert isinstance(condenser, StructuredSummaryCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
|
||||
assert condenser.max_size == 50
|
||||
assert condenser.keep_first == 10
|
||||
|
||||
|
||||
def test_structured_summary_condenser_invalid_config():
|
||||
"""Test that StructuredSummaryCondenser raises error when keep_first > max_size."""
|
||||
# Since the condenser only works when function calling is on, we need to
|
||||
# mock up the check for that.
|
||||
llm = MagicMock()
|
||||
llm.is_function_calling_active.return_value = True
|
||||
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
StructuredSummaryCondenser,
|
||||
llm=llm,
|
||||
max_size=4,
|
||||
keep_first=2,
|
||||
)
|
||||
|
||||
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, max_size=0)
|
||||
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, keep_first=-1)
|
||||
|
||||
# If all other parameters are good but there's no function calling the
|
||||
# condenser still counts as improperly configured.
|
||||
llm.is_function_calling_active.return_value = False
|
||||
pytest.raises(
|
||||
ValueError, StructuredSummaryCondenser, llm=llm, max_size=40, keep_first=2
|
||||
)
|
||||
|
||||
|
||||
def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
|
||||
"""Test that StructuredSummaryCondenser maintains the correct view size."""
|
||||
max_size = 10
|
||||
condenser = StructuredSummaryCondenser(max_size=max_size, llm=mock_llm)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
|
||||
# Set up mock LLM response
|
||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||
|
||||
harness = RollingCondenserTestHarness(condenser)
|
||||
|
||||
for i, view in enumerate(harness.views(events)):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
"""Test that the StructuredSummaryCondenser appropriately maintains the event prefix and any summary events."""
|
||||
max_size = 10
|
||||
keep_first = 3
|
||||
condenser = StructuredSummaryCondenser(
|
||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
||||
)
|
||||
|
||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
harness = RollingCondenserTestHarness(condenser)
|
||||
|
||||
for i, view in enumerate(harness.views(events)):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
# Ensure that the we've called out the summarizing LLM once per condensation
|
||||
assert mock_llm.completion.call_count == harness.expected_condensations(
|
||||
i, max_size
|
||||
)
|
||||
|
||||
# Ensure that the prefix is appropiately maintained
|
||||
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
||||
|
||||
# If we've condensed, ensure that the summary event is present
|
||||
if i > max_size:
|
||||
assert isinstance(view[keep_first], AgentCondensationObservation)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user