Vision and prompt caching fixes (#4014)

This commit is contained in:
Engel Nyst 2024-09-28 14:37:29 +02:00 committed by GitHub
parent f427f9d8d4
commit e582806004
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 262 additions and 121 deletions

View File

@ -204,11 +204,6 @@ class CodeActAgent(Agent):
],
}
if self.llm.is_caching_prompt_active():
params['extra_headers'] = {
'anthropic-beta': 'prompt-caching-2024-07-31',
}
response = self.llm.completion(**params)
return self.action_parser.parse(response)

View File

@ -146,8 +146,8 @@ model = "gpt-4o"
# Drop any unmapped (unsupported) params without causing an exception
#drop_params = false
# Using the prompt caching feature provided by the LLM
#caching_prompt = false
# Using the prompt caching feature if provided by the LLM and supported
#caching_prompt = true
# Base URL for the OLLAMA API
#ollama_base_url = ""

View File

@ -1,3 +1,4 @@
import os
from dataclasses import dataclass, fields
from openhands.core.config.config_utils import get_field_info
@ -36,7 +37,7 @@ class LLMConfig:
ollama_base_url: The base URL for the OLLAMA API.
drop_params: Drop any unmapped (unsupported) params without causing an exception.
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
caching_prompt: Using the prompt caching feature provided by the LLM.
caching_prompt: Use the prompt caching feature if provided by the LLM and supported by the provider.
log_completions: Whether to log LLM completions to the state.
"""
@ -68,7 +69,7 @@ class LLMConfig:
ollama_base_url: str | None = None
drop_params: bool = True
disable_vision: bool | None = None
caching_prompt: bool = False
caching_prompt: bool = True
log_completions: bool = False
def defaults_to_dict(self) -> dict:
@ -78,6 +79,18 @@ class LLMConfig:
result[f.name] = get_field_info(f)
return result
def __post_init__(self):
"""
Post-initialization hook to assign OpenRouter-related variables to environment variables.
This ensures that these values are accessible to litellm at runtime.
"""
# Assign OpenRouter-specific variables to environment variables
if self.openrouter_site_url:
os.environ['OR_SITE_URL'] = self.openrouter_site_url
if self.openrouter_app_name:
os.environ['OR_APP_NAME'] = self.openrouter_app_name
def __str__(self):
attr_str = []
for f in fields(self):
@ -101,9 +114,3 @@ class LLMConfig:
if k in LLM_SENSITIVE_FIELDS:
ret[k] = '******' if v else None
return ret
def set_missing_attributes(self):
"""Set any missing attributes to their default values."""
for field_name, field_obj in self.__dataclass_fields__.items():
if not hasattr(self, field_name):
setattr(self, field_name, field_obj.default)

View File

@ -50,6 +50,8 @@ class ImageContent(Content):
class Message(BaseModel):
role: Literal['user', 'system', 'assistant']
content: list[TextContent | ImageContent] = Field(default=list)
cache_enabled: bool = False
vision_enabled: bool = False
@property
def contains_image(self) -> bool:
@ -58,23 +60,22 @@ class Message(BaseModel):
@model_serializer
def serialize_model(self) -> dict:
content: list[dict] | str
if self.role == 'system':
# For system role, concatenate all text content into a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
elif self.role == 'assistant' and not self.contains_image:
# For assistant role without vision, concatenate all text content into a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
else:
# For user role or assistant role with vision enabled, serialize each content item
# two kinds of serializer:
# 1. vision serializer: when prompt caching or vision is enabled
# 2. single text serializer: for other cases
# remove this when liteLLM or providers support this format translation
if self.cache_enabled or self.vision_enabled:
# when prompt caching or vision is enabled, use vision serializer
content = []
for item in self.content:
if isinstance(item, TextContent):
content.append(item.model_dump())
elif isinstance(item, ImageContent):
content.extend(item.model_dump())
else:
# for other cases, concatenate all text content
# into a single string per message
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
return {'content': content, 'role': self.role}

View File

@ -0,0 +1,5 @@
from openhands.llm.async_llm import AsyncLLM
from openhands.llm.llm import LLM
from openhands.llm.streaming_llm import StreamingLLM
__all__ = ['LLM', 'AsyncLLM', 'StreamingLLM']

View File

@ -1,11 +1,12 @@
import asyncio
from functools import partial
from typing import Any
from litellm import completion as litellm_acompletion
from openhands.core.exceptions import LLMResponseError, UserCancelledError
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm import LLM
from openhands.llm.llm import LLM, LLM_RETRY_EXCEPTIONS
from openhands.runtime.utils.shutdown_listener import should_continue
@ -33,19 +34,31 @@ class AsyncLLM(LLM):
@self.retry_decorator(
num_retries=self.config.num_retries,
retry_exceptions=self.retry_exceptions,
retry_exceptions=LLM_RETRY_EXCEPTIONS,
retry_min_wait=self.config.retry_min_wait,
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_completion_wrapper(*args, **kwargs):
"""Wrapper for the litellm acompletion function."""
# some callers might just send the messages directly
if 'messages' in kwargs:
messages = kwargs['messages']
else:
messages = args[1] if len(args) > 1 else []
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
# litellm allows positional args, like completion(model, messages, **kwargs)
# see llm.py for more details
if len(args) > 1:
messages = args[1] if len(args) > 1 else args[0]
kwargs['messages'] = messages
# remove the first args, they're sent in kwargs
args = args[2:]
elif 'messages' in kwargs:
messages = kwargs['messages']
# ensure we work with a list of messages
messages = messages if isinstance(messages, list) else [messages]
# if we have no messages, something went very wrong
if not messages:
raise ValueError(
'The messages list is empty. At least one message is required.'
@ -101,7 +114,4 @@ class AsyncLLM(LLM):
@property
def async_completion(self):
"""Decorator for the async litellm acompletion function."""
try:
return self._async_completion
except Exception as e:
raise LLMResponseError(e)
return self._async_completion

View File

@ -1,3 +1,5 @@
from typing import Any
from openhands.core.logger import llm_prompt_logger, llm_response_logger
from openhands.core.logger import openhands_logger as logger
@ -5,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n'
class DebugMixin:
def log_prompt(self, messages):
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
if not messages:
logger.debug('No completion messages!')
return
@ -20,11 +22,11 @@ class DebugMixin:
else:
logger.debug('No completion messages!')
def log_response(self, message_back):
def log_response(self, message_back: str):
if message_back:
llm_response_logger.debug(message_back)
def _format_message_content(self, message):
def _format_message_content(self, message: dict[str, Any]):
content = message['content']
if isinstance(content, list):
return '\n'.join(
@ -32,7 +34,7 @@ class DebugMixin:
)
return str(content)
def _format_content_element(self, element):
def _format_content_element(self, element: dict[str, Any]):
if isinstance(element, dict):
if 'text' in element:
return element['text']
@ -44,10 +46,6 @@ class DebugMixin:
return element['image_url']['url']
return str(element)
def _log_stats(self, stats):
if stats:
logger.info(stats)
# This method should be implemented in the class that uses DebugMixin
def vision_is_active(self):
raise NotImplementedError

View File

@ -1,5 +1,4 @@
import copy
import os
import time
import warnings
from functools import partial
@ -10,16 +9,16 @@ from openhands.core.config import LLMConfig
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from litellm import ModelInfo
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
APIConnectionError,
ContentPolicyViolationError,
InternalServerError,
OpenAIError,
RateLimitError,
ServiceUnavailableError,
)
from litellm.types.utils import CostPerToken
from litellm.types.utils import CostPerToken, ModelResponse, Usage
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
@ -29,9 +28,23 @@ from openhands.llm.retry_mixin import RetryMixin
__all__ = ['LLM']
cache_prompting_supported_models = [
# tuple of exceptions to retry on
LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
APIConnectionError,
InternalServerError,
RateLimitError,
ServiceUnavailableError,
)
# cache prompt supporting models
# remove this when we gemini and deepseek are supported
CACHE_PROMPT_SUPPORTED_MODELS = [
'claude-3-5-sonnet-20240620',
'claude-3-haiku-20240307',
'claude-3-opus-20240229',
'anthropic/claude-3-opus-20240229',
'anthropic/claude-3-haiku-20240307',
'anthropic/claude-3-5-sonnet-20240620',
]
@ -55,23 +68,17 @@ class LLM(RetryMixin, DebugMixin):
config: The LLM configuration.
metrics: The metrics to use.
"""
self.metrics = metrics if metrics is not None else Metrics()
self.cost_metric_supported = True
self.config = copy.deepcopy(config)
os.environ['OR_SITE_URL'] = self.config.openrouter_site_url
os.environ['OR_APP_NAME'] = self.config.openrouter_app_name
self.metrics: Metrics = metrics if metrics is not None else Metrics()
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
# - 'messages': list of messages
# - 'response': response from the LLM
self.llm_completions: list[dict[str, Any]] = []
# Set up config attributes with default values to prevent AttributeError
LLMConfig.set_missing_attributes(self.config)
# litellm actually uses base Exception here for unknown model
self.model_info = None
self.model_info: ModelInfo | None = None
try:
if self.config.model.startswith('openrouter'):
self.model_info = litellm.get_model_info(self.config.model)
@ -83,15 +90,6 @@ class LLM(RetryMixin, DebugMixin):
except Exception as e:
logger.warning(f'Could not get model info for {config.model}:\n{e}')
# Tuple of exceptions to retry on
self.retry_exceptions = (
APIConnectionError,
ContentPolicyViolationError,
InternalServerError,
OpenAIError,
RateLimitError,
)
# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
@ -135,23 +133,39 @@ class LLM(RetryMixin, DebugMixin):
if self.vision_is_active():
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
logger.debug('LLM: caching prompt enabled')
completion_unwrapped = self._completion
@self.retry_decorator(
num_retries=self.config.num_retries,
retry_exceptions=self.retry_exceptions,
retry_exceptions=LLM_RETRY_EXCEPTIONS,
retry_min_wait=self.config.retry_min_wait,
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
def wrapper(*args, **kwargs):
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
# some callers might just send the messages directly
if 'messages' in kwargs:
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
# litellm allows positional args, like completion(model, messages, **kwargs)
if len(args) > 1:
# ignore the first argument if it's provided (it would be the model)
# design wise: we don't allow overriding the configured values
# implementation wise: the partial function set the model as a kwarg already
# as well as other kwargs
messages = args[1] if len(args) > 1 else args[0]
kwargs['messages'] = messages
# remove the first args, they're sent in kwargs
args = args[2:]
elif 'messages' in kwargs:
messages = kwargs['messages']
else:
messages = args[1] if len(args) > 1 else []
# ensure we work with a list of messages
messages = messages if isinstance(messages, list) else [messages]
# if we have no messages, something went very wrong
if not messages:
@ -169,7 +183,8 @@ class LLM(RetryMixin, DebugMixin):
'anthropic-beta': 'prompt-caching-2024-07-31',
}
resp = completion_unwrapped(*args, **kwargs)
# we don't support streaming here, thus we get a ModelResponse
resp: ModelResponse = completion_unwrapped(*args, **kwargs)
# log for evals or other scripts that need the raw completion
if self.config.log_completions:
@ -182,7 +197,7 @@ class LLM(RetryMixin, DebugMixin):
}
)
message_back = resp['choices'][0]['message']['content']
message_back: str = resp['choices'][0]['message']['content']
# log the LLM response
self.log_response(message_back)
@ -211,22 +226,29 @@ class LLM(RetryMixin, DebugMixin):
Returns:
bool: True if model is vision capable. If model is not supported by litellm, it will return False.
"""
try:
return litellm.supports_vision(self.config.model)
except Exception:
return False
def is_caching_prompt_active(self) -> bool:
"""Check if prompt caching is enabled and supported for current model.
Returns:
boolean: True if prompt caching is active for the given model.
"""
return self.config.caching_prompt is True and any(
model in self.config.model for model in cache_prompting_supported_models
# litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes)
# but model_info will have the correct value for some reason.
# we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
return litellm.supports_vision(self.config.model) or (
self.model_info is not None
and self.model_info.get('supports_vision', False)
)
def _post_completion(self, response) -> None:
def is_caching_prompt_active(self) -> bool:
"""Check if prompt caching is supported and enabled for current model.
Returns:
boolean: True if prompt caching is supported and enabled for the given model.
"""
return (
self.config.caching_prompt is True
and self.model_info is not None
and self.model_info.get('supports_prompt_caching', False)
and self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
)
def _post_completion(self, response: ModelResponse) -> None:
"""Post-process the completion response.
Logs the cost and usage stats of the completion call.
@ -244,7 +266,7 @@ class LLM(RetryMixin, DebugMixin):
self.metrics.accumulated_cost,
)
usage = response.get('usage')
usage: Usage | None = response.get('usage')
if usage:
# keep track of the input and output tokens
@ -366,5 +388,12 @@ class LLM(RetryMixin, DebugMixin):
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
if isinstance(messages, Message):
return [messages.model_dump()]
messages = [messages]
# set flags to know how to serialize the messages
for message in messages:
message.cache_enabled = self.is_caching_prompt_active()
message.vision_enabled = self.vision_is_active()
# let pydantic handle the serialization
return [message.model_dump() for message in messages]

View File

@ -1,9 +1,10 @@
import asyncio
from functools import partial
from typing import Any
from openhands.core.exceptions import LLMResponseError, UserCancelledError
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
from openhands.llm.async_llm import AsyncLLM
from openhands.llm.async_llm import LLM_RETRY_EXCEPTIONS, AsyncLLM
class StreamingLLM(AsyncLLM):
@ -31,18 +32,30 @@ class StreamingLLM(AsyncLLM):
@self.retry_decorator(
num_retries=self.config.num_retries,
retry_exceptions=self.retry_exceptions,
retry_exceptions=LLM_RETRY_EXCEPTIONS,
retry_min_wait=self.config.retry_min_wait,
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_streaming_completion_wrapper(*args, **kwargs):
# some callers might just send the messages directly
if 'messages' in kwargs:
messages = kwargs['messages']
else:
messages = args[1] if len(args) > 1 else []
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
# litellm allows positional args, like completion(model, messages, **kwargs)
# see llm.py for more details
if len(args) > 1:
messages = args[1] if len(args) > 1 else args[0]
kwargs['messages'] = messages
# remove the first args, they're sent in kwargs
args = args[2:]
elif 'messages' in kwargs:
messages = kwargs['messages']
# ensure we work with a list of messages
messages = messages if isinstance(messages, list) else [messages]
# if we have no messages, something went very wrong
if not messages:
raise ValueError(
'The messages list is empty. At least one message is required.'
@ -90,7 +103,4 @@ class StreamingLLM(AsyncLLM):
@property
def async_streaming_completion(self):
"""Decorator for the async litellm acompletion function with streaming."""
try:
return self._async_streaming_completion
except Exception as e:
raise LLMResponseError(e)
return self._async_streaming_completion

View File

@ -3,10 +3,9 @@ from unittest.mock import MagicMock, patch
import pytest
from litellm.exceptions import (
APIConnectionError,
ContentPolicyViolationError,
InternalServerError,
OpenAIError,
RateLimitError,
ServiceUnavailableError,
)
from openhands.core.config import LLMConfig
@ -138,17 +137,16 @@ def test_completion_with_mocked_logger(
{'llm_provider': 'test_provider', 'model': 'test_model'},
2,
),
(
ContentPolicyViolationError,
{'model': 'test_model', 'llm_provider': 'test_provider'},
2,
),
(
InternalServerError,
{'llm_provider': 'test_provider', 'model': 'test_model'},
2,
),
(OpenAIError, {}, 2),
(
ServiceUnavailableError,
{'llm_provider': 'test_provider', 'model': 'test_model'},
2,
),
(RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
],
)
@ -298,3 +296,39 @@ def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
assert call_args['model'] == default_config.model
assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
assert not call_args['stream']
@patch('openhands.llm.llm.litellm_completion')
def test_completion_with_two_positional_args(mock_litellm_completion, default_config):
mock_response = {
'choices': [{'message': {'content': 'Response to positional args.'}}]
}
mock_litellm_completion.return_value = mock_response
test_llm = LLM(config=default_config)
response = test_llm.completion(
'some-model-to-be-ignored',
[{'role': 'user', 'content': 'Hello from positional args!'}],
stream=False,
)
# Assertions
assert (
response['choices'][0]['message']['content'] == 'Response to positional args.'
)
mock_litellm_completion.assert_called_once()
# Check if the correct arguments were passed to litellm_completion
call_args, call_kwargs = mock_litellm_completion.call_args
assert (
call_kwargs['model'] == default_config.model
) # Should use the model from config, not the first arg
assert call_kwargs['messages'] == [
{'role': 'user', 'content': 'Hello from positional args!'}
]
assert not call_kwargs['stream']
# Ensure the first positional argument (model) was ignored
assert (
len(call_args) == 0
) # No positional args should be passed to litellm_completion here

View File

@ -1,7 +1,7 @@
from openhands.core.message import ImageContent, Message, TextContent
def test_message_serialization():
def test_message_with_vision_enabled():
text_content1 = TextContent(text='This is a text message')
image_content1 = ImageContent(
image_urls=['http://example.com/image1.png', 'http://example.com/image2.png']
@ -11,11 +11,12 @@ def test_message_serialization():
image_urls=['http://example.com/image3.png', 'http://example.com/image4.png']
)
message = Message(
message: Message = Message(
role='user',
content=[text_content1, image_content1, text_content2, image_content2],
vision_enabled=True,
)
serialized_message = message.serialize_model()
serialized_message: dict = message.serialize_model()
expected_serialized_message = {
'role': 'user',
@ -45,12 +46,14 @@ def test_message_serialization():
assert message.contains_image is True
def test_message_with_only_text_content():
def test_message_with_only_text_content_and_vision_enabled():
text_content1 = TextContent(text='This is a text message')
text_content2 = TextContent(text='This is another text message')
message = Message(role='user', content=[text_content1, text_content2])
serialized_message = message.serialize_model()
message: Message = Message(
role='user', content=[text_content1, text_content2], vision_enabled=True
)
serialized_message: dict = message.serialize_model()
expected_serialized_message = {
'role': 'user',
@ -62,3 +65,52 @@ def test_message_with_only_text_content():
assert serialized_message == expected_serialized_message
assert message.contains_image is False
def test_message_with_only_text_content_and_vision_disabled():
text_content1 = TextContent(text='This is a text message')
text_content2 = TextContent(text='This is another text message')
message: Message = Message(
role='user', content=[text_content1, text_content2], vision_enabled=False
)
serialized_message: dict = message.serialize_model()
expected_serialized_message = {
'role': 'user',
'content': 'This is a text message\nThis is another text message',
}
assert serialized_message == expected_serialized_message
assert message.contains_image is False
def test_message_with_mixed_content_and_vision_disabled():
# Create a message with both text and image content
text_content1 = TextContent(text='This is a text message')
image_content1 = ImageContent(
image_urls=['http://example.com/image1.png', 'http://example.com/image2.png']
)
text_content2 = TextContent(text='This is another text message')
image_content2 = ImageContent(
image_urls=['http://example.com/image3.png', 'http://example.com/image4.png']
)
# Initialize Message with vision disabled
message: Message = Message(
role='user',
content=[text_content1, image_content1, text_content2, image_content2],
vision_enabled=False,
)
serialized_message: dict = message.serialize_model()
# Expected serialization ignores images and concatenates text
expected_serialized_message = {
'role': 'user',
'content': 'This is a text message\nThis is another text message',
}
# Assert serialized message matches expectation
assert serialized_message == expected_serialized_message
# Assert that images exist in the original message
assert message.contains_image