Fix non-function calls messages (#5026)

Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
This commit is contained in:
Engel Nyst 2024-11-21 19:18:49 +01:00 committed by GitHub
parent 68e52a9c62
commit d08886f30e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 32 deletions

View File

@ -56,6 +56,7 @@ class Message(BaseModel):
cache_enabled: bool = False
vision_enabled: bool = False
# function calling
function_calling_enabled: bool = False
# - tool calls (from LLM)
tool_calls: list[ChatCompletionMessageToolCall] | None = None
# - tool execution result (to LLM)
@ -72,22 +73,22 @@ class Message(BaseModel):
# - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
# - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
# NOTE: remove this when litellm or providers support the new API
if (
self.cache_enabled
or self.vision_enabled
or self.tool_call_id is not None
or self.tool_calls is not None
):
if self.cache_enabled or self.vision_enabled or self.function_calling_enabled:
return self._list_serializer()
# some providers, like HF and Groq/llama, don't support a list here, but a single string
return self._string_serializer()
def _string_serializer(self):
def _string_serializer(self) -> dict:
# convert content to a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
return {'content': content, 'role': self.role}
message_dict: dict = {'content': content, 'role': self.role}
def _list_serializer(self):
# add tool call keys if we have a tool call or response
return self._add_tool_call_keys(message_dict)
def _list_serializer(self) -> dict:
content: list[dict] = []
role_tool_with_prompt_caching = False
for item in self.content:
@ -102,24 +103,37 @@ class Message(BaseModel):
elif isinstance(item, ImageContent) and self.vision_enabled:
content.extend(d)
ret: dict = {'content': content, 'role': self.role}
message_dict: dict = {'content': content, 'role': self.role}
# pop content if it's empty
if not content or (
len(content) == 1
and content[0]['type'] == 'text'
and content[0]['text'] == ''
):
ret.pop('content')
message_dict.pop('content')
if role_tool_with_prompt_caching:
ret['cache_control'] = {'type': 'ephemeral'}
message_dict['cache_control'] = {'type': 'ephemeral'}
# add tool call keys if we have a tool call or response
return self._add_tool_call_keys(message_dict)
def _add_tool_call_keys(self, message_dict: dict) -> dict:
"""Add tool call keys if we have a tool call or response.
NOTE: this is necessary for both native and non-native tool calling"""
# an assistant message calling a tool
if self.tool_calls is not None:
message_dict['tool_calls'] = self.tool_calls
# an observation message with tool response
if self.tool_call_id is not None:
assert (
self.name is not None
), 'name is required when tool_call_id is not None'
ret['tool_call_id'] = self.tool_call_id
ret['name'] = self.name
if self.tool_calls:
ret['tool_calls'] = self.tool_calls
return ret
message_dict['tool_call_id'] = self.tool_call_id
message_dict['name'] = self.name
return message_dict

View File

@ -320,9 +320,8 @@ def convert_fncall_messages_to_non_fncall_messages(
converted_messages = []
first_user_message_encountered = False
for message in messages:
role, content = message['role'], message['content']
if content is None:
content = ''
role = message['role']
content = message.get('content', '')
# 1. SYSTEM MESSAGES
# append system prompt suffix to content
@ -339,6 +338,7 @@ def convert_fncall_messages_to_non_fncall_messages(
f'Unexpected content type {type(content)}. Expected str or list. Content: {content}'
)
converted_messages.append({'role': 'system', 'content': content})
# 2. USER MESSAGES (no change)
elif role == 'user':
# Add in-context learning example for the first user message
@ -447,10 +447,12 @@ def convert_fncall_messages_to_non_fncall_messages(
f'Unexpected content type {type(content)}. Expected str or list. Content: {content}'
)
converted_messages.append({'role': 'assistant', 'content': content})
# 4. TOOL MESSAGES (tool outputs)
elif role == 'tool':
# Convert tool result as assistant message
prefix = f'EXECUTION RESULT of [{message["name"]}]:\n'
# Convert tool result as user message
tool_name = message.get('name', 'function')
prefix = f'EXECUTION RESULT of [{tool_name}]:\n'
# and omit "tool_call_id" AND "name"
if isinstance(content, str):
content = prefix + content

View File

@ -122,6 +122,9 @@ class LLM(RetryMixin, DebugMixin):
drop_params=self.config.drop_params,
)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
self.init_model_info()
if self.vision_is_active():
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
@ -143,16 +146,6 @@ class LLM(RetryMixin, DebugMixin):
drop_params=self.config.drop_params,
)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
self.init_model_info()
if self.vision_is_active():
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
logger.debug('LLM: caching prompt enabled')
if self.is_function_calling_active():
logger.debug('LLM: model supports function calling')
self._completion_unwrapped = self._completion
@self.retry_decorator(
@ -342,6 +335,13 @@ class LLM(RetryMixin, DebugMixin):
pass
logger.debug(f'Model info: {self.model_info}')
if self.config.model.startswith('huggingface'):
# HF doesn't support the OpenAI default value for top_p (1)
logger.debug(
f'Setting top_p to 0.9 for Hugging Face model: {self.config.model}'
)
self.config.top_p = 0.9 if self.config.top_p == 1 else self.config.top_p
# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
@ -566,6 +566,7 @@ class LLM(RetryMixin, DebugMixin):
for message in messages:
message.cache_enabled = self.is_caching_prompt_active()
message.vision_enabled = self.vision_is_active()
message.function_calling_enabled = self.is_function_calling_active()
# let pydantic handle the serialization
return [message.model_dump() for message in messages]