Vision and prompt caching fixes (#4014)

This commit is contained in:
Engel Nyst
2024-09-28 14:37:29 +02:00
committed by GitHub
parent f427f9d8d4
commit e582806004
11 changed files with 262 additions and 121 deletions

View File

@@ -50,6 +50,8 @@ class ImageContent(Content):
class Message(BaseModel):
role: Literal['user', 'system', 'assistant']
content: list[TextContent | ImageContent] = Field(default=list)
cache_enabled: bool = False
vision_enabled: bool = False
@property
def contains_image(self) -> bool:
@@ -58,23 +60,22 @@ class Message(BaseModel):
@model_serializer
def serialize_model(self) -> dict:
content: list[dict] | str
if self.role == 'system':
# For system role, concatenate all text content into a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
elif self.role == 'assistant' and not self.contains_image:
# For assistant role without vision, concatenate all text content into a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
else:
# For user role or assistant role with vision enabled, serialize each content item
# two kinds of serializer:
# 1. vision serializer: when prompt caching or vision is enabled
# 2. single text serializer: for other cases
# remove this when liteLLM or providers support this format translation
if self.cache_enabled or self.vision_enabled:
# when prompt caching or vision is enabled, use vision serializer
content = []
for item in self.content:
if isinstance(item, TextContent):
content.append(item.model_dump())
elif isinstance(item, ImageContent):
content.extend(item.model_dump())
else:
# for other cases, concatenate all text content
# into a single string per message
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
return {'content': content, 'role': self.role}