fix(llm): fallback when model is out of function calling supported list (#4617)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Xingyao Wang
2024-10-30 12:54:50 -05:00
committed by GitHub
parent 87bc35d2c8
commit 2587220b12
4 changed files with 34 additions and 19 deletions

View File

@@ -93,17 +93,16 @@ class CodeActAgent(Agent):
if config.micro_agent_name
else None
)
if (
self.config.function_calling
and not self.llm.config.supports_function_calling
):
self.function_calling_active = self.config.function_calling
if self.function_calling_active and not self.llm.is_function_calling_active():
logger.warning(
f'Function calling not supported for model {self.llm.config.model}. '
'Disabling function calling.'
)
self.config.function_calling = False
self.function_calling_active = False
if self.config.function_calling:
if self.function_calling_active:
# Function calling mode
self.tools = codeact_function_calling.get_tools(
codeact_enable_browsing_delegate=self.config.codeact_enable_browsing_delegate,
@@ -172,7 +171,7 @@ class CodeActAgent(Agent):
FileEditAction,
),
) or (isinstance(action, AgentFinishAction) and action.source == 'agent'):
if self.config.function_calling:
if self.function_calling_active:
tool_metadata = action.tool_call_metadata
assert tool_metadata is not None, (
'Tool call metadata should NOT be None when function calling is enabled. Action: '
@@ -286,7 +285,7 @@ class CodeActAgent(Agent):
# when the LLM tries to return the next message
raise ValueError(f'Unknown observation type: {type(obs)}')
if self.config.function_calling:
if self.function_calling_active:
# Update the message as tool response properly
if (tool_call_metadata := obs.tool_call_metadata) is not None:
tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message(
@@ -334,7 +333,7 @@ class CodeActAgent(Agent):
params: dict = {
'messages': self.llm.format_messages_for_llm(messages),
}
if self.config.function_calling:
if self.function_calling_active:
params['tools'] = self.tools
else:
params['stop'] = [
@@ -345,7 +344,7 @@ class CodeActAgent(Agent):
]
response = self.llm.completion(**params)
if self.config.function_calling:
if self.function_calling_active:
actions = codeact_function_calling.response_to_actions(response)
for action in actions:
self.pending_actions.append(action)
@@ -479,7 +478,7 @@ class CodeActAgent(Agent):
else:
break
if not self.config.function_calling:
if not self.function_calling_active:
# The latest user message is important:
# we want to remind the agent of the environment constraints
latest_user_message = next(

View File

@@ -42,7 +42,6 @@ class LLMConfig:
log_completions: Whether to log LLM completions to the state.
log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
supports_function_calling: Whether the model supports function calling.
"""
model: str = 'claude-3-5-sonnet-20241022'
@@ -77,7 +76,6 @@ class LLMConfig:
log_completions: bool = False
log_completions_folder: str | None = None
draft_editor: Optional['LLMConfig'] = None
supports_function_calling: bool = False
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""

View File

@@ -53,6 +53,14 @@ CACHE_PROMPT_SUPPORTED_MODELS = [
'claude-3-opus-20240229',
]
# function calling supporting models
FUNCTION_CALLING_SUPPORTED_MODELS = [
'claude-3-5-sonnet-20240620',
'claude-3-5-sonnet-20241022',
'gpt-4o',
'gpt-4o-mini',
]
class LLM(RetryMixin, DebugMixin):
"""The LLM class represents a Language Model instance.
@@ -163,11 +171,6 @@ class LLM(RetryMixin, DebugMixin):
):
self.config.max_output_tokens = self.model_info['max_tokens']
self.config.supports_function_calling = (
self.model_info is not None
and self.model_info.get('supports_function_calling', False)
)
self._completion = partial(
litellm_completion,
model=self.config.model,
@@ -186,7 +189,7 @@ class LLM(RetryMixin, DebugMixin):
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
logger.debug('LLM: caching prompt enabled')
if self.config.supports_function_calling:
if self.is_function_calling_active():
logger.debug('LLM: model supports function calling')
completion_unwrapped = self._completion
@@ -327,6 +330,18 @@ class LLM(RetryMixin, DebugMixin):
)
)
def is_function_calling_active(self) -> bool:
# Check if model name is in supported list before checking model_info
model_name_supported = (
self.config.model in FUNCTION_CALLING_SUPPORTED_MODELS
or self.config.model.split('/')[-1] in FUNCTION_CALLING_SUPPORTED_MODELS
or any(m in self.config.model for m in FUNCTION_CALLING_SUPPORTED_MODELS)
)
return model_name_supported and (
self.model_info is not None
and self.model_info.get('supports_function_calling', False)
)
def _post_completion(self, response: ModelResponse) -> None:
"""Post-process the completion response.

View File

@@ -137,6 +137,9 @@ def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
if codeact_agent.config.function_calling:
pytest.skip('Skipping this test for function calling')
# Add a mix of actions and observations
message_action_1 = MessageAction(
"Let's list the contents of the current directory."