fix: serialize tool calls (#5553)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Engel Nyst 2024-12-13 20:51:03 +01:00 committed by GitHub
parent d782bdf691
commit d733bc6bdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 156 additions and 2 deletions

View File

@ -166,6 +166,7 @@ class CodeActAgent(Agent):
# Add the LLM message (assistant) that initiated the tool calls
# (overwrites any previous message with the same response_id)
logger.debug(f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}')
pending_tool_call_action_messages[llm_response.id] = Message(
role=assistant_msg.role,
# tool call content SHOULD BE a string

View File

@ -114,11 +114,21 @@ class Message(BaseModel):
def _add_tool_call_keys(self, message_dict: dict) -> dict:
"""Add tool call keys if we have a tool call or response.
NOTE: this is necessary for both native and non-native tool calling"""
NOTE: this is necessary for both native and non-native tool calling."""
# an assistant message calling a tool
if self.tool_calls is not None:
message_dict['tool_calls'] = self.tool_calls
message_dict['tool_calls'] = [
{
'id': tool_call.id,
'type': 'function',
'function': {
'name': tool_call.function.name,
'arguments': tool_call.function.arguments,
},
}
for tool_call in self.tool_calls
]
# an observation message with tool response
if self.tool_call_id is not None:

View File

@ -0,0 +1,89 @@
# OpenHands Message Format and litellm Integration
## Overview
OpenHands uses its own `Message` class (`openhands/core/message.py`) which provides rich content support while maintaining compatibility with litellm's message handling system.
## Class Structure
Our `Message` class (`openhands/core/message.py`):
```python
class Message(BaseModel):
role: Literal['user', 'system', 'assistant', 'tool']
content: list[TextContent | ImageContent] = Field(default_factory=list)
cache_enabled: bool = False
vision_enabled: bool = False
condensable: bool = True
function_calling_enabled: bool = False
tool_calls: list[ChatCompletionMessageToolCall] | None = None
tool_call_id: str | None = None
name: str | None = None
event_id: int = -1
```
litellm's `Message` class (`litellm/types/utils.py`):
```python
class Message(OpenAIObject):
content: Optional[str]
role: Literal["assistant", "user", "system", "tool", "function"]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
function_call: Optional[FunctionCall]
audio: Optional[ChatCompletionAudioResponse] = None
```
## How It Works
1. **Message Creation**: Our `Message` class is a Pydantic model that supports rich content (text and images) through its `content` field.
2. **Serialization**: The class uses Pydantic's `@model_serializer` to convert messages into dictionaries that litellm can understand. We have two serialization methods:
```python
def _string_serializer(self) -> dict:
# convert content to a single string
content = '\n'.join(item.text for item in self.content if isinstance(item, TextContent))
message_dict: dict = {'content': content, 'role': self.role}
return self._add_tool_call_keys(message_dict)
def _list_serializer(self) -> dict:
content: list[dict] = []
for item in self.content:
d = item.model_dump()
if isinstance(item, TextContent):
content.append(d)
elif isinstance(item, ImageContent) and self.vision_enabled:
content.extend(d)
return {'content': content, 'role': self.role}
```
The appropriate serializer is chosen based on the message's capabilities:
```python
@model_serializer
def serialize_model(self) -> dict:
if self.cache_enabled or self.vision_enabled or self.function_calling_enabled:
return self._list_serializer()
return self._string_serializer()
```
3. **Tool Call Handling**: Tool calls require special attention in serialization because:
- They need to work with litellm's API calls (which accept both dicts and objects)
- They need to be properly serialized for token counting
- They need to maintain compatibility with different LLM providers' formats
4. **litellm Integration**: When we pass our messages to `litellm.completion()`, litellm doesn't care about the message class type - it works with the dictionary representation. This works because:
- litellm's transformation code (e.g., `litellm/llms/anthropic/chat/transformation.py`) processes messages based on their structure, not their type
- our serialization produces dictionaries that match litellm's expected format
- litellm handles rich content by looking at the message structure, supporting both simple string content and lists of content items
5. **Provider-Specific Handling**: litellm then transforms these messages into provider-specific formats (e.g., Anthropic, OpenAI) through its transformation layers, which know how to handle both simple and rich content structures.
### Token Counting
To use litellm's token counter, we need to make sure that all message components (including tool calls) are properly serialized to dictionaries. This is because:
- litellm's token counter expects dictionary structures
- Tool calls need to be included in the token count
- Different providers may count tokens differently for structured content
## Note
- We don't need to inherit from litellm's `Message` class because litellm works with dictionary representations, not class types
- Our rich content model is more sophisticated than litellm's basic string content, but litellm handles it correctly through its transformation layers
- The compatibility is maintained through proper serialization rather than inheritance

View File

@ -1,3 +1,5 @@
from litellm import ChatCompletionMessageToolCall
from openhands.core.message import ImageContent, Message, TextContent
@ -114,3 +116,55 @@ def test_message_with_mixed_content_and_vision_disabled():
assert serialized_message == expected_serialized_message
# Assert that images exist in the original message
assert message.contains_image
def test_message_tool_call_serialization():
"""Test that tool calls are properly serialized into dicts for token counting."""
# Create a tool call
tool_call = ChatCompletionMessageToolCall(
id='call_123',
type='function',
function={'name': 'test_function', 'arguments': '{"arg1": "value1"}'},
)
# Create a message with the tool call
message = Message(
role='assistant',
content=[TextContent(text='Test message')],
tool_calls=[tool_call],
)
# Serialize the message
serialized = message.model_dump()
# Check that tool calls are properly serialized
assert 'tool_calls' in serialized
assert isinstance(serialized['tool_calls'], list)
assert len(serialized['tool_calls']) == 1
tool_call_dict = serialized['tool_calls'][0]
assert isinstance(tool_call_dict, dict)
assert tool_call_dict['id'] == 'call_123'
assert tool_call_dict['type'] == 'function'
assert tool_call_dict['function']['name'] == 'test_function'
assert tool_call_dict['function']['arguments'] == '{"arg1": "value1"}'
def test_message_tool_response_serialization():
"""Test that tool responses are properly serialized."""
# Create a message with tool response
message = Message(
role='tool',
content=[TextContent(text='Function result')],
tool_call_id='call_123',
name='test_function',
)
# Serialize the message
serialized = message.model_dump()
# Check that tool response fields are properly serialized
assert 'tool_call_id' in serialized
assert serialized['tool_call_id'] == 'call_123'
assert 'name' in serialized
assert serialized['name'] == 'test_function'