feat: add return type hints to LLM class methods (#5173)

This commit is contained in:
Cheng Yang 2024-11-21 21:00:46 +08:00 committed by GitHub
parent 7e38297732
commit 68e52a9c62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -369,16 +369,16 @@ class LLM(RetryMixin, DebugMixin):
):
self.config.max_output_tokens = self.model_info['max_tokens']
def vision_is_active(self):
def vision_is_active(self) -> bool:
with warnings.catch_warnings():
warnings.simplefilter('ignore')
return not self.config.disable_vision and self._supports_vision()
def _supports_vision(self):
def _supports_vision(self) -> bool:
"""Acquire from litellm if model is vision capable.
Returns:
bool: True if model is vision capable. If model is not supported by litellm, it will return False.
bool: True if model is vision capable. Return False if model not supported by litellm.
"""
# litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes)
# but model_info will have the correct value for some reason.
@ -476,7 +476,7 @@ class LLM(RetryMixin, DebugMixin):
if stats:
logger.debug(stats)
def get_token_count(self, messages):
def get_token_count(self, messages) -> int:
"""Get the number of tokens in a list of messages.
Args:
@ -491,7 +491,7 @@ class LLM(RetryMixin, DebugMixin):
# TODO: this is to limit logspam in case token count is not supported
return 0
def _is_local(self):
def _is_local(self) -> bool:
"""Determines if the system is using a locally running LLM.
Returns:
@ -506,7 +506,7 @@ class LLM(RetryMixin, DebugMixin):
return True
return False
def _completion_cost(self, response):
def _completion_cost(self, response) -> float:
"""Calculate the cost of a completion response based on the model. Local models are treated as free.
Add the current cost into total cost in metrics.
@ -555,7 +555,7 @@ class LLM(RetryMixin, DebugMixin):
def __repr__(self):
return str(self)
def reset(self):
def reset(self) -> None:
self.metrics.reset()
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]: