diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/enums/model/AiPlatformEnum.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/enums/model/AiPlatformEnum.java index 47a4d2d719..612a91338a 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/enums/model/AiPlatformEnum.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/enums/model/AiPlatformEnum.java @@ -40,6 +40,7 @@ public enum AiPlatformEnum implements ArrayValuable { STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney SUNO("Suno", "Suno"), // Suno AI + GROK("Grok","Grok"), // Grok ; diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/config/AiAutoConfiguration.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/config/AiAutoConfiguration.java index 26fbe0ad41..fc965c7074 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/config/AiAutoConfiguration.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/config/AiAutoConfiguration.java @@ -7,6 +7,7 @@ import cn.iocoder.yudao.module.ai.framework.ai.core.model.AiModelFactoryImpl; import cn.iocoder.yudao.module.ai.framework.ai.core.model.baichuan.BaiChuanChatModel; import cn.iocoder.yudao.module.ai.framework.ai.core.model.doubao.DouBaoChatModel; import cn.iocoder.yudao.module.ai.framework.ai.core.model.gemini.GeminiChatModel; +import cn.iocoder.yudao.module.ai.framework.ai.core.model.grok.GrokChatModel; import cn.iocoder.yudao.module.ai.framework.ai.core.model.hunyuan.HunYuanChatModel; import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowApiConstants; @@ -17,6 +18,7 @@ import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchClient; import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.bocha.AiBoChaWebSearchClient; import cn.iocoder.yudao.module.ai.tool.method.PersonService; import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.deepseek.DeepSeekChatModel; import org.springframework.ai.deepseek.DeepSeekChatOptions; import org.springframework.ai.deepseek.api.DeepSeekApi; @@ -40,6 +42,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import java.util.List; +import java.util.Optional; /** * 芋道 AI 自动配置 @@ -286,4 +289,25 @@ public class AiAutoConfiguration { return List.of(ToolCallbacks.from(personService)); } + public ChatModel buildGrokChatClient(YudaoAiProperties.Grok properties) { + if (StrUtil.isEmpty(properties.getModel())) { + properties.setModel(GrokChatModel.MODEL_DEFAULT); + } + OpenAiChatModel openAiChatModel = OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder() + .baseUrl(Optional.ofNullable(properties.getBaseUrl()) + .orElse(GrokChatModel.BASE_URL)) + .completionsPath(GrokChatModel.COMPLETE_PATH) + .apiKey(properties.getApiKey()) + .build()) + .defaultOptions(OpenAiChatOptions.builder() + .model(properties.getModel()) + .temperature(properties.getTemperature()) + .maxTokens(properties.getMaxTokens()) + .topP(properties.getTopP()) + .build()) + .toolCallingManager(getToolCallingManager()) + .build(); + return new DouBaoChatModel(openAiChatModel); + } } \ No newline at end of file diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/config/YudaoAiProperties.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/config/YudaoAiProperties.java index 67d3bb5f3a..986c24c18c 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/config/YudaoAiProperties.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/config/YudaoAiProperties.java @@ -160,6 +160,20 @@ public class YudaoAiProperties { } + @Data + public static class Grok { + + private String enable; + private String apiKey; + private String baseUrl; + + private String model; + private Double temperature; + private Integer maxTokens; + private Double topP; + + } + @Data public static class WebSearch { diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/AiModelFactoryImpl.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/AiModelFactoryImpl.java index 75798ebd2a..0b4c4563ce 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/AiModelFactoryImpl.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/AiModelFactoryImpl.java @@ -178,6 +178,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return buildGeminiChatModel(apiKey); case OLLAMA: return buildOllamaChatModel(url); + case GROK: + return buildGrokChatModel(apiKey,url); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } @@ -405,6 +407,12 @@ public class AiModelFactoryImpl implements AiModelFactory { .build(); } + private ChatModel buildGrokChatModel(String apiKey,String url) { + YudaoAiProperties.Grok properties = new YudaoAiProperties.Grok() + .setBaseUrl(url) + .setApiKey(apiKey); + return new AiAutoConfiguration().buildGrokChatClient(properties); + } /** * 可参考 {@link AiAutoConfiguration#douBaoChatClient(YudaoAiProperties)} */ diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/grok/GrokChatModel.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/grok/GrokChatModel.java new file mode 100644 index 0000000000..06eed2504f --- /dev/null +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/grok/GrokChatModel.java @@ -0,0 +1,44 @@ +package cn.iocoder.yudao.module.ai.framework.ai.core.model.grok; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import reactor.core.publisher.Flux; + +/** + * Grok {@link ChatModel} 实现类 + * + * + */ +@Slf4j +@RequiredArgsConstructor +public class GrokChatModel implements ChatModel { + + public static final String BASE_URL = "https://api.x.ai"; + public static final String COMPLETE_PATH = "/v1/chat/completions"; + public static final String MODEL_DEFAULT = "grok-4-fast-reasoning"; + + /** + * 兼容 OpenAI 接口,进行复用 + */ + private final ChatModel openAiChatModel; + + @Override + public ChatResponse call(Prompt prompt) { + return openAiChatModel.call(prompt); + } + + @Override + public Flux stream(Prompt prompt) { + return openAiChatModel.stream(prompt); + } + + @Override + public ChatOptions getDefaultOptions() { + return openAiChatModel.getDefaultOptions(); + } + +} diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/util/AiUtils.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/util/AiUtils.java index d209c62d44..ccda6a3955 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/util/AiUtils.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/util/AiUtils.java @@ -79,6 +79,9 @@ public class AiUtils { case OLLAMA: return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens) .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); + case GROK: + return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) + .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); }