OpenHands/openhands/core/message.py
Kaushik Deka 5bb931e4d6
Add prompt caching (Sonnet, Haiku only) (#3411)
* Add prompt caching

* remove anthropic-version from extra_headers

* change supports_prompt_caching method to attribute

* change caching strat and log cache statistics

* add reminder as a new message to fix caching

* fix unit test

* append reminder to the end of the last message content

* move token logs to post completion function

* fix unit test failure

* fix reminder and prompt caching

* unit tests for prompt caching

* add test

* clean up tests

* separate reminder, use latest two messages

* fix tests

---------

Co-authored-by: tobitege <10787084+tobitege@users.noreply.github.com>
Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
2024-08-26 20:46:44 -04:00

69 lines
1.9 KiB
Python

from enum import Enum
from pydantic import BaseModel, Field, model_serializer
from typing_extensions import Literal
class ContentType(Enum):
TEXT = 'text'
IMAGE_URL = 'image_url'
class Content(BaseModel):
type: ContentType
cache_prompt: bool = False
@model_serializer
def serialize_model(self):
raise NotImplementedError('Subclasses should implement this method.')
class TextContent(Content):
type: ContentType = ContentType.TEXT
text: str
@model_serializer
def serialize_model(self):
data: dict[str, str | dict[str, str]] = {
'type': self.type.value,
'text': self.text,
}
if self.cache_prompt:
data['cache_control'] = {'type': 'ephemeral'}
return data
class ImageContent(Content):
type: ContentType = ContentType.IMAGE_URL
image_urls: list[str]
@model_serializer
def serialize_model(self):
images: list[dict[str, str | dict[str, str]]] = []
for url in self.image_urls:
images.append({'type': self.type.value, 'image_url': {'url': url}})
if self.cache_prompt and images:
images[-1]['cache_control'] = {'type': 'ephemeral'}
return images
class Message(BaseModel):
role: Literal['user', 'system', 'assistant']
content: list[TextContent | ImageContent] = Field(default=list)
@property
def contains_image(self) -> bool:
return any(isinstance(content, ImageContent) for content in self.content)
@model_serializer
def serialize_model(self) -> dict:
content: list[dict[str, str | dict[str, str]]] = []
for item in self.content:
if isinstance(item, TextContent):
content.append(item.model_dump())
elif isinstance(item, ImageContent):
content.extend(item.model_dump())
return {'role': self.role, 'content': content}