Fix issue #4604: '[Bug]: Disable function calling for DeepSeek'

This commit is contained in:
openhands
2024-10-29 19:46:52 +00:00
parent 981b05fc2b
commit b19ce48bb9
4 changed files with 64 additions and 6 deletions

View File

@@ -90,7 +90,11 @@ class LLMConfig:
"""
Post-initialization hook to assign OpenRouter-related variables to environment variables.
This ensures that these values are accessible to litellm at runtime.
Also sets model-specific capabilities.
"""
# Set function calling support for DeepSeek models
if 'deepseek' in self.model.lower():
self.supports_function_calling = True
# Assign OpenRouter-specific variables to environment variables
if self.openrouter_site_url:
@@ -136,3 +140,4 @@ class LLMConfig:
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
args['draft_editor'] = draft_editor_config
return cls(**args)

View File

@@ -80,17 +80,32 @@ class Message(BaseModel):
elif isinstance(item, ImageContent) and self.vision_enabled:
content.extend(d)
# For DeepSeek, we need to ensure content is a string when using tool calls
if len(content) == 1 and isinstance(content[0], dict) and 'text' in content[0]:
content = content[0]['text']
ret: dict = {'content': content, 'role': self.role}
if role_tool_with_prompt_caching:
ret['cache_control'] = {'type': 'ephemeral'}
# Handle tool calls for DeepSeek compatibility
if self.tool_call_id is not None:
assert (
self.name is not None
), 'name is required when tool_call_id is not None'
assert self.name is not None, 'name is required when tool_call_id is not None'
ret['tool_call_id'] = self.tool_call_id
ret['name'] = self.name
if self.tool_calls:
ret['tool_calls'] = self.tool_calls
# Ensure tool_calls is properly serialized for DeepSeek
ret['tool_calls'] = [
{
'id': tc.id,
'type': tc.type,
'function': {
'name': tc.function.name,
'arguments': tc.function.arguments
}
} for tc in self.tool_calls
]
return ret

View File

@@ -159,11 +159,12 @@ class LLM(RetryMixin, DebugMixin):
self.config.max_output_tokens = self.model_info['max_tokens']
self.config.supports_function_calling = (
self.model_info is not None
and self.model_info.get('supports_function_calling', False)
(self.model_info is not None and self.model_info.get('supports_function_calling', False))
or 'deepseek' in self.config.model.lower() # DeepSeek models support function calling
)
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key,
@@ -473,3 +474,9 @@ class LLM(RetryMixin, DebugMixin):
# let pydantic handle the serialization
return [message.model_dump() for message in messages]

View File

@@ -0,0 +1,31 @@
import pytest
from openhands.core.config import LLMConfig
from openhands.core.message import Message, TextContent
from openhands.llm.llm import LLM
def test_deepseek_tool_calling_message_serialization():
# Create a message with tool calls
message = Message(
role="assistant",
content=[TextContent(text="Let me help you with that.")],
tool_calls=[{
"id": "call_123",
"type": "function",
"function": {
"name": "test_function",
"arguments": '{"arg1": "value1"}'
}
}]
)
# Verify serialization
serialized = message.model_dump()
assert "tool_calls" in serialized
assert isinstance(serialized["content"], str)
assert serialized["tool_calls"][0]["id"] == "call_123"
assert serialized["tool_calls"][0]["function"]["name"] == "test_function"
def test_deepseek_model_supports_function_calling():
config = LLMConfig(model="deepseek-chat")
llm = LLM(config)
assert llm.config.supports_function_calling is True