mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Fix non-function calls messages (#5026)
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
This commit is contained in:
parent
68e52a9c62
commit
d08886f30e
@ -56,6 +56,7 @@ class Message(BaseModel):
|
||||
cache_enabled: bool = False
|
||||
vision_enabled: bool = False
|
||||
# function calling
|
||||
function_calling_enabled: bool = False
|
||||
# - tool calls (from LLM)
|
||||
tool_calls: list[ChatCompletionMessageToolCall] | None = None
|
||||
# - tool execution result (to LLM)
|
||||
@ -72,22 +73,22 @@ class Message(BaseModel):
|
||||
# - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
|
||||
# - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
|
||||
# NOTE: remove this when litellm or providers support the new API
|
||||
if (
|
||||
self.cache_enabled
|
||||
or self.vision_enabled
|
||||
or self.tool_call_id is not None
|
||||
or self.tool_calls is not None
|
||||
):
|
||||
if self.cache_enabled or self.vision_enabled or self.function_calling_enabled:
|
||||
return self._list_serializer()
|
||||
# some providers, like HF and Groq/llama, don't support a list here, but a single string
|
||||
return self._string_serializer()
|
||||
|
||||
def _string_serializer(self):
|
||||
def _string_serializer(self) -> dict:
|
||||
# convert content to a single string
|
||||
content = '\n'.join(
|
||||
item.text for item in self.content if isinstance(item, TextContent)
|
||||
)
|
||||
return {'content': content, 'role': self.role}
|
||||
message_dict: dict = {'content': content, 'role': self.role}
|
||||
|
||||
def _list_serializer(self):
|
||||
# add tool call keys if we have a tool call or response
|
||||
return self._add_tool_call_keys(message_dict)
|
||||
|
||||
def _list_serializer(self) -> dict:
|
||||
content: list[dict] = []
|
||||
role_tool_with_prompt_caching = False
|
||||
for item in self.content:
|
||||
@ -102,24 +103,37 @@ class Message(BaseModel):
|
||||
elif isinstance(item, ImageContent) and self.vision_enabled:
|
||||
content.extend(d)
|
||||
|
||||
ret: dict = {'content': content, 'role': self.role}
|
||||
message_dict: dict = {'content': content, 'role': self.role}
|
||||
|
||||
# pop content if it's empty
|
||||
if not content or (
|
||||
len(content) == 1
|
||||
and content[0]['type'] == 'text'
|
||||
and content[0]['text'] == ''
|
||||
):
|
||||
ret.pop('content')
|
||||
message_dict.pop('content')
|
||||
|
||||
if role_tool_with_prompt_caching:
|
||||
ret['cache_control'] = {'type': 'ephemeral'}
|
||||
message_dict['cache_control'] = {'type': 'ephemeral'}
|
||||
|
||||
# add tool call keys if we have a tool call or response
|
||||
return self._add_tool_call_keys(message_dict)
|
||||
|
||||
def _add_tool_call_keys(self, message_dict: dict) -> dict:
|
||||
"""Add tool call keys if we have a tool call or response.
|
||||
|
||||
NOTE: this is necessary for both native and non-native tool calling"""
|
||||
|
||||
# an assistant message calling a tool
|
||||
if self.tool_calls is not None:
|
||||
message_dict['tool_calls'] = self.tool_calls
|
||||
|
||||
# an observation message with tool response
|
||||
if self.tool_call_id is not None:
|
||||
assert (
|
||||
self.name is not None
|
||||
), 'name is required when tool_call_id is not None'
|
||||
ret['tool_call_id'] = self.tool_call_id
|
||||
ret['name'] = self.name
|
||||
if self.tool_calls:
|
||||
ret['tool_calls'] = self.tool_calls
|
||||
return ret
|
||||
message_dict['tool_call_id'] = self.tool_call_id
|
||||
message_dict['name'] = self.name
|
||||
|
||||
return message_dict
|
||||
|
||||
@ -320,9 +320,8 @@ def convert_fncall_messages_to_non_fncall_messages(
|
||||
converted_messages = []
|
||||
first_user_message_encountered = False
|
||||
for message in messages:
|
||||
role, content = message['role'], message['content']
|
||||
if content is None:
|
||||
content = ''
|
||||
role = message['role']
|
||||
content = message.get('content', '')
|
||||
|
||||
# 1. SYSTEM MESSAGES
|
||||
# append system prompt suffix to content
|
||||
@ -339,6 +338,7 @@ def convert_fncall_messages_to_non_fncall_messages(
|
||||
f'Unexpected content type {type(content)}. Expected str or list. Content: {content}'
|
||||
)
|
||||
converted_messages.append({'role': 'system', 'content': content})
|
||||
|
||||
# 2. USER MESSAGES (no change)
|
||||
elif role == 'user':
|
||||
# Add in-context learning example for the first user message
|
||||
@ -447,10 +447,12 @@ def convert_fncall_messages_to_non_fncall_messages(
|
||||
f'Unexpected content type {type(content)}. Expected str or list. Content: {content}'
|
||||
)
|
||||
converted_messages.append({'role': 'assistant', 'content': content})
|
||||
|
||||
# 4. TOOL MESSAGES (tool outputs)
|
||||
elif role == 'tool':
|
||||
# Convert tool result as assistant message
|
||||
prefix = f'EXECUTION RESULT of [{message["name"]}]:\n'
|
||||
# Convert tool result as user message
|
||||
tool_name = message.get('name', 'function')
|
||||
prefix = f'EXECUTION RESULT of [{tool_name}]:\n'
|
||||
# and omit "tool_call_id" AND "name"
|
||||
if isinstance(content, str):
|
||||
content = prefix + content
|
||||
|
||||
@ -122,6 +122,9 @@ class LLM(RetryMixin, DebugMixin):
|
||||
drop_params=self.config.drop_params,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
self.init_model_info()
|
||||
if self.vision_is_active():
|
||||
logger.debug('LLM: model has vision enabled')
|
||||
if self.is_caching_prompt_active():
|
||||
@ -143,16 +146,6 @@ class LLM(RetryMixin, DebugMixin):
|
||||
drop_params=self.config.drop_params,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
self.init_model_info()
|
||||
if self.vision_is_active():
|
||||
logger.debug('LLM: model has vision enabled')
|
||||
if self.is_caching_prompt_active():
|
||||
logger.debug('LLM: caching prompt enabled')
|
||||
if self.is_function_calling_active():
|
||||
logger.debug('LLM: model supports function calling')
|
||||
|
||||
self._completion_unwrapped = self._completion
|
||||
|
||||
@self.retry_decorator(
|
||||
@ -342,6 +335,13 @@ class LLM(RetryMixin, DebugMixin):
|
||||
pass
|
||||
logger.debug(f'Model info: {self.model_info}')
|
||||
|
||||
if self.config.model.startswith('huggingface'):
|
||||
# HF doesn't support the OpenAI default value for top_p (1)
|
||||
logger.debug(
|
||||
f'Setting top_p to 0.9 for Hugging Face model: {self.config.model}'
|
||||
)
|
||||
self.config.top_p = 0.9 if self.config.top_p == 1 else self.config.top_p
|
||||
|
||||
# Set the max tokens in an LM-specific way if not set
|
||||
if self.config.max_input_tokens is None:
|
||||
if (
|
||||
@ -566,6 +566,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
for message in messages:
|
||||
message.cache_enabled = self.is_caching_prompt_active()
|
||||
message.vision_enabled = self.vision_is_active()
|
||||
message.function_calling_enabled = self.is_function_calling_active()
|
||||
|
||||
# let pydantic handle the serialization
|
||||
return [message.model_dump() for message in messages]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user