mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Vision and prompt caching fixes (#4014)
This commit is contained in:
parent
f427f9d8d4
commit
e582806004
@ -204,11 +204,6 @@ class CodeActAgent(Agent):
|
||||
],
|
||||
}
|
||||
|
||||
if self.llm.is_caching_prompt_active():
|
||||
params['extra_headers'] = {
|
||||
'anthropic-beta': 'prompt-caching-2024-07-31',
|
||||
}
|
||||
|
||||
response = self.llm.completion(**params)
|
||||
|
||||
return self.action_parser.parse(response)
|
||||
|
||||
@ -146,8 +146,8 @@ model = "gpt-4o"
|
||||
# Drop any unmapped (unsupported) params without causing an exception
|
||||
#drop_params = false
|
||||
|
||||
# Using the prompt caching feature provided by the LLM
|
||||
#caching_prompt = false
|
||||
# Using the prompt caching feature if provided by the LLM and supported
|
||||
#caching_prompt = true
|
||||
|
||||
# Base URL for the OLLAMA API
|
||||
#ollama_base_url = ""
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
from openhands.core.config.config_utils import get_field_info
|
||||
@ -36,7 +37,7 @@ class LLMConfig:
|
||||
ollama_base_url: The base URL for the OLLAMA API.
|
||||
drop_params: Drop any unmapped (unsupported) params without causing an exception.
|
||||
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
|
||||
caching_prompt: Using the prompt caching feature provided by the LLM.
|
||||
caching_prompt: Use the prompt caching feature if provided by the LLM and supported by the provider.
|
||||
log_completions: Whether to log LLM completions to the state.
|
||||
"""
|
||||
|
||||
@ -68,7 +69,7 @@ class LLMConfig:
|
||||
ollama_base_url: str | None = None
|
||||
drop_params: bool = True
|
||||
disable_vision: bool | None = None
|
||||
caching_prompt: bool = False
|
||||
caching_prompt: bool = True
|
||||
log_completions: bool = False
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
@ -78,6 +79,18 @@ class LLMConfig:
|
||||
result[f.name] = get_field_info(f)
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization hook to assign OpenRouter-related variables to environment variables.
|
||||
This ensures that these values are accessible to litellm at runtime.
|
||||
"""
|
||||
|
||||
# Assign OpenRouter-specific variables to environment variables
|
||||
if self.openrouter_site_url:
|
||||
os.environ['OR_SITE_URL'] = self.openrouter_site_url
|
||||
if self.openrouter_app_name:
|
||||
os.environ['OR_APP_NAME'] = self.openrouter_app_name
|
||||
|
||||
def __str__(self):
|
||||
attr_str = []
|
||||
for f in fields(self):
|
||||
@ -101,9 +114,3 @@ class LLMConfig:
|
||||
if k in LLM_SENSITIVE_FIELDS:
|
||||
ret[k] = '******' if v else None
|
||||
return ret
|
||||
|
||||
def set_missing_attributes(self):
|
||||
"""Set any missing attributes to their default values."""
|
||||
for field_name, field_obj in self.__dataclass_fields__.items():
|
||||
if not hasattr(self, field_name):
|
||||
setattr(self, field_name, field_obj.default)
|
||||
|
||||
@ -50,6 +50,8 @@ class ImageContent(Content):
|
||||
class Message(BaseModel):
|
||||
role: Literal['user', 'system', 'assistant']
|
||||
content: list[TextContent | ImageContent] = Field(default=list)
|
||||
cache_enabled: bool = False
|
||||
vision_enabled: bool = False
|
||||
|
||||
@property
|
||||
def contains_image(self) -> bool:
|
||||
@ -58,23 +60,22 @@ class Message(BaseModel):
|
||||
@model_serializer
|
||||
def serialize_model(self) -> dict:
|
||||
content: list[dict] | str
|
||||
if self.role == 'system':
|
||||
# For system role, concatenate all text content into a single string
|
||||
content = '\n'.join(
|
||||
item.text for item in self.content if isinstance(item, TextContent)
|
||||
)
|
||||
elif self.role == 'assistant' and not self.contains_image:
|
||||
# For assistant role without vision, concatenate all text content into a single string
|
||||
content = '\n'.join(
|
||||
item.text for item in self.content if isinstance(item, TextContent)
|
||||
)
|
||||
else:
|
||||
# For user role or assistant role with vision enabled, serialize each content item
|
||||
# two kinds of serializer:
|
||||
# 1. vision serializer: when prompt caching or vision is enabled
|
||||
# 2. single text serializer: for other cases
|
||||
# remove this when liteLLM or providers support this format translation
|
||||
if self.cache_enabled or self.vision_enabled:
|
||||
# when prompt caching or vision is enabled, use vision serializer
|
||||
content = []
|
||||
for item in self.content:
|
||||
if isinstance(item, TextContent):
|
||||
content.append(item.model_dump())
|
||||
elif isinstance(item, ImageContent):
|
||||
content.extend(item.model_dump())
|
||||
|
||||
else:
|
||||
# for other cases, concatenate all text content
|
||||
# into a single string per message
|
||||
content = '\n'.join(
|
||||
item.text for item in self.content if isinstance(item, TextContent)
|
||||
)
|
||||
return {'content': content, 'role': self.role}
|
||||
|
||||
5
openhands/llm/__init__.py
Normal file
5
openhands/llm/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from openhands.llm.async_llm import AsyncLLM
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.streaming_llm import StreamingLLM
|
||||
|
||||
__all__ = ['LLM', 'AsyncLLM', 'StreamingLLM']
|
||||
@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from litellm import completion as litellm_acompletion
|
||||
|
||||
from openhands.core.exceptions import LLMResponseError, UserCancelledError
|
||||
from openhands.core.exceptions import UserCancelledError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm import LLM, LLM_RETRY_EXCEPTIONS
|
||||
from openhands.runtime.utils.shutdown_listener import should_continue
|
||||
|
||||
|
||||
@ -33,19 +34,31 @@ class AsyncLLM(LLM):
|
||||
|
||||
@self.retry_decorator(
|
||||
num_retries=self.config.num_retries,
|
||||
retry_exceptions=self.retry_exceptions,
|
||||
retry_exceptions=LLM_RETRY_EXCEPTIONS,
|
||||
retry_min_wait=self.config.retry_min_wait,
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
)
|
||||
async def async_completion_wrapper(*args, **kwargs):
|
||||
"""Wrapper for the litellm acompletion function."""
|
||||
# some callers might just send the messages directly
|
||||
if 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
else:
|
||||
messages = args[1] if len(args) > 1 else []
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
# litellm allows positional args, like completion(model, messages, **kwargs)
|
||||
# see llm.py for more details
|
||||
if len(args) > 1:
|
||||
messages = args[1] if len(args) > 1 else args[0]
|
||||
kwargs['messages'] = messages
|
||||
|
||||
# remove the first args, they're sent in kwargs
|
||||
args = args[2:]
|
||||
elif 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
|
||||
# ensure we work with a list of messages
|
||||
messages = messages if isinstance(messages, list) else [messages]
|
||||
|
||||
# if we have no messages, something went very wrong
|
||||
if not messages:
|
||||
raise ValueError(
|
||||
'The messages list is empty. At least one message is required.'
|
||||
@ -101,7 +114,4 @@ class AsyncLLM(LLM):
|
||||
@property
|
||||
def async_completion(self):
|
||||
"""Decorator for the async litellm acompletion function."""
|
||||
try:
|
||||
return self._async_completion
|
||||
except Exception as e:
|
||||
raise LLMResponseError(e)
|
||||
return self._async_completion
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.logger import llm_prompt_logger, llm_response_logger
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@ -5,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n'
|
||||
|
||||
|
||||
class DebugMixin:
|
||||
def log_prompt(self, messages):
|
||||
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
|
||||
if not messages:
|
||||
logger.debug('No completion messages!')
|
||||
return
|
||||
@ -20,11 +22,11 @@ class DebugMixin:
|
||||
else:
|
||||
logger.debug('No completion messages!')
|
||||
|
||||
def log_response(self, message_back):
|
||||
def log_response(self, message_back: str):
|
||||
if message_back:
|
||||
llm_response_logger.debug(message_back)
|
||||
|
||||
def _format_message_content(self, message):
|
||||
def _format_message_content(self, message: dict[str, Any]):
|
||||
content = message['content']
|
||||
if isinstance(content, list):
|
||||
return '\n'.join(
|
||||
@ -32,7 +34,7 @@ class DebugMixin:
|
||||
)
|
||||
return str(content)
|
||||
|
||||
def _format_content_element(self, element):
|
||||
def _format_content_element(self, element: dict[str, Any]):
|
||||
if isinstance(element, dict):
|
||||
if 'text' in element:
|
||||
return element['text']
|
||||
@ -44,10 +46,6 @@ class DebugMixin:
|
||||
return element['image_url']['url']
|
||||
return str(element)
|
||||
|
||||
def _log_stats(self, stats):
|
||||
if stats:
|
||||
logger.info(stats)
|
||||
|
||||
# This method should be implemented in the class that uses DebugMixin
|
||||
def vision_is_active(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
@ -10,16 +9,16 @@ from openhands.core.config import LLMConfig
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
import litellm
|
||||
from litellm import ModelInfo
|
||||
from litellm import completion as litellm_completion
|
||||
from litellm import completion_cost as litellm_completion_cost
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
ContentPolicyViolationError,
|
||||
InternalServerError,
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
from litellm.types.utils import CostPerToken
|
||||
from litellm.types.utils import CostPerToken, ModelResponse, Usage
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message
|
||||
@ -29,9 +28,23 @@ from openhands.llm.retry_mixin import RetryMixin
|
||||
|
||||
__all__ = ['LLM']
|
||||
|
||||
cache_prompting_supported_models = [
|
||||
# tuple of exceptions to retry on
|
||||
LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
|
||||
APIConnectionError,
|
||||
InternalServerError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
|
||||
# cache prompt supporting models
|
||||
# remove this when we gemini and deepseek are supported
|
||||
CACHE_PROMPT_SUPPORTED_MODELS = [
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-3-haiku-20240307',
|
||||
'claude-3-opus-20240229',
|
||||
'anthropic/claude-3-opus-20240229',
|
||||
'anthropic/claude-3-haiku-20240307',
|
||||
'anthropic/claude-3-5-sonnet-20240620',
|
||||
]
|
||||
|
||||
|
||||
@ -55,23 +68,17 @@ class LLM(RetryMixin, DebugMixin):
|
||||
config: The LLM configuration.
|
||||
metrics: The metrics to use.
|
||||
"""
|
||||
self.metrics = metrics if metrics is not None else Metrics()
|
||||
self.cost_metric_supported = True
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
os.environ['OR_SITE_URL'] = self.config.openrouter_site_url
|
||||
os.environ['OR_APP_NAME'] = self.config.openrouter_app_name
|
||||
self.metrics: Metrics = metrics if metrics is not None else Metrics()
|
||||
self.cost_metric_supported: bool = True
|
||||
self.config: LLMConfig = copy.deepcopy(config)
|
||||
|
||||
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
|
||||
# - 'messages': list of messages
|
||||
# - 'response': response from the LLM
|
||||
self.llm_completions: list[dict[str, Any]] = []
|
||||
|
||||
# Set up config attributes with default values to prevent AttributeError
|
||||
LLMConfig.set_missing_attributes(self.config)
|
||||
|
||||
# litellm actually uses base Exception here for unknown model
|
||||
self.model_info = None
|
||||
self.model_info: ModelInfo | None = None
|
||||
try:
|
||||
if self.config.model.startswith('openrouter'):
|
||||
self.model_info = litellm.get_model_info(self.config.model)
|
||||
@ -83,15 +90,6 @@ class LLM(RetryMixin, DebugMixin):
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not get model info for {config.model}:\n{e}')
|
||||
|
||||
# Tuple of exceptions to retry on
|
||||
self.retry_exceptions = (
|
||||
APIConnectionError,
|
||||
ContentPolicyViolationError,
|
||||
InternalServerError,
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
)
|
||||
|
||||
# Set the max tokens in an LM-specific way if not set
|
||||
if self.config.max_input_tokens is None:
|
||||
if (
|
||||
@ -135,23 +133,39 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
if self.vision_is_active():
|
||||
logger.debug('LLM: model has vision enabled')
|
||||
if self.is_caching_prompt_active():
|
||||
logger.debug('LLM: caching prompt enabled')
|
||||
|
||||
completion_unwrapped = self._completion
|
||||
|
||||
@self.retry_decorator(
|
||||
num_retries=self.config.num_retries,
|
||||
retry_exceptions=self.retry_exceptions,
|
||||
retry_exceptions=LLM_RETRY_EXCEPTIONS,
|
||||
retry_min_wait=self.config.retry_min_wait,
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
# some callers might just send the messages directly
|
||||
if 'messages' in kwargs:
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
# litellm allows positional args, like completion(model, messages, **kwargs)
|
||||
if len(args) > 1:
|
||||
# ignore the first argument if it's provided (it would be the model)
|
||||
# design wise: we don't allow overriding the configured values
|
||||
# implementation wise: the partial function set the model as a kwarg already
|
||||
# as well as other kwargs
|
||||
messages = args[1] if len(args) > 1 else args[0]
|
||||
kwargs['messages'] = messages
|
||||
|
||||
# remove the first args, they're sent in kwargs
|
||||
args = args[2:]
|
||||
elif 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
else:
|
||||
messages = args[1] if len(args) > 1 else []
|
||||
|
||||
# ensure we work with a list of messages
|
||||
messages = messages if isinstance(messages, list) else [messages]
|
||||
|
||||
# if we have no messages, something went very wrong
|
||||
if not messages:
|
||||
@ -169,7 +183,8 @@ class LLM(RetryMixin, DebugMixin):
|
||||
'anthropic-beta': 'prompt-caching-2024-07-31',
|
||||
}
|
||||
|
||||
resp = completion_unwrapped(*args, **kwargs)
|
||||
# we don't support streaming here, thus we get a ModelResponse
|
||||
resp: ModelResponse = completion_unwrapped(*args, **kwargs)
|
||||
|
||||
# log for evals or other scripts that need the raw completion
|
||||
if self.config.log_completions:
|
||||
@ -182,7 +197,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
}
|
||||
)
|
||||
|
||||
message_back = resp['choices'][0]['message']['content']
|
||||
message_back: str = resp['choices'][0]['message']['content']
|
||||
|
||||
# log the LLM response
|
||||
self.log_response(message_back)
|
||||
@ -211,22 +226,29 @@ class LLM(RetryMixin, DebugMixin):
|
||||
Returns:
|
||||
bool: True if model is vision capable. If model is not supported by litellm, it will return False.
|
||||
"""
|
||||
try:
|
||||
return litellm.supports_vision(self.config.model)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def is_caching_prompt_active(self) -> bool:
|
||||
"""Check if prompt caching is enabled and supported for current model.
|
||||
|
||||
Returns:
|
||||
boolean: True if prompt caching is active for the given model.
|
||||
"""
|
||||
return self.config.caching_prompt is True and any(
|
||||
model in self.config.model for model in cache_prompting_supported_models
|
||||
# litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes)
|
||||
# but model_info will have the correct value for some reason.
|
||||
# we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers
|
||||
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
|
||||
return litellm.supports_vision(self.config.model) or (
|
||||
self.model_info is not None
|
||||
and self.model_info.get('supports_vision', False)
|
||||
)
|
||||
|
||||
def _post_completion(self, response) -> None:
|
||||
def is_caching_prompt_active(self) -> bool:
|
||||
"""Check if prompt caching is supported and enabled for current model.
|
||||
|
||||
Returns:
|
||||
boolean: True if prompt caching is supported and enabled for the given model.
|
||||
"""
|
||||
return (
|
||||
self.config.caching_prompt is True
|
||||
and self.model_info is not None
|
||||
and self.model_info.get('supports_prompt_caching', False)
|
||||
and self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
|
||||
)
|
||||
|
||||
def _post_completion(self, response: ModelResponse) -> None:
|
||||
"""Post-process the completion response.
|
||||
|
||||
Logs the cost and usage stats of the completion call.
|
||||
@ -244,7 +266,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self.metrics.accumulated_cost,
|
||||
)
|
||||
|
||||
usage = response.get('usage')
|
||||
usage: Usage | None = response.get('usage')
|
||||
|
||||
if usage:
|
||||
# keep track of the input and output tokens
|
||||
@ -366,5 +388,12 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
|
||||
if isinstance(messages, Message):
|
||||
return [messages.model_dump()]
|
||||
messages = [messages]
|
||||
|
||||
# set flags to know how to serialize the messages
|
||||
for message in messages:
|
||||
message.cache_enabled = self.is_caching_prompt_active()
|
||||
message.vision_enabled = self.vision_is_active()
|
||||
|
||||
# let pydantic handle the serialization
|
||||
return [message.model_dump() for message in messages]
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.exceptions import LLMResponseError, UserCancelledError
|
||||
from openhands.core.exceptions import UserCancelledError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm.async_llm import AsyncLLM
|
||||
from openhands.llm.async_llm import LLM_RETRY_EXCEPTIONS, AsyncLLM
|
||||
|
||||
|
||||
class StreamingLLM(AsyncLLM):
|
||||
@ -31,18 +32,30 @@ class StreamingLLM(AsyncLLM):
|
||||
|
||||
@self.retry_decorator(
|
||||
num_retries=self.config.num_retries,
|
||||
retry_exceptions=self.retry_exceptions,
|
||||
retry_exceptions=LLM_RETRY_EXCEPTIONS,
|
||||
retry_min_wait=self.config.retry_min_wait,
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
)
|
||||
async def async_streaming_completion_wrapper(*args, **kwargs):
|
||||
# some callers might just send the messages directly
|
||||
if 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
else:
|
||||
messages = args[1] if len(args) > 1 else []
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
# litellm allows positional args, like completion(model, messages, **kwargs)
|
||||
# see llm.py for more details
|
||||
if len(args) > 1:
|
||||
messages = args[1] if len(args) > 1 else args[0]
|
||||
kwargs['messages'] = messages
|
||||
|
||||
# remove the first args, they're sent in kwargs
|
||||
args = args[2:]
|
||||
elif 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
|
||||
# ensure we work with a list of messages
|
||||
messages = messages if isinstance(messages, list) else [messages]
|
||||
|
||||
# if we have no messages, something went very wrong
|
||||
if not messages:
|
||||
raise ValueError(
|
||||
'The messages list is empty. At least one message is required.'
|
||||
@ -90,7 +103,4 @@ class StreamingLLM(AsyncLLM):
|
||||
@property
|
||||
def async_streaming_completion(self):
|
||||
"""Decorator for the async litellm acompletion function with streaming."""
|
||||
try:
|
||||
return self._async_streaming_completion
|
||||
except Exception as e:
|
||||
raise LLMResponseError(e)
|
||||
return self._async_streaming_completion
|
||||
|
||||
@ -3,10 +3,9 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
ContentPolicyViolationError,
|
||||
InternalServerError,
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
@ -138,17 +137,16 @@ def test_completion_with_mocked_logger(
|
||||
{'llm_provider': 'test_provider', 'model': 'test_model'},
|
||||
2,
|
||||
),
|
||||
(
|
||||
ContentPolicyViolationError,
|
||||
{'model': 'test_model', 'llm_provider': 'test_provider'},
|
||||
2,
|
||||
),
|
||||
(
|
||||
InternalServerError,
|
||||
{'llm_provider': 'test_provider', 'model': 'test_model'},
|
||||
2,
|
||||
),
|
||||
(OpenAIError, {}, 2),
|
||||
(
|
||||
ServiceUnavailableError,
|
||||
{'llm_provider': 'test_provider', 'model': 'test_model'},
|
||||
2,
|
||||
),
|
||||
(RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
|
||||
],
|
||||
)
|
||||
@ -298,3 +296,39 @@ def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
|
||||
assert call_args['model'] == default_config.model
|
||||
assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
|
||||
assert not call_args['stream']
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_completion_with_two_positional_args(mock_litellm_completion, default_config):
|
||||
mock_response = {
|
||||
'choices': [{'message': {'content': 'Response to positional args.'}}]
|
||||
}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
|
||||
test_llm = LLM(config=default_config)
|
||||
response = test_llm.completion(
|
||||
'some-model-to-be-ignored',
|
||||
[{'role': 'user', 'content': 'Hello from positional args!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert (
|
||||
response['choices'][0]['message']['content'] == 'Response to positional args.'
|
||||
)
|
||||
mock_litellm_completion.assert_called_once()
|
||||
|
||||
# Check if the correct arguments were passed to litellm_completion
|
||||
call_args, call_kwargs = mock_litellm_completion.call_args
|
||||
assert (
|
||||
call_kwargs['model'] == default_config.model
|
||||
) # Should use the model from config, not the first arg
|
||||
assert call_kwargs['messages'] == [
|
||||
{'role': 'user', 'content': 'Hello from positional args!'}
|
||||
]
|
||||
assert not call_kwargs['stream']
|
||||
|
||||
# Ensure the first positional argument (model) was ignored
|
||||
assert (
|
||||
len(call_args) == 0
|
||||
) # No positional args should be passed to litellm_completion here
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
|
||||
|
||||
def test_message_serialization():
|
||||
def test_message_with_vision_enabled():
|
||||
text_content1 = TextContent(text='This is a text message')
|
||||
image_content1 = ImageContent(
|
||||
image_urls=['http://example.com/image1.png', 'http://example.com/image2.png']
|
||||
@ -11,11 +11,12 @@ def test_message_serialization():
|
||||
image_urls=['http://example.com/image3.png', 'http://example.com/image4.png']
|
||||
)
|
||||
|
||||
message = Message(
|
||||
message: Message = Message(
|
||||
role='user',
|
||||
content=[text_content1, image_content1, text_content2, image_content2],
|
||||
vision_enabled=True,
|
||||
)
|
||||
serialized_message = message.serialize_model()
|
||||
serialized_message: dict = message.serialize_model()
|
||||
|
||||
expected_serialized_message = {
|
||||
'role': 'user',
|
||||
@ -45,12 +46,14 @@ def test_message_serialization():
|
||||
assert message.contains_image is True
|
||||
|
||||
|
||||
def test_message_with_only_text_content():
|
||||
def test_message_with_only_text_content_and_vision_enabled():
|
||||
text_content1 = TextContent(text='This is a text message')
|
||||
text_content2 = TextContent(text='This is another text message')
|
||||
|
||||
message = Message(role='user', content=[text_content1, text_content2])
|
||||
serialized_message = message.serialize_model()
|
||||
message: Message = Message(
|
||||
role='user', content=[text_content1, text_content2], vision_enabled=True
|
||||
)
|
||||
serialized_message: dict = message.serialize_model()
|
||||
|
||||
expected_serialized_message = {
|
||||
'role': 'user',
|
||||
@ -62,3 +65,52 @@ def test_message_with_only_text_content():
|
||||
|
||||
assert serialized_message == expected_serialized_message
|
||||
assert message.contains_image is False
|
||||
|
||||
|
||||
def test_message_with_only_text_content_and_vision_disabled():
|
||||
text_content1 = TextContent(text='This is a text message')
|
||||
text_content2 = TextContent(text='This is another text message')
|
||||
|
||||
message: Message = Message(
|
||||
role='user', content=[text_content1, text_content2], vision_enabled=False
|
||||
)
|
||||
serialized_message: dict = message.serialize_model()
|
||||
|
||||
expected_serialized_message = {
|
||||
'role': 'user',
|
||||
'content': 'This is a text message\nThis is another text message',
|
||||
}
|
||||
|
||||
assert serialized_message == expected_serialized_message
|
||||
assert message.contains_image is False
|
||||
|
||||
|
||||
def test_message_with_mixed_content_and_vision_disabled():
|
||||
# Create a message with both text and image content
|
||||
text_content1 = TextContent(text='This is a text message')
|
||||
image_content1 = ImageContent(
|
||||
image_urls=['http://example.com/image1.png', 'http://example.com/image2.png']
|
||||
)
|
||||
text_content2 = TextContent(text='This is another text message')
|
||||
image_content2 = ImageContent(
|
||||
image_urls=['http://example.com/image3.png', 'http://example.com/image4.png']
|
||||
)
|
||||
|
||||
# Initialize Message with vision disabled
|
||||
message: Message = Message(
|
||||
role='user',
|
||||
content=[text_content1, image_content1, text_content2, image_content2],
|
||||
vision_enabled=False,
|
||||
)
|
||||
serialized_message: dict = message.serialize_model()
|
||||
|
||||
# Expected serialization ignores images and concatenates text
|
||||
expected_serialized_message = {
|
||||
'role': 'user',
|
||||
'content': 'This is a text message\nThis is another text message',
|
||||
}
|
||||
|
||||
# Assert serialized message matches expectation
|
||||
assert serialized_message == expected_serialized_message
|
||||
# Assert that images exist in the original message
|
||||
assert message.contains_image
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user