Merge pull request #981 from gezipppp-ux/master-jdk17

支持 GROK
This commit is contained in:
芋道源码 2025-11-24 10:13:55 +08:00 committed by GitHub
commit 6d10f136a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 94 additions and 0 deletions

View File

@ -40,6 +40,7 @@ public enum AiPlatformEnum implements ArrayValuable<String> {
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney
SUNO("Suno", "Suno"), // Suno AI
GROK("Grok","Grok"), // Grok
;

View File

@ -7,6 +7,7 @@ import cn.iocoder.yudao.module.ai.framework.ai.core.model.AiModelFactoryImpl;
import cn.iocoder.yudao.module.ai.framework.ai.core.model.baichuan.BaiChuanChatModel;
import cn.iocoder.yudao.module.ai.framework.ai.core.model.doubao.DouBaoChatModel;
import cn.iocoder.yudao.module.ai.framework.ai.core.model.gemini.GeminiChatModel;
import cn.iocoder.yudao.module.ai.framework.ai.core.model.grok.GrokChatModel;
import cn.iocoder.yudao.module.ai.framework.ai.core.model.hunyuan.HunYuanChatModel;
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
@ -17,6 +18,7 @@ import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchClient;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.bocha.AiBoChaWebSearchClient;
import cn.iocoder.yudao.module.ai.tool.method.PersonService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.deepseek.api.DeepSeekApi;
@ -40,6 +42,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.List;
import java.util.Optional;
/**
* 芋道 AI 自动配置
@ -286,4 +289,25 @@ public class AiAutoConfiguration {
return List.of(ToolCallbacks.from(personService));
}
public ChatModel buildGrokChatClient(YudaoAiProperties.Grok properties) {
if (StrUtil.isEmpty(properties.getModel())) {
properties.setModel(GrokChatModel.MODEL_DEFAULT);
}
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(Optional.ofNullable(properties.getBaseUrl())
.orElse(GrokChatModel.BASE_URL))
.completionsPath(GrokChatModel.COMPLETE_PATH)
.apiKey(properties.getApiKey())
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(properties.getModel())
.temperature(properties.getTemperature())
.maxTokens(properties.getMaxTokens())
.topP(properties.getTopP())
.build())
.toolCallingManager(getToolCallingManager())
.build();
return new DouBaoChatModel(openAiChatModel);
}
}

View File

@ -160,6 +160,20 @@ public class YudaoAiProperties {
}
@Data
public static class Grok {
private String enable;
private String apiKey;
private String baseUrl;
private String model;
private Double temperature;
private Integer maxTokens;
private Double topP;
}
@Data
public static class WebSearch {

View File

@ -178,6 +178,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return buildGeminiChatModel(apiKey);
case OLLAMA:
return buildOllamaChatModel(url);
case GROK:
return buildGrokChatModel(apiKey,url);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@ -405,6 +407,12 @@ public class AiModelFactoryImpl implements AiModelFactory {
.build();
}
private ChatModel buildGrokChatModel(String apiKey,String url) {
YudaoAiProperties.Grok properties = new YudaoAiProperties.Grok()
.setBaseUrl(url)
.setApiKey(apiKey);
return new AiAutoConfiguration().buildGrokChatClient(properties);
}
/**
* 可参考 {@link AiAutoConfiguration#douBaoChatClient(YudaoAiProperties)}
*/

View File

@ -0,0 +1,44 @@
package cn.iocoder.yudao.module.ai.framework.ai.core.model.grok;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
/**
* Grok {@link ChatModel} 实现类
*
*
*/
@Slf4j
@RequiredArgsConstructor
public class GrokChatModel implements ChatModel {
public static final String BASE_URL = "https://api.x.ai";
public static final String COMPLETE_PATH = "/v1/chat/completions";
public static final String MODEL_DEFAULT = "grok-4-fast-reasoning";
/**
* 兼容 OpenAI 接口进行复用
*/
private final ChatModel openAiChatModel;
@Override
public ChatResponse call(Prompt prompt) {
return openAiChatModel.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return openAiChatModel.stream(prompt);
}
@Override
public ChatOptions getDefaultOptions() {
return openAiChatModel.getDefaultOptions();
}
}

View File

@ -79,6 +79,9 @@ public class AiUtils {
case OLLAMA:
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
case GROK:
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}