Small fix and addition for token counting (#5550)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Engel Nyst 2024-12-15 15:12:05 +01:00 committed by GitHub
parent 4716955960
commit 590ebb6e47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 144 additions and 16 deletions

View File

@ -172,6 +172,10 @@ model = "gpt-4o"
# If model is vision capable, this option allows to disable image processing (useful for cost reduction).
#disable_vision = true
# Custom tokenizer to use for token counting
# https://docs.litellm.ai/docs/completion/token_usage
#custom_tokenizer = ""
[llm.gpt4o-mini]
api_key = "your-api-key"
model = "gpt-4o"

View File

@ -43,6 +43,7 @@ class LLMConfig:
log_completions: Whether to log LLM completions to the state.
log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
custom_tokenizer: A custom tokenizer to use for token counting.
"""
model: str = 'claude-3-5-sonnet-20241022'
@ -77,6 +78,7 @@ class LLMConfig:
log_completions: bool = False
log_completions_folder: str = os.path.join(LOG_DIR, 'completions')
draft_editor: Optional['LLMConfig'] = None
custom_tokenizer: str | None = None
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""

View File

@ -25,6 +25,7 @@ from litellm.exceptions import (
ServiceUnavailableError,
)
from litellm.types.utils import CostPerToken, ModelResponse, Usage
from litellm.utils import create_pretrained_tokenizer
from openhands.core.exceptions import CloudFlareBlockageError
from openhands.core.logger import openhands_logger as logger
@ -122,6 +123,13 @@ class LLM(RetryMixin, DebugMixin):
if self.is_function_calling_active():
logger.debug('LLM: model supports function calling')
# if using a custom tokenizer, make sure it's loaded and accessible in the format expected by litellm
if self.config.custom_tokenizer is not None:
self.tokenizer = create_pretrained_tokenizer(self.config.custom_tokenizer)
else:
self.tokenizer = None
# set up the completion function
self._completion = partial(
litellm_completion,
model=self.config.model,
@ -491,19 +499,43 @@ class LLM(RetryMixin, DebugMixin):
return cur_cost
def get_token_count(self, messages) -> int:
"""Get the number of tokens in a list of messages.
def get_token_count(self, messages: list[dict] | list[Message]) -> int:
"""Get the number of tokens in a list of messages. Use dicts for better token counting.
Args:
messages (list): A list of messages.
messages (list): A list of messages, either as a list of dicts or as a list of Message objects.
Returns:
int: The number of tokens.
"""
# attempt to convert Message objects to dicts, litellm expects dicts
if (
isinstance(messages, list)
and len(messages) > 0
and isinstance(messages[0], Message)
):
logger.info(
'Message objects now include serialized tool calls in token counting'
)
messages = self.format_messages_for_llm(messages) # type: ignore
# try to get the token count with the default litellm tokenizers
# or the custom tokenizer if set for this LLM configuration
try:
return litellm.token_counter(model=self.config.model, messages=messages)
except Exception:
# TODO: this is to limit logspam in case token count is not supported
return litellm.token_counter(
model=self.config.model,
messages=messages,
custom_tokenizer=self.tokenizer,
)
except Exception as e:
# limit logspam in case token count is not supported
logger.error(
f'Error getting token count for\n model {self.config.model}\n{e}'
+ (
f'\ncustom_tokenizer: {self.config.custom_tokenizer}'
if self.config.custom_tokenizer is not None
else ''
)
)
return 0
def _is_local(self) -> bool:

View File

@ -1,6 +1,6 @@
import asyncio
from copy import deepcopy
import time
from copy import deepcopy
import socketio
@ -9,7 +9,6 @@ from openhands.core.config import AppConfig
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.core.schema.config import ConfigType
from openhands.events.action import MessageAction, NullAction
from openhands.events.event import Event, EventSource
from openhands.events.observation import (
@ -68,15 +67,28 @@ class Session:
)
# Extract the agent-relevant arguments from the request
agent_cls = session_init_data.agent or self.config.default_agent
self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
self.config.security.confirmation_mode = (
self.config.security.confirmation_mode
if session_init_data.confirmation_mode is None
else session_init_data.confirmation_mode
)
self.config.security.security_analyzer = (
session_init_data.security_analyzer
or self.config.security.security_analyzer
)
max_iterations = session_init_data.max_iterations or self.config.max_iterations
# override default LLM config
default_llm_config = self.config.get_llm_config()
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url
default_llm_config.model = (
session_init_data.llm_model or default_llm_config.model
)
default_llm_config.api_key = (
session_init_data.llm_api_key or default_llm_config.api_key
)
default_llm_config.base_url = (
session_init_data.llm_base_url or default_llm_config.base_url
)
# TODO: override other LLM config & agent config groups (#2075)

View File

@ -1,5 +1,3 @@
from dataclasses import dataclass
@ -8,6 +6,7 @@ class SessionInitData:
"""
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""
language: str | None = None
agent: str | None = None
max_iterations: int | None = None

View File

@ -428,6 +428,7 @@ def test_api_keys_repr_str():
'aws_secret_access_key',
'input_cost_per_token',
'output_cost_per_token',
'custom_tokenizer',
]
for attr_name in dir(LLMConfig):
if (

View File

@ -11,6 +11,7 @@ from litellm.exceptions import (
from openhands.core.config import LLMConfig
from openhands.core.exceptions import OperationCancelled
from openhands.core.message import Message, TextContent
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
@ -21,6 +22,7 @@ def mock_logger(monkeypatch):
mock_logger = MagicMock()
monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger)
monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger)
monkeypatch.setattr('openhands.llm.llm.logger', mock_logger)
return mock_logger
@ -397,3 +399,79 @@ def test_llm_cloudflare_blockage(mock_litellm_completion, default_config):
# Ensure the completion was called
mock_litellm_completion.assert_called_once()
@patch('openhands.llm.llm.litellm.token_counter')
def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
mock_token_counter.return_value = 42
llm = LLM(default_config)
messages = [{'role': 'user', 'content': 'Hello!'}]
token_count = llm.get_token_count(messages)
assert token_count == 42
mock_token_counter.assert_called_once_with(
model=default_config.model, messages=messages, custom_tokenizer=None
)
@patch('openhands.llm.llm.litellm.token_counter')
def test_get_token_count_with_message_objects(
mock_token_counter, default_config, mock_logger
):
llm = LLM(default_config)
# Create a Message object and its equivalent dict
message_obj = Message(role='user', content=[TextContent(text='Hello!')])
message_dict = {'role': 'user', 'content': 'Hello!'}
# Mock token counter to return different values for each call
mock_token_counter.side_effect = [42, 42] # Same value for both cases
# Get token counts for both formats
token_count_obj = llm.get_token_count([message_obj])
token_count_dict = llm.get_token_count([message_dict])
# Verify both formats get the same token count
assert token_count_obj == token_count_dict
assert mock_token_counter.call_count == 2
@patch('openhands.llm.llm.litellm.token_counter')
@patch('openhands.llm.llm.create_pretrained_tokenizer')
def test_get_token_count_with_custom_tokenizer(
mock_create_tokenizer, mock_token_counter, default_config
):
mock_tokenizer = MagicMock()
mock_create_tokenizer.return_value = mock_tokenizer
mock_token_counter.return_value = 42
config = copy.deepcopy(default_config)
config.custom_tokenizer = 'custom/tokenizer'
llm = LLM(config)
messages = [{'role': 'user', 'content': 'Hello!'}]
token_count = llm.get_token_count(messages)
assert token_count == 42
mock_create_tokenizer.assert_called_once_with('custom/tokenizer')
mock_token_counter.assert_called_once_with(
model=config.model, messages=messages, custom_tokenizer=mock_tokenizer
)
@patch('openhands.llm.llm.litellm.token_counter')
def test_get_token_count_error_handling(
mock_token_counter, default_config, mock_logger
):
mock_token_counter.side_effect = Exception('Token counting failed')
llm = LLM(default_config)
messages = [{'role': 'user', 'content': 'Hello!'}]
token_count = llm.get_token_count(messages)
assert token_count == 0
mock_token_counter.assert_called_once()
mock_logger.error.assert_called_once_with(
'Error getting token count for\n model gpt-4o\nToken counting failed'
)