From 9436de4bb3caa63fef05e1f94059021c50ee78dc Mon Sep 17 00:00:00 2001 From: Richardson Gunde <152559661+richard-devbot@users.noreply.github.com> Date: Tue, 7 Jan 2025 19:26:26 +0530 Subject: [PATCH] Update utils.py added fetch LLM models --- src/utils/utils.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/utils/utils.py b/src/utils/utils.py index 6fbbd6c..2b1da52 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -106,6 +106,49 @@ def get_llm_model(provider: str, **kwargs): else: raise ValueError(f'Unsupported provider: {provider}') +from openai import OpenAI, AzureOpenAI +from google.generativeai import configure, list_models +from langchain_anthropic import AnthropicLLM +from langchain_ollama.llms import OllamaLLM + +def fetch_available_models(llm_provider: str, api_key: str = None, base_url: str = None) -> list[str]: + try: + if llm_provider == "anthropic": + client = AnthropicLLM(api_key=api_key) + # Handle model fetching appropriately for Anthropic + return ["claude-3-5-sonnet-20240620"] # Replace with actual model fetching logic + + elif llm_provider == "openai": + client = OpenAI(api_key=api_key, base_url=base_url) + models = client.models.list() + return [model.id for model in models.data] + + elif llm_provider == "deepseek": + # For Deepseek, we'll return the default model for now + return ["deepseek-chat"] + + elif llm_provider == "gemini": + configure(api_key=api_key) + models = list_models() + return [model.name for model in models] + + elif llm_provider == "ollama": + client = OllamaLLM(model="default_model_name") # Replace with the actual model name + models = client.models.list() + return [model.name for model in models] + + elif llm_provider == "azure_openai": + client = AzureOpenAI(api_key=api_key, base_url=base_url) + models = client.models.list() + return [model.id for model in models.data] + + else: + print(f"Unsupported LLM provider: {llm_provider}") + return [] + + except Exception as e: + print(f"Error fetching models from {llm_provider}: {e}") + return [] def encode_image(img_path): if not img_path: