diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 997d424c51..d1f67eae9c 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -93,17 +93,16 @@ class CodeActAgent(Agent): if config.micro_agent_name else None ) - if ( - self.config.function_calling - and not self.llm.config.supports_function_calling - ): + + self.function_calling_active = self.config.function_calling + if self.function_calling_active and not self.llm.is_function_calling_active(): logger.warning( f'Function calling not supported for model {self.llm.config.model}. ' 'Disabling function calling.' ) - self.config.function_calling = False + self.function_calling_active = False - if self.config.function_calling: + if self.function_calling_active: # Function calling mode self.tools = codeact_function_calling.get_tools( codeact_enable_browsing_delegate=self.config.codeact_enable_browsing_delegate, @@ -172,7 +171,7 @@ class CodeActAgent(Agent): FileEditAction, ), ) or (isinstance(action, AgentFinishAction) and action.source == 'agent'): - if self.config.function_calling: + if self.function_calling_active: tool_metadata = action.tool_call_metadata assert tool_metadata is not None, ( 'Tool call metadata should NOT be None when function calling is enabled. Action: ' @@ -286,7 +285,7 @@ class CodeActAgent(Agent): # when the LLM tries to return the next message raise ValueError(f'Unknown observation type: {type(obs)}') - if self.config.function_calling: + if self.function_calling_active: # Update the message as tool response properly if (tool_call_metadata := obs.tool_call_metadata) is not None: tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message( @@ -334,7 +333,7 @@ class CodeActAgent(Agent): params: dict = { 'messages': self.llm.format_messages_for_llm(messages), } - if self.config.function_calling: + if self.function_calling_active: params['tools'] = self.tools else: params['stop'] = [ @@ -345,7 +344,7 @@ class CodeActAgent(Agent): ] response = self.llm.completion(**params) - if self.config.function_calling: + if self.function_calling_active: actions = codeact_function_calling.response_to_actions(response) for action in actions: self.pending_actions.append(action) @@ -479,7 +478,7 @@ class CodeActAgent(Agent): else: break - if not self.config.function_calling: + if not self.function_calling_active: # The latest user message is important: # we want to remind the agent of the environment constraints latest_user_message = next( diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index 2bcc0bd391..ced45b905b 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -42,7 +42,6 @@ class LLMConfig: log_completions: Whether to log LLM completions to the state. log_completions_folder: The folder to log LLM completions to. Required if log_completions is True. draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985). - supports_function_calling: Whether the model supports function calling. """ model: str = 'claude-3-5-sonnet-20241022' @@ -77,7 +76,6 @@ class LLMConfig: log_completions: bool = False log_completions_folder: str | None = None draft_editor: Optional['LLMConfig'] = None - supports_function_calling: bool = False def defaults_to_dict(self) -> dict: """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 26e910d7e0..ba887a3542 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -53,6 +53,14 @@ CACHE_PROMPT_SUPPORTED_MODELS = [ 'claude-3-opus-20240229', ] +# function calling supporting models +FUNCTION_CALLING_SUPPORTED_MODELS = [ + 'claude-3-5-sonnet-20240620', + 'claude-3-5-sonnet-20241022', + 'gpt-4o', + 'gpt-4o-mini', +] + class LLM(RetryMixin, DebugMixin): """The LLM class represents a Language Model instance. @@ -163,11 +171,6 @@ class LLM(RetryMixin, DebugMixin): ): self.config.max_output_tokens = self.model_info['max_tokens'] - self.config.supports_function_calling = ( - self.model_info is not None - and self.model_info.get('supports_function_calling', False) - ) - self._completion = partial( litellm_completion, model=self.config.model, @@ -186,7 +189,7 @@ class LLM(RetryMixin, DebugMixin): logger.debug('LLM: model has vision enabled') if self.is_caching_prompt_active(): logger.debug('LLM: caching prompt enabled') - if self.config.supports_function_calling: + if self.is_function_calling_active(): logger.debug('LLM: model supports function calling') completion_unwrapped = self._completion @@ -327,6 +330,18 @@ class LLM(RetryMixin, DebugMixin): ) ) + def is_function_calling_active(self) -> bool: + # Check if model name is in supported list before checking model_info + model_name_supported = ( + self.config.model in FUNCTION_CALLING_SUPPORTED_MODELS + or self.config.model.split('/')[-1] in FUNCTION_CALLING_SUPPORTED_MODELS + or any(m in self.config.model for m in FUNCTION_CALLING_SUPPORTED_MODELS) + ) + return model_name_supported and ( + self.model_info is not None + and self.model_info.get('supports_function_calling', False) + ) + def _post_completion(self, response: ModelResponse) -> None: """Post-process the completion response. diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 7adbd9119d..50c42bf662 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -137,6 +137,9 @@ def test_get_messages_prompt_caching(codeact_agent, mock_event_stream): def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): + if codeact_agent.config.function_calling: + pytest.skip('Skipping this test for function calling') + # Add a mix of actions and observations message_action_1 = MessageAction( "Let's list the contents of the current directory."