From 590ebb6e4730301b5644fb92b894ea391f3025c2 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Sun, 15 Dec 2024 15:12:05 +0100 Subject: [PATCH] Small fix and addition for token counting (#5550) Co-authored-by: openhands --- config.template.toml | 4 + openhands/core/config/llm_config.py | 2 + openhands/llm/llm.py | 46 +++++++++-- openhands/server/session/session.py | 26 +++++-- openhands/server/session/session_init_data.py | 3 +- tests/unit/test_config.py | 1 + tests/unit/test_llm.py | 78 +++++++++++++++++++ 7 files changed, 144 insertions(+), 16 deletions(-) diff --git a/config.template.toml b/config.template.toml index 6f626e6bee..e88150cd07 100644 --- a/config.template.toml +++ b/config.template.toml @@ -172,6 +172,10 @@ model = "gpt-4o" # If model is vision capable, this option allows to disable image processing (useful for cost reduction). #disable_vision = true +# Custom tokenizer to use for token counting +# https://docs.litellm.ai/docs/completion/token_usage +#custom_tokenizer = "" + [llm.gpt4o-mini] api_key = "your-api-key" model = "gpt-4o" diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index 477b47ccdb..4e60d4a281 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -43,6 +43,7 @@ class LLMConfig: log_completions: Whether to log LLM completions to the state. log_completions_folder: The folder to log LLM completions to. Required if log_completions is True. draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985). + custom_tokenizer: A custom tokenizer to use for token counting. """ model: str = 'claude-3-5-sonnet-20241022' @@ -77,6 +78,7 @@ class LLMConfig: log_completions: bool = False log_completions_folder: str = os.path.join(LOG_DIR, 'completions') draft_editor: Optional['LLMConfig'] = None + custom_tokenizer: str | None = None def defaults_to_dict(self) -> dict: """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 6b87b33009..d7c7309eff 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -25,6 +25,7 @@ from litellm.exceptions import ( ServiceUnavailableError, ) from litellm.types.utils import CostPerToken, ModelResponse, Usage +from litellm.utils import create_pretrained_tokenizer from openhands.core.exceptions import CloudFlareBlockageError from openhands.core.logger import openhands_logger as logger @@ -122,6 +123,13 @@ class LLM(RetryMixin, DebugMixin): if self.is_function_calling_active(): logger.debug('LLM: model supports function calling') + # if using a custom tokenizer, make sure it's loaded and accessible in the format expected by litellm + if self.config.custom_tokenizer is not None: + self.tokenizer = create_pretrained_tokenizer(self.config.custom_tokenizer) + else: + self.tokenizer = None + + # set up the completion function self._completion = partial( litellm_completion, model=self.config.model, @@ -491,19 +499,43 @@ class LLM(RetryMixin, DebugMixin): return cur_cost - def get_token_count(self, messages) -> int: - """Get the number of tokens in a list of messages. + def get_token_count(self, messages: list[dict] | list[Message]) -> int: + """Get the number of tokens in a list of messages. Use dicts for better token counting. Args: - messages (list): A list of messages. - + messages (list): A list of messages, either as a list of dicts or as a list of Message objects. Returns: int: The number of tokens. """ + # attempt to convert Message objects to dicts, litellm expects dicts + if ( + isinstance(messages, list) + and len(messages) > 0 + and isinstance(messages[0], Message) + ): + logger.info( + 'Message objects now include serialized tool calls in token counting' + ) + messages = self.format_messages_for_llm(messages) # type: ignore + + # try to get the token count with the default litellm tokenizers + # or the custom tokenizer if set for this LLM configuration try: - return litellm.token_counter(model=self.config.model, messages=messages) - except Exception: - # TODO: this is to limit logspam in case token count is not supported + return litellm.token_counter( + model=self.config.model, + messages=messages, + custom_tokenizer=self.tokenizer, + ) + except Exception as e: + # limit logspam in case token count is not supported + logger.error( + f'Error getting token count for\n model {self.config.model}\n{e}' + + ( + f'\ncustom_tokenizer: {self.config.custom_tokenizer}' + if self.config.custom_tokenizer is not None + else '' + ) + ) return 0 def _is_local(self) -> bool: diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 04f5ba7254..b8c7049e39 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -1,6 +1,6 @@ import asyncio -from copy import deepcopy import time +from copy import deepcopy import socketio @@ -9,7 +9,6 @@ from openhands.core.config import AppConfig from openhands.core.const.guide_url import TROUBLESHOOTING_URL from openhands.core.logger import openhands_logger as logger from openhands.core.schema import AgentState -from openhands.core.schema.config import ConfigType from openhands.events.action import MessageAction, NullAction from openhands.events.event import Event, EventSource from openhands.events.observation import ( @@ -68,15 +67,28 @@ class Session: ) # Extract the agent-relevant arguments from the request agent_cls = session_init_data.agent or self.config.default_agent - self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode - self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer + self.config.security.confirmation_mode = ( + self.config.security.confirmation_mode + if session_init_data.confirmation_mode is None + else session_init_data.confirmation_mode + ) + self.config.security.security_analyzer = ( + session_init_data.security_analyzer + or self.config.security.security_analyzer + ) max_iterations = session_init_data.max_iterations or self.config.max_iterations # override default LLM config default_llm_config = self.config.get_llm_config() - default_llm_config.model = session_init_data.llm_model or default_llm_config.model - default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key - default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url + default_llm_config.model = ( + session_init_data.llm_model or default_llm_config.model + ) + default_llm_config.api_key = ( + session_init_data.llm_api_key or default_llm_config.api_key + ) + default_llm_config.base_url = ( + session_init_data.llm_base_url or default_llm_config.base_url + ) # TODO: override other LLM config & agent config groups (#2075) diff --git a/openhands/server/session/session_init_data.py b/openhands/server/session/session_init_data.py index 1308598b14..f269b1c74a 100644 --- a/openhands/server/session/session_init_data.py +++ b/openhands/server/session/session_init_data.py @@ -1,5 +1,3 @@ - - from dataclasses import dataclass @@ -8,6 +6,7 @@ class SessionInitData: """ Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data. """ + language: str | None = None agent: str | None = None max_iterations: int | None = None diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 6d6681a983..d4ef11c4ce 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -428,6 +428,7 @@ def test_api_keys_repr_str(): 'aws_secret_access_key', 'input_cost_per_token', 'output_cost_per_token', + 'custom_tokenizer', ] for attr_name in dir(LLMConfig): if ( diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 3dc1e1a797..4973fcf353 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -11,6 +11,7 @@ from litellm.exceptions import ( from openhands.core.config import LLMConfig from openhands.core.exceptions import OperationCancelled +from openhands.core.message import Message, TextContent from openhands.llm.llm import LLM from openhands.llm.metrics import Metrics @@ -21,6 +22,7 @@ def mock_logger(monkeypatch): mock_logger = MagicMock() monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger) monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger) + monkeypatch.setattr('openhands.llm.llm.logger', mock_logger) return mock_logger @@ -397,3 +399,79 @@ def test_llm_cloudflare_blockage(mock_litellm_completion, default_config): # Ensure the completion was called mock_litellm_completion.assert_called_once() + + +@patch('openhands.llm.llm.litellm.token_counter') +def test_get_token_count_with_dict_messages(mock_token_counter, default_config): + mock_token_counter.return_value = 42 + llm = LLM(default_config) + messages = [{'role': 'user', 'content': 'Hello!'}] + + token_count = llm.get_token_count(messages) + + assert token_count == 42 + mock_token_counter.assert_called_once_with( + model=default_config.model, messages=messages, custom_tokenizer=None + ) + + +@patch('openhands.llm.llm.litellm.token_counter') +def test_get_token_count_with_message_objects( + mock_token_counter, default_config, mock_logger +): + llm = LLM(default_config) + + # Create a Message object and its equivalent dict + message_obj = Message(role='user', content=[TextContent(text='Hello!')]) + message_dict = {'role': 'user', 'content': 'Hello!'} + + # Mock token counter to return different values for each call + mock_token_counter.side_effect = [42, 42] # Same value for both cases + + # Get token counts for both formats + token_count_obj = llm.get_token_count([message_obj]) + token_count_dict = llm.get_token_count([message_dict]) + + # Verify both formats get the same token count + assert token_count_obj == token_count_dict + assert mock_token_counter.call_count == 2 + + +@patch('openhands.llm.llm.litellm.token_counter') +@patch('openhands.llm.llm.create_pretrained_tokenizer') +def test_get_token_count_with_custom_tokenizer( + mock_create_tokenizer, mock_token_counter, default_config +): + mock_tokenizer = MagicMock() + mock_create_tokenizer.return_value = mock_tokenizer + mock_token_counter.return_value = 42 + + config = copy.deepcopy(default_config) + config.custom_tokenizer = 'custom/tokenizer' + llm = LLM(config) + messages = [{'role': 'user', 'content': 'Hello!'}] + + token_count = llm.get_token_count(messages) + + assert token_count == 42 + mock_create_tokenizer.assert_called_once_with('custom/tokenizer') + mock_token_counter.assert_called_once_with( + model=config.model, messages=messages, custom_tokenizer=mock_tokenizer + ) + + +@patch('openhands.llm.llm.litellm.token_counter') +def test_get_token_count_error_handling( + mock_token_counter, default_config, mock_logger +): + mock_token_counter.side_effect = Exception('Token counting failed') + llm = LLM(default_config) + messages = [{'role': 'user', 'content': 'Hello!'}] + + token_count = llm.get_token_count(messages) + + assert token_count == 0 + mock_token_counter.assert_called_once() + mock_logger.error.assert_called_once_with( + 'Error getting token count for\n model gpt-4o\nToken counting failed' + )