diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py index 9e2a1d6..405308e 100644 --- a/tests/test_llm_api.py +++ b/tests/test_llm_api.py @@ -16,6 +16,25 @@ import sys sys.path.append(".") +def test_mistral_model(): + from langchain_core.messages import HumanMessage + from src.utils import utils + + llm = utils.get_llm_model( + provider="mistral", + model_name="mistral-large-latest", + temperature=0.8, + base_url=os.getenv("MISTRAL_ENDPOINT", ""), + api_key=os.getenv("MISTRAL_API_KEY", "") + ) + message = HumanMessage( + content=[ + {"type": "text", "text": "who are you?"} + ] + ) + ai_msg = llm.invoke([message]) + print(ai_msg.content) + def test_openai_model(): from langchain_core.messages import HumanMessage from src.utils import utils @@ -128,4 +147,5 @@ if __name__ == '__main__': # test_gemini_model() # test_azure_openai_model() # test_deepseek_model() - test_ollama_model() + # test_ollama_model() + test_mistral_model()