mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
fix(llm): fallback when model is out of function calling supported list (#4617)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -93,17 +93,16 @@ class CodeActAgent(Agent):
|
||||
if config.micro_agent_name
|
||||
else None
|
||||
)
|
||||
if (
|
||||
self.config.function_calling
|
||||
and not self.llm.config.supports_function_calling
|
||||
):
|
||||
|
||||
self.function_calling_active = self.config.function_calling
|
||||
if self.function_calling_active and not self.llm.is_function_calling_active():
|
||||
logger.warning(
|
||||
f'Function calling not supported for model {self.llm.config.model}. '
|
||||
'Disabling function calling.'
|
||||
)
|
||||
self.config.function_calling = False
|
||||
self.function_calling_active = False
|
||||
|
||||
if self.config.function_calling:
|
||||
if self.function_calling_active:
|
||||
# Function calling mode
|
||||
self.tools = codeact_function_calling.get_tools(
|
||||
codeact_enable_browsing_delegate=self.config.codeact_enable_browsing_delegate,
|
||||
@@ -172,7 +171,7 @@ class CodeActAgent(Agent):
|
||||
FileEditAction,
|
||||
),
|
||||
) or (isinstance(action, AgentFinishAction) and action.source == 'agent'):
|
||||
if self.config.function_calling:
|
||||
if self.function_calling_active:
|
||||
tool_metadata = action.tool_call_metadata
|
||||
assert tool_metadata is not None, (
|
||||
'Tool call metadata should NOT be None when function calling is enabled. Action: '
|
||||
@@ -286,7 +285,7 @@ class CodeActAgent(Agent):
|
||||
# when the LLM tries to return the next message
|
||||
raise ValueError(f'Unknown observation type: {type(obs)}')
|
||||
|
||||
if self.config.function_calling:
|
||||
if self.function_calling_active:
|
||||
# Update the message as tool response properly
|
||||
if (tool_call_metadata := obs.tool_call_metadata) is not None:
|
||||
tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message(
|
||||
@@ -334,7 +333,7 @@ class CodeActAgent(Agent):
|
||||
params: dict = {
|
||||
'messages': self.llm.format_messages_for_llm(messages),
|
||||
}
|
||||
if self.config.function_calling:
|
||||
if self.function_calling_active:
|
||||
params['tools'] = self.tools
|
||||
else:
|
||||
params['stop'] = [
|
||||
@@ -345,7 +344,7 @@ class CodeActAgent(Agent):
|
||||
]
|
||||
response = self.llm.completion(**params)
|
||||
|
||||
if self.config.function_calling:
|
||||
if self.function_calling_active:
|
||||
actions = codeact_function_calling.response_to_actions(response)
|
||||
for action in actions:
|
||||
self.pending_actions.append(action)
|
||||
@@ -479,7 +478,7 @@ class CodeActAgent(Agent):
|
||||
else:
|
||||
break
|
||||
|
||||
if not self.config.function_calling:
|
||||
if not self.function_calling_active:
|
||||
# The latest user message is important:
|
||||
# we want to remind the agent of the environment constraints
|
||||
latest_user_message = next(
|
||||
|
||||
@@ -42,7 +42,6 @@ class LLMConfig:
|
||||
log_completions: Whether to log LLM completions to the state.
|
||||
log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
|
||||
draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
|
||||
supports_function_calling: Whether the model supports function calling.
|
||||
"""
|
||||
|
||||
model: str = 'claude-3-5-sonnet-20241022'
|
||||
@@ -77,7 +76,6 @@ class LLMConfig:
|
||||
log_completions: bool = False
|
||||
log_completions_folder: str | None = None
|
||||
draft_editor: Optional['LLMConfig'] = None
|
||||
supports_function_calling: bool = False
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
|
||||
@@ -53,6 +53,14 @@ CACHE_PROMPT_SUPPORTED_MODELS = [
|
||||
'claude-3-opus-20240229',
|
||||
]
|
||||
|
||||
# function calling supporting models
|
||||
FUNCTION_CALLING_SUPPORTED_MODELS = [
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
'gpt-4o',
|
||||
'gpt-4o-mini',
|
||||
]
|
||||
|
||||
|
||||
class LLM(RetryMixin, DebugMixin):
|
||||
"""The LLM class represents a Language Model instance.
|
||||
@@ -163,11 +171,6 @@ class LLM(RetryMixin, DebugMixin):
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_tokens']
|
||||
|
||||
self.config.supports_function_calling = (
|
||||
self.model_info is not None
|
||||
and self.model_info.get('supports_function_calling', False)
|
||||
)
|
||||
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
model=self.config.model,
|
||||
@@ -186,7 +189,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
logger.debug('LLM: model has vision enabled')
|
||||
if self.is_caching_prompt_active():
|
||||
logger.debug('LLM: caching prompt enabled')
|
||||
if self.config.supports_function_calling:
|
||||
if self.is_function_calling_active():
|
||||
logger.debug('LLM: model supports function calling')
|
||||
|
||||
completion_unwrapped = self._completion
|
||||
@@ -327,6 +330,18 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
)
|
||||
|
||||
def is_function_calling_active(self) -> bool:
|
||||
# Check if model name is in supported list before checking model_info
|
||||
model_name_supported = (
|
||||
self.config.model in FUNCTION_CALLING_SUPPORTED_MODELS
|
||||
or self.config.model.split('/')[-1] in FUNCTION_CALLING_SUPPORTED_MODELS
|
||||
or any(m in self.config.model for m in FUNCTION_CALLING_SUPPORTED_MODELS)
|
||||
)
|
||||
return model_name_supported and (
|
||||
self.model_info is not None
|
||||
and self.model_info.get('supports_function_calling', False)
|
||||
)
|
||||
|
||||
def _post_completion(self, response: ModelResponse) -> None:
|
||||
"""Post-process the completion response.
|
||||
|
||||
|
||||
@@ -137,6 +137,9 @@ def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
|
||||
|
||||
|
||||
def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
|
||||
if codeact_agent.config.function_calling:
|
||||
pytest.skip('Skipping this test for function calling')
|
||||
|
||||
# Add a mix of actions and observations
|
||||
message_action_1 = MessageAction(
|
||||
"Let's list the contents of the current directory."
|
||||
|
||||
Reference in New Issue
Block a user