mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Add prompt caching (Sonnet, Haiku only) (#3411)
* Add prompt caching * remove anthropic-version from extra_headers * change supports_prompt_caching method to attribute * change caching strat and log cache statistics * add reminder as a new message to fix caching * fix unit test * append reminder to the end of the last message content * move token logs to post completion function * fix unit test failure * fix reminder and prompt caching * unit tests for prompt caching * add test * clean up tests * separate reminder, use latest two messages * fix tests --------- Co-authored-by: tobitege <10787084+tobitege@users.noreply.github.com> Co-authored-by: Xingyao Wang <xingyao6@illinois.edu> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
@@ -11,6 +11,7 @@ class ContentType(Enum):
|
||||
|
||||
class Content(BaseModel):
|
||||
type: ContentType
|
||||
cache_prompt: bool = False
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
@@ -23,7 +24,13 @@ class TextContent(Content):
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
return {'type': self.type.value, 'text': self.text}
|
||||
data: dict[str, str | dict[str, str]] = {
|
||||
'type': self.type.value,
|
||||
'text': self.text,
|
||||
}
|
||||
if self.cache_prompt:
|
||||
data['cache_control'] = {'type': 'ephemeral'}
|
||||
return data
|
||||
|
||||
|
||||
class ImageContent(Content):
|
||||
@@ -35,6 +42,8 @@ class ImageContent(Content):
|
||||
images: list[dict[str, str | dict[str, str]]] = []
|
||||
for url in self.image_urls:
|
||||
images.append({'type': self.type.value, 'image_url': {'url': url}})
|
||||
if self.cache_prompt and images:
|
||||
images[-1]['cache_control'] = {'type': 'ephemeral'}
|
||||
return images
|
||||
|
||||
|
||||
|
||||
@@ -35,6 +35,11 @@ __all__ = ['LLM']
|
||||
|
||||
message_separator = '\n\n----------\n\n'
|
||||
|
||||
cache_prompting_supported_models = [
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-3-haiku-20240307',
|
||||
]
|
||||
|
||||
|
||||
class LLM:
|
||||
"""The LLM class represents a Language Model instance.
|
||||
@@ -58,6 +63,9 @@ class LLM:
|
||||
self.config = copy.deepcopy(config)
|
||||
self.metrics = metrics if metrics is not None else Metrics()
|
||||
self.cost_metric_supported = True
|
||||
self.supports_prompt_caching = (
|
||||
self.config.model in cache_prompting_supported_models
|
||||
)
|
||||
|
||||
# Set up config attributes with default values to prevent AttributeError
|
||||
LLMConfig.set_missing_attributes(self.config)
|
||||
@@ -184,6 +192,7 @@ class LLM:
|
||||
|
||||
# log the response
|
||||
message_back = resp['choices'][0]['message']['content']
|
||||
|
||||
llm_response_logger.debug(message_back)
|
||||
|
||||
# post-process to log costs
|
||||
@@ -421,19 +430,51 @@ class LLM:
|
||||
def supports_vision(self):
|
||||
return litellm.supports_vision(self.config.model)
|
||||
|
||||
def _post_completion(self, response: str) -> None:
|
||||
def _post_completion(self, response) -> None:
|
||||
"""Post-process the completion response."""
|
||||
try:
|
||||
cur_cost = self.completion_cost(response)
|
||||
except Exception:
|
||||
cur_cost = 0
|
||||
|
||||
stats = ''
|
||||
if self.cost_metric_supported:
|
||||
logger.info(
|
||||
'Cost: %.2f USD | Accumulated Cost: %.2f USD',
|
||||
stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
|
||||
cur_cost,
|
||||
self.metrics.accumulated_cost,
|
||||
)
|
||||
|
||||
usage = response.get('usage')
|
||||
|
||||
if usage:
|
||||
input_tokens = usage.get('prompt_tokens')
|
||||
output_tokens = usage.get('completion_tokens')
|
||||
|
||||
if input_tokens:
|
||||
stats += 'Input tokens: ' + str(input_tokens) + '\n'
|
||||
|
||||
if output_tokens:
|
||||
stats += 'Output tokens: ' + str(output_tokens) + '\n'
|
||||
|
||||
model_extra = usage.get('model_extra', {})
|
||||
|
||||
cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
|
||||
if cache_creation_input_tokens:
|
||||
stats += (
|
||||
'Input tokens (cache write): '
|
||||
+ str(cache_creation_input_tokens)
|
||||
+ '\n'
|
||||
)
|
||||
|
||||
cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
|
||||
if cache_read_input_tokens:
|
||||
stats += (
|
||||
'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
|
||||
)
|
||||
|
||||
if stats:
|
||||
logger.info(stats)
|
||||
|
||||
def get_token_count(self, messages):
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user