diff --git a/.env.example b/.env.example index 0eb799a..8d7ceff 100644 --- a/.env.example +++ b/.env.example @@ -34,9 +34,13 @@ IBM_ENDPOINT=https://us-south.ml.cloud.ibm.com IBM_API_KEY= IBM_PROJECT_ID= +GROK_ENDPOINT="https://api.x.ai/v1" +GROK_API_KEY= + #set default LLM DEFAULT_LLM=openai + # Set to false to disable anonymized telemetry ANONYMIZED_TELEMETRY=false diff --git a/src/utils/config.py b/src/utils/config.py index 509bc82..de82bb9 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -7,7 +7,8 @@ PROVIDER_DISPLAY_NAMES = { "alibaba": "Alibaba", "moonshot": "MoonShot", "unbound": "Unbound AI", - "ibm": "IBM" + "ibm": "IBM", + "grok": "Grok", } # Predefined model names for common providers @@ -25,6 +26,15 @@ model_names = { "alibaba": ["qwen-plus", "qwen-max", "qwen-vl-max", "qwen-vl-plus", "qwen-turbo", "qwen-long"], "moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"], "unbound": ["gemini-2.0-flash", "gpt-4o-mini", "gpt-4o", "gpt-4.5-preview"], + "grok": [ + "grok-3", + "grok-3-fast", + "grok-3-mini", + "grok-3-mini-fast", + "grok-2-vision", + "grok-2-image", + "grok-2", + ], "siliconflow": [ "deepseek-ai/DeepSeek-R1", "deepseek-ai/DeepSeek-V3", diff --git a/src/utils/llm_provider.py b/src/utils/llm_provider.py index beadb1f..36da553 100644 --- a/src/utils/llm_provider.py +++ b/src/utils/llm_provider.py @@ -205,6 +205,18 @@ def get_llm_model(provider: str, **kwargs): base_url=base_url, api_key=api_key, ) + elif provider == "grok": + if not kwargs.get("base_url", ""): + base_url = os.getenv("GROK_ENDPOINT", "https://api.x.ai/v1") + else: + base_url = kwargs.get("base_url") + + return ChatOpenAI( + model=kwargs.get("model_name", "grok-3"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) elif provider == "deepseek": if not kwargs.get("base_url", ""): base_url = os.getenv("DEEPSEEK_ENDPOINT", "")