mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Small fix and addition for token counting (#5550)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
4716955960
commit
590ebb6e47
@ -172,6 +172,10 @@ model = "gpt-4o"
|
||||
# If model is vision capable, this option allows to disable image processing (useful for cost reduction).
|
||||
#disable_vision = true
|
||||
|
||||
# Custom tokenizer to use for token counting
|
||||
# https://docs.litellm.ai/docs/completion/token_usage
|
||||
#custom_tokenizer = ""
|
||||
|
||||
[llm.gpt4o-mini]
|
||||
api_key = "your-api-key"
|
||||
model = "gpt-4o"
|
||||
|
||||
@ -43,6 +43,7 @@ class LLMConfig:
|
||||
log_completions: Whether to log LLM completions to the state.
|
||||
log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
|
||||
draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
|
||||
custom_tokenizer: A custom tokenizer to use for token counting.
|
||||
"""
|
||||
|
||||
model: str = 'claude-3-5-sonnet-20241022'
|
||||
@ -77,6 +78,7 @@ class LLMConfig:
|
||||
log_completions: bool = False
|
||||
log_completions_folder: str = os.path.join(LOG_DIR, 'completions')
|
||||
draft_editor: Optional['LLMConfig'] = None
|
||||
custom_tokenizer: str | None = None
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
|
||||
@ -25,6 +25,7 @@ from litellm.exceptions import (
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
from litellm.types.utils import CostPerToken, ModelResponse, Usage
|
||||
from litellm.utils import create_pretrained_tokenizer
|
||||
|
||||
from openhands.core.exceptions import CloudFlareBlockageError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -122,6 +123,13 @@ class LLM(RetryMixin, DebugMixin):
|
||||
if self.is_function_calling_active():
|
||||
logger.debug('LLM: model supports function calling')
|
||||
|
||||
# if using a custom tokenizer, make sure it's loaded and accessible in the format expected by litellm
|
||||
if self.config.custom_tokenizer is not None:
|
||||
self.tokenizer = create_pretrained_tokenizer(self.config.custom_tokenizer)
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
# set up the completion function
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
model=self.config.model,
|
||||
@ -491,19 +499,43 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
return cur_cost
|
||||
|
||||
def get_token_count(self, messages) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
def get_token_count(self, messages: list[dict] | list[Message]) -> int:
|
||||
"""Get the number of tokens in a list of messages. Use dicts for better token counting.
|
||||
|
||||
Args:
|
||||
messages (list): A list of messages.
|
||||
|
||||
messages (list): A list of messages, either as a list of dicts or as a list of Message objects.
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
# attempt to convert Message objects to dicts, litellm expects dicts
|
||||
if (
|
||||
isinstance(messages, list)
|
||||
and len(messages) > 0
|
||||
and isinstance(messages[0], Message)
|
||||
):
|
||||
logger.info(
|
||||
'Message objects now include serialized tool calls in token counting'
|
||||
)
|
||||
messages = self.format_messages_for_llm(messages) # type: ignore
|
||||
|
||||
# try to get the token count with the default litellm tokenizers
|
||||
# or the custom tokenizer if set for this LLM configuration
|
||||
try:
|
||||
return litellm.token_counter(model=self.config.model, messages=messages)
|
||||
except Exception:
|
||||
# TODO: this is to limit logspam in case token count is not supported
|
||||
return litellm.token_counter(
|
||||
model=self.config.model,
|
||||
messages=messages,
|
||||
custom_tokenizer=self.tokenizer,
|
||||
)
|
||||
except Exception as e:
|
||||
# limit logspam in case token count is not supported
|
||||
logger.error(
|
||||
f'Error getting token count for\n model {self.config.model}\n{e}'
|
||||
+ (
|
||||
f'\ncustom_tokenizer: {self.config.custom_tokenizer}'
|
||||
if self.config.custom_tokenizer is not None
|
||||
else ''
|
||||
)
|
||||
)
|
||||
return 0
|
||||
|
||||
def _is_local(self) -> bool:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import socketio
|
||||
|
||||
@ -9,7 +9,6 @@ from openhands.core.config import AppConfig
|
||||
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.core.schema.config import ConfigType
|
||||
from openhands.events.action import MessageAction, NullAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation import (
|
||||
@ -68,15 +67,28 @@ class Session:
|
||||
)
|
||||
# Extract the agent-relevant arguments from the request
|
||||
agent_cls = session_init_data.agent or self.config.default_agent
|
||||
self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
|
||||
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
|
||||
self.config.security.confirmation_mode = (
|
||||
self.config.security.confirmation_mode
|
||||
if session_init_data.confirmation_mode is None
|
||||
else session_init_data.confirmation_mode
|
||||
)
|
||||
self.config.security.security_analyzer = (
|
||||
session_init_data.security_analyzer
|
||||
or self.config.security.security_analyzer
|
||||
)
|
||||
max_iterations = session_init_data.max_iterations or self.config.max_iterations
|
||||
# override default LLM config
|
||||
|
||||
default_llm_config = self.config.get_llm_config()
|
||||
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
|
||||
default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
|
||||
default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url
|
||||
default_llm_config.model = (
|
||||
session_init_data.llm_model or default_llm_config.model
|
||||
)
|
||||
default_llm_config.api_key = (
|
||||
session_init_data.llm_api_key or default_llm_config.api_key
|
||||
)
|
||||
default_llm_config.base_url = (
|
||||
session_init_data.llm_base_url or default_llm_config.base_url
|
||||
)
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@ -8,6 +6,7 @@ class SessionInitData:
|
||||
"""
|
||||
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
||||
"""
|
||||
|
||||
language: str | None = None
|
||||
agent: str | None = None
|
||||
max_iterations: int | None = None
|
||||
|
||||
@ -428,6 +428,7 @@ def test_api_keys_repr_str():
|
||||
'aws_secret_access_key',
|
||||
'input_cost_per_token',
|
||||
'output_cost_per_token',
|
||||
'custom_tokenizer',
|
||||
]
|
||||
for attr_name in dir(LLMConfig):
|
||||
if (
|
||||
|
||||
@ -11,6 +11,7 @@ from litellm.exceptions import (
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.exceptions import OperationCancelled
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
@ -21,6 +22,7 @@ def mock_logger(monkeypatch):
|
||||
mock_logger = MagicMock()
|
||||
monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger)
|
||||
monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger)
|
||||
monkeypatch.setattr('openhands.llm.llm.logger', mock_logger)
|
||||
return mock_logger
|
||||
|
||||
|
||||
@ -397,3 +399,79 @@ def test_llm_cloudflare_blockage(mock_litellm_completion, default_config):
|
||||
|
||||
# Ensure the completion was called
|
||||
mock_litellm_completion.assert_called_once()
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm.token_counter')
|
||||
def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
|
||||
mock_token_counter.return_value = 42
|
||||
llm = LLM(default_config)
|
||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||
|
||||
token_count = llm.get_token_count(messages)
|
||||
|
||||
assert token_count == 42
|
||||
mock_token_counter.assert_called_once_with(
|
||||
model=default_config.model, messages=messages, custom_tokenizer=None
|
||||
)
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm.token_counter')
|
||||
def test_get_token_count_with_message_objects(
|
||||
mock_token_counter, default_config, mock_logger
|
||||
):
|
||||
llm = LLM(default_config)
|
||||
|
||||
# Create a Message object and its equivalent dict
|
||||
message_obj = Message(role='user', content=[TextContent(text='Hello!')])
|
||||
message_dict = {'role': 'user', 'content': 'Hello!'}
|
||||
|
||||
# Mock token counter to return different values for each call
|
||||
mock_token_counter.side_effect = [42, 42] # Same value for both cases
|
||||
|
||||
# Get token counts for both formats
|
||||
token_count_obj = llm.get_token_count([message_obj])
|
||||
token_count_dict = llm.get_token_count([message_dict])
|
||||
|
||||
# Verify both formats get the same token count
|
||||
assert token_count_obj == token_count_dict
|
||||
assert mock_token_counter.call_count == 2
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm.token_counter')
|
||||
@patch('openhands.llm.llm.create_pretrained_tokenizer')
|
||||
def test_get_token_count_with_custom_tokenizer(
|
||||
mock_create_tokenizer, mock_token_counter, default_config
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_create_tokenizer.return_value = mock_tokenizer
|
||||
mock_token_counter.return_value = 42
|
||||
|
||||
config = copy.deepcopy(default_config)
|
||||
config.custom_tokenizer = 'custom/tokenizer'
|
||||
llm = LLM(config)
|
||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||
|
||||
token_count = llm.get_token_count(messages)
|
||||
|
||||
assert token_count == 42
|
||||
mock_create_tokenizer.assert_called_once_with('custom/tokenizer')
|
||||
mock_token_counter.assert_called_once_with(
|
||||
model=config.model, messages=messages, custom_tokenizer=mock_tokenizer
|
||||
)
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm.token_counter')
|
||||
def test_get_token_count_error_handling(
|
||||
mock_token_counter, default_config, mock_logger
|
||||
):
|
||||
mock_token_counter.side_effect = Exception('Token counting failed')
|
||||
llm = LLM(default_config)
|
||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||
|
||||
token_count = llm.get_token_count(messages)
|
||||
|
||||
assert token_count == 0
|
||||
mock_token_counter.assert_called_once()
|
||||
mock_logger.error.assert_called_once_with(
|
||||
'Error getting token count for\n model gpt-4o\nToken counting failed'
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user