diff --git a/src/utils/utils.py b/src/utils/utils.py index cfe24d3..da0197b 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Dict, Optional from langchain_anthropic import ChatAnthropic +from langchain_mistralai import ChatMistralAI from langchain_google_genai import ChatGoogleGenerativeAI from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI @@ -40,6 +41,22 @@ def get_llm_model(provider: str, **kwargs): base_url=base_url, api_key=api_key, ) + elif provider == 'mistral': + if not kwargs.get("base_url", ""): + base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1") + else: + base_url = kwargs.get("base_url") + if not kwargs.get("api_key", ""): + api_key = os.getenv("MISTRAL_API_KEY", "") + else: + api_key = kwargs.get("api_key") + + return ChatMistralAI( + model=kwargs.get("model_name", "mistral-large-latest"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) elif provider == "openai": if not kwargs.get("base_url", ""): base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") @@ -117,7 +134,8 @@ model_names = { "deepseek": ["deepseek-chat"], "gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ], "ollama": ["qwen2.5:7b", "llama2:7b"], - "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"] + "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"], + "mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"] } # Callback to update the model name dropdown based on the selected provider