feat:【ai 大模型】增加联网搜索功能

This commit is contained in:
YunaiV 2025-08-25 23:47:55 +08:00
parent ca34d1650e
commit 9b2f2f581b
5 changed files with 105 additions and 22 deletions

View File

@ -39,14 +39,27 @@ Authorization: {{token}}
tenant-id: {{adminTenantId}}
{
"conversationId": "1781604279872581797",
"conversationId": "1781604279872581799",
"content": "说下图片里,有哪些字?",
"useContext": true
}
### 获得指定对话的消息列表
GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649
### 发送消息(流式)【联网搜索】
POST {{baseUrl}}/ai/chat/message/send-stream
Content-Type: application/json
Authorization: {{token}}
tenant-id: {{adminTenantId}}
{
"conversationId": "1781604279872581799",
"content": "今天是周几?",
"useSearch": true
}
### 获得指定对话的消息列表
GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581799
Authorization: {{token}}
tenant-id: {{adminTenantId}}
### 删除消息
DELETE {{baseUrl}}/ai/chat/message/delete?id=50

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@ -49,6 +50,9 @@ public class AiChatMessageRespVO {
@Schema(description = "知识库段落数组")
private List<KnowledgeSegment> segments;
@Schema(description = "联网搜索的网页内容数组")
private List<AiWebSearchResponse.WebPage> webSearchPages;
@Schema(description = "附件 URL 数组", example = "https://www.iocoder.cn/1.png")
private List<String> attachmentUrls;

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@ -38,6 +39,9 @@ public class AiChatMessageSendRespVO {
@Schema(description = "知识库段落数组")
private List<AiChatMessageRespVO.KnowledgeSegment> segments;
@Schema(description = "联网搜索的网页内容数组")
private List<AiWebSearchResponse.WebPage> webSearchPages;
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime createTime;

View File

@ -6,11 +6,16 @@ import cn.iocoder.yudao.framework.mybatis.core.type.StringListTypeHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.*;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.ai.chat.messages.MessageType;
import java.util.List;
@ -106,6 +111,12 @@ public class AiChatMessageDO extends BaseDO {
@TableField(typeHandler = LongListTypeHandler.class)
private List<Long> segmentIds;
/**
* 联网搜索的网页内容数组
*/
@TableField(typeHandler = JacksonTypeHandler.class)
private List<AiWebSearchResponse.WebPage> webSearchPages;
/**
* 附件 URL 数组
*/

View File

@ -23,6 +23,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchClient;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchRequest;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
@ -44,6 +47,7 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
@ -69,6 +73,11 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
@Slf4j
public class AiChatMessageServiceImpl implements AiChatMessageService {
/**
* 联网搜索的结束数
*/
private static final Integer WEB_SEARCH_COUNT = 10;
// TODO @芋艿后续优化下对话的 Prompt 整体结构
/**
@ -78,6 +87,10 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
"%s\n\n" + // 多个 <Reference></Reference> 的拼接
"回答要求:\n- 避免提及你是从 <Reference></Reference> 获取的知识。";
private static final String WEB_SEARCH_USER_MESSAGE_TEMPLATE = "使用 <WebSearch></WebSearch> 标记中的内容作为本次对话的参考:\n\n" +
"%s\n\n" + // 多个 <WebSearch></WebSearch> 的拼接
"回答要求:\n- 避免提及你是从 <WebSearch></WebSearch> 获取的知识。";
/**
* 附件转 ${@link UserMessage} 的内容模版
*/
@ -102,6 +115,10 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Resource
private AiToolService toolService;
@SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection")
@Autowired(required = false) // 由于 yudao.ai.web-search.enable 配置项可以关闭 AiWebSearchClient 的功能所以这里只能不强制注入
private AiWebSearchClient webSearchClient;
@Transactional(rollbackFor = Exception.class)
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
// 1.1 校验对话存在
@ -115,30 +132,35 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiModelDO model = modalService.validateModel(conversation.getModelId());
ChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 知识库找
// 2.1 知识库召
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(
sendReqVO.getContent(), conversation);
// 2.2 联网搜索
AiWebSearchResponse webSearchResponse = Boolean.TRUE.equals(sendReqVO.getUseSearch()) && webSearchClient != null ?
webSearchClient.search(new AiWebSearchRequest().setQuery(sendReqVO.getContent())
.setSummary(true).setCount(WEB_SEARCH_COUNT)) : null;
// 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
null, sendReqVO.getAttachmentUrls());
null, sendReqVO.getAttachmentUrls(), null);
// 3.1 插入 assistant 接收消息
// 4.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
knowledgeSegments, null);
knowledgeSegments, null, webSearchResponse);
// 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
// 4.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, webSearchResponse, model, sendReqVO);
ChatResponse chatResponse = chatModel.call(prompt);
// 3.3 更新响应内容
// 4.3 更新响应内容
String newContent = AiUtils.getChatResponseContent(chatResponse);
String newReasoningContent = AiUtils.getChatResponseReasoningContent(chatResponse);
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId())
.setContent(newContent).setReasoningContent(newReasoningContent));
// 3.4 响应结果
// 4.4 响应结果
Map<Long, AiKnowledgeDocumentDO> documentMap = knowledgeDocumentService.getKnowledgeDocumentMap(
convertSet(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getDocumentId));
List<AiChatMessageRespVO.KnowledgeSegment> segments = BeanUtils.toBean(knowledgeSegments,
@ -149,7 +171,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return new AiChatMessageSendRespVO()
.setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
.setContent(newContent).setSegments(segments));
.setContent(newContent).setSegments(segments)
.setWebSearchPages(webSearchResponse != null ? webSearchResponse.getLists() : null));
}
@Override
@ -166,30 +189,36 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiModelDO model = modalService.validateModel(conversation.getModelId());
StreamingChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 知识库找回
// 2.1 知识库找回
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(
sendReqVO.getContent(), conversation);
// 2.2 联网搜索
AiWebSearchResponse webSearchResponse = Boolean.TRUE.equals(sendReqVO.getUseSearch()) && webSearchClient != null ?
webSearchClient.search(new AiWebSearchRequest().setQuery(sendReqVO.getContent())
.setSummary(true).setCount(WEB_SEARCH_COUNT)) : null;
// 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
null, sendReqVO.getAttachmentUrls());
null, sendReqVO.getAttachmentUrls(), null);
// 4.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
knowledgeSegments, null);
knowledgeSegments, null, webSearchResponse);
// 4.2 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, webSearchResponse, model, sendReqVO);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 4.3 流式返回
StringBuffer contentBuffer = new StringBuffer();
StringBuffer reasoningContentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
// 处理知识库的返回只有首次才有
// 仅首次返回知识库联网搜索
List<AiChatMessageRespVO.KnowledgeSegment> segments = null;
List<AiWebSearchResponse.WebPage> webSearchPages = null;
if (StrUtil.isEmpty(contentBuffer)) {
Map<Long, AiKnowledgeDocumentDO> documentMap = TenantUtils.executeIgnore(() ->
knowledgeDocumentService.getKnowledgeDocumentMap(
@ -198,6 +227,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
segment.setDocumentName(document != null ? document.getName() : null);
});
if (webSearchResponse != null) {
webSearchPages = webSearchResponse.getLists();
}
}
// 响应结果
String newContent = AiUtils.getChatResponseContent(chunk);
@ -213,7 +245,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
.setContent(StrUtil.nullToDefault(newContent, "")) // 避免 null 情况
.setReasoningContent(StrUtil.nullToDefault(newReasoningContent, "")) // 避免 null 情况
.setSegments(segments))); // 知识库返回
.setSegments(segments).setWebSearchPages(webSearchPages))); // 知识库 + 联网搜索
}).doOnComplete(() -> {
// 忽略租户因为 Flux 异步无法透传租户
TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
@ -239,7 +271,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return Collections.emptyList();
}
// 2. 遍历
// 2. 遍历
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = new ArrayList<>();
for (Long knowledgeId : role.getKnowledgeIds()) {
knowledgeSegments.addAll(knowledgeSegmentService.searchKnowledgeSegment(new AiKnowledgeSegmentSearchReqBO()
@ -250,6 +282,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
AiWebSearchResponse webSearchResponse,
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
List<Message> chatMessages = new ArrayList<>();
// 1.1 System Context 角色设定
@ -265,6 +298,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
if (attachmentUserMessage != null) {
chatMessages.add(attachmentUserMessage);
}
// TODO @芋艿历史的知识库历史的搜索要不要拼接
});
// 1.3 当前 user message 新发送消息
@ -278,7 +312,20 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
}
// 1.5 附件通过 UserMessage 实现
// 1.5 联网搜索通过 UserMessage 实现
if (webSearchResponse != null && CollUtil.isNotEmpty(webSearchResponse.getLists())) {
String webSearch = webSearchResponse.getLists().stream()
.map(page -> {
String summary = StrUtil.isNotEmpty(page.getSummary()) ?
"\nSummary: " + page.getSummary() : "";
return "<WebSearch title=\"" + page.getTitle() + "\" url=\"" + page.getUrl() + "\">"
+ StrUtil.blankToDefault(page.getSummary(), page.getSnippet()) + "</WebSearch>";
})
.collect(Collectors.joining("\n\n"));
chatMessages.add(new UserMessage(String.format(WEB_SEARCH_USER_MESSAGE_TEMPLATE, webSearch)));
}
// 1.6 附件通过 UserMessage 实现
if (CollUtil.isNotEmpty(sendReqVO.getAttachmentUrls())) {
UserMessage attachmentUserMessage = buildAttachmentUserMessage(sendReqVO.getAttachmentUrls());
if (attachmentUserMessage != null) {
@ -383,12 +430,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiModelDO model, Long userId, Long roleId,
MessageType messageType, String content, Boolean useContext,
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
List<String> attachmentUrls) {
List<String> attachmentUrls,
AiWebSearchResponse webSearchResponse) {
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
.setType(messageType.getValue()).setContent(content).setUseContext(useContext)
.setSegmentIds(convertList(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getId))
.setAttachmentUrls(attachmentUrls);
if (webSearchResponse != null) {
message.setWebSearchPages(webSearchResponse.getLists());
}
message.setCreateTime(LocalDateTime.now());
chatMessageMapper.insert(message);
return message;