This commit is contained in:
zhaoyuhang 2023-06-29 23:52:48 +08:00
parent 3218b37ffc
commit fb90e98a4e
11 changed files with 542 additions and 5 deletions

View File

@ -36,6 +36,11 @@ public class RedisKey {
*/
public static final String USER_SUMMARY_STRING = "userSummary:uid_%d";
/**
* 用户AI聊天次数
*/
public static final String USER_CHAT_NUM = "userAIChatNum:uid_%d";
public static String getKey(String key, Object... objects) {
return BASE_KEY + String.format(key, objects);
}

View File

@ -0,0 +1,16 @@
package com.abin.mallchat.common.common.utils;
import java.util.Calendar;
import java.util.Date;
public class DateUtils extends org.apache.commons.lang3.time.DateUtils {
public static Long getEndTimeByToday() {
Calendar instance = Calendar.getInstance();
Date now = new Date();
instance.setTime(now);
instance.set(Calendar.HOUR_OF_DAY, 23);
instance.set(Calendar.MINUTE, 59);
instance.set(Calendar.SECOND, 59);
return instance.getTime().getTime() - now.getTime();
}
}

View File

@ -12,7 +12,7 @@ mybatis-plus:
spring:
profiles:
#运行的环境
active: my-prod
active: test
application:
name: mallchat
datasource:
@ -37,8 +37,8 @@ spring:
database: 0
# 连接超时时间
timeout: 1800000
# 设置密码
password: ${mallchat.redis.password}
# # 设置密码
# password: ${mallchat.redis.password}
lettuce:
pool:
# 最大阻塞等待时间,负数表示没有限制
@ -62,4 +62,9 @@ wx:
- appId: ${mallchat.wx.appId} # 第一个公众号的appid
secret: ${mallchat.wx.secret} # 公众号的appsecret
token: ${mallchat.wx.token} # 接口配置里的Token值
aesKey: ${mallchat.wx.aesKey} # 接口配置里的EncodingAESKey值
aesKey: ${mallchat.wx.aesKey} # 接口配置里的EncodingAESKey值
openai:
use-openai: true
ai-user-id: xxxxx
key: xxxxxxx
proxy-url: https://xxxxxxx

View File

@ -38,6 +38,7 @@ import com.abin.mallchat.custom.chat.service.strategy.mark.MsgMarkFactory;
import com.abin.mallchat.custom.chat.service.strategy.msg.AbstractMsgHandler;
import com.abin.mallchat.custom.chat.service.strategy.msg.MsgHandlerFactory;
import com.abin.mallchat.custom.chat.service.strategy.msg.RecallMsgHandler;
import com.abin.mallchat.custom.openai.event.OpenAIEvent;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
@ -95,6 +96,7 @@ public class ChatServiceImpl implements ChatService {
msgHandler.saveMsg(insert, request);
//发布消息发送事件
applicationEventPublisher.publishEvent(new MessageSendEvent(this, insert.getId()));
applicationEventPublisher.publishEvent(new OpenAIEvent(this, insert.getId()));
return insert.getId();
}

View File

@ -66,7 +66,7 @@ public class TextMsgHandler extends AbstractMsgHandler {
AssertUtil.equal(replyMsg.getRoomId(), request.getRoomId(), "只能回复相同会话内的消息");
}
if (CollectionUtil.isNotEmpty(body.getAtUidList())) {
AssertUtil.isTrue(body.getAtUidList().size() > 10, "一次别艾特这么多人");
AssertUtil.isFalse(body.getAtUidList().size() > 10, "一次别艾特这么多人");
List<Long> atUidList = body.getAtUidList();
Map<Long, User> batch = userInfoCache.getBatch(atUidList);
AssertUtil.equal(atUidList.size(), batch.values().size(), "@用户不存在");

View File

@ -0,0 +1,92 @@
package com.abin.mallchat.custom.openai.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
@AllArgsConstructor
@Getter
public enum OpenAIModelEnums {
// chat
GPT_35_TURBO("gpt-3.5-turbo", 3, 40000),
GPT_35_TURBO_0301("gpt-3.5-turbo-0301", 3, 40000),
GPT_35_TURBO_0613("gpt-3.5-turbo-0613", 3, 40000),
GPT_35_TURBO_16K("gpt-3.5-turbo-16k", 3, 40000),
GPT_35_TURBO_16K_0613("gpt-3.5-turbo-16k-0613", 3, 40000),
// text
ADA("ada", 60, 150000),
ADA_CODE_SEARCH_CODE("ada-code-search-code", 60, 150000),
ADA_CODE_SEARCH_TEXT("ada-code-search-text", 60, 150000),
ADA_SEARCH_DOCUMENT("ada-search-document", 60, 150000),
ADA_SEARCH_QUERY("ada-search-query", 60, 150000),
ADA_SIMILARITY("ada-similarity", 60, 150000),
BABBAGE("babbage", 60, 150000),
BABBAGE_CODE_SEARCH_CODE("babbage-code-search-code", 60, 150000),
BABBAGE_CODE_SEARCH_TEXT("babbage-code-search-text", 60, 150000),
BABBAGE_SEARCH_DOCUMENT("babbage-search-document", 60, 150000),
BABBAGE_SEARCH_QUERY("babbage-search-query", 60, 150000),
BABBAGE_SIMILARITY("babbage-similarity", 60, 150000),
CODE_DAVINCI_EDIT_001("code-davinci-edit-001", 20, 150000),
CODE_SEARCH_ADA_CODE_001("code-search-ada-code-001", 60, 150000),
CODE_SEARCH_ADA_TEXT_001("code-search-ada-text-001", 60, 150000),
CODE_SEARCH_BABBAGE_CODE_001("code-search-babbage-code-001", 60, 150000),
CODE_SEARCH_BABBAGE_TEXT_001("code-search-babbage-text-001", 60, 150000),
CURIE("curie", 60, 150000),
CURIE_INSTRUCT_BETA("curie-instruct-beta", 60, 150000),
CURIE_SEARCH_DOCUMENT("curie-search-document", 60, 150000),
CURIE_SEARCH_QUERY("curie-search-query", 60, 150000),
CURIE_SIMILARITY("curie-similarity", 60, 150000),
DAVINCI("davinci", 60, 150000),
DAVINCI_INSTRUCT_BETA("davinci-instruct-beta", 60, 150000),
DAVINCI_SEARCH_DOCUMENT("davinci-search-document", 60, 150000),
DAVINCI_SEARCH_QUERY("davinci-search-query", 60, 150000),
DAVINCI_SIMILARITY("davinci-similarity", 60, 150000),
TEXT_ADA_001("text-ada-001", 60, 150000),
TEXT_BABBAGE_001("text-babbage-001", 60, 150000),
TEXT_CURIE_001("text-curie-001", 60, 150000),
TEXT_DAVINCI_001("text-davinci-001", 60, 150000),
TEXT_DAVINCI_002("text-davinci-002", 60, 150000),
TEXT_DAVINCI_003("text-davinci-003", 60, 150000),
TEXT_DAVINCI_EDIT_001("text-davinci-edit-001", 20, 150000),
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", 60, 150000),
TEXT_SEARCH_ADA_DOC_001("text-search-ada-doc-001", 60, 150000),
TEXT_SEARCH_ADA_QUERY_001("text-search-ada-query-001", 60, 150000),
TEXT_SEARCH_BABBAGE_DOC_001("text-search-babbage-doc-001", 60, 150000),
TEXT_SEARCH_BABBAGE_QUERY_001("text-search-babbage-query-001", 60, 150000),
TEXT_SEARCH_CURIE_DOC_001("text-search-curie-doc-001", 60, 150000),
TEXT_SEARCH_CURIE_QUERY_001("text-search-curie-query-001", 60, 150000),
TEXT_SEARCH_DAVINCI_DOC_001("text-search-davinci-doc-001", 60, 150000),
TEXT_SEARCH_DAVINCI_QUERY_001("text-search-davinci-query-001", 60, 150000),
TEXT_SIMILARITY_ADA_001("text-similarity-ada-001", 60, 150000),
TEXT_SIMILARITY_BABBAGE_001("text-similarity-babbage-001", 60, 150000),
TEXT_SIMILARITY_CURIE_001("text-similarity-curie-001", 60, 150000),
TEXT_SIMILARITY_DAVINCI_001("text-similarity-davinci-001", 60, 150000);
/**
* 名字
*/
private final String name;
/**
* 每分钟请求数
*/
private final Integer RPM;
/**
* 每分钟令牌数
*/
private final Integer TPM;
private static final Map<String, OpenAIModelEnums> cache;
static {
cache = Arrays.stream(OpenAIModelEnums.values()).collect(Collectors.toMap(OpenAIModelEnums::getName, Function.identity()));
}
public static OpenAIModelEnums of(String name) {
return cache.get(name);
}
}

View File

@ -0,0 +1,14 @@
package com.abin.mallchat.custom.openai.event;
import lombok.Getter;
import org.springframework.context.ApplicationEvent;
@Getter
public class OpenAIEvent extends ApplicationEvent {
private Long msgId;
public OpenAIEvent(Object source, Long msgId) {
super(source);
this.msgId = msgId;
}
}

View File

@ -0,0 +1,63 @@
package com.abin.mallchat.custom.openai.event.listener;
import com.abin.mallchat.common.chat.dao.MessageDao;
import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.custom.openai.event.OpenAIEvent;
import com.abin.mallchat.custom.openai.service.IOpenAIService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionalEventListener;
import static com.abin.mallchat.custom.openai.service.impl.OpenAIServiceImpl.MALL_CHAT_AI_NAME;
/**
* 是否AI回复监听器
*
* @author zhaoyuhang
* @date 2023/06/29
*/
@Slf4j
@Component
public class OpenAIListener {
@Autowired
private IOpenAIService openAIService;
@Autowired
private MessageDao messageDao;
@TransactionalEventListener(classes = OpenAIEvent.class, fallbackExecution = true)
public void notifyAllOnline(@NotNull OpenAIEvent event) {
Message message = messageDao.getById(event.getMsgId());
if (ATedAI(message)) {
openAIService.chat(message);
}
}
/**
* @return boolean
* @了AI
*/
private boolean ATedAI(Message message) {
/* 前端传@信息后取消注释 */
// MessageExtra extra = message.getExtra();
// if (extra == null) {
// return false;
// }
// if (CollectionUtils.isEmpty(extra.getAtUidList())) {
// return false;
// }
// if (!extra.getAtUidList().contains(OpenAIServiceImpl.AI_USER_ID)) {
// return false;
// }
if (StringUtils.isBlank(message.getContent())) {
return false;
}
return StringUtils.contains(message.getContent(), "@" + MALL_CHAT_AI_NAME)
&& StringUtils.isNotBlank(message.getContent().replace(MALL_CHAT_AI_NAME, "").trim());
}
}

View File

@ -0,0 +1,11 @@
package com.abin.mallchat.custom.openai.service;
import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq;
public interface IOpenAIService {
void chat(ChatMessageReq chatMessageReq, Long uid);
void chat(Message message);
}

View File

@ -0,0 +1,171 @@
package com.abin.mallchat.custom.openai.service.impl;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.thread.NamedThreadFactory;
import cn.hutool.http.HttpResponse;
import cn.hutool.http.HttpUtil;
import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.enums.MessageTypeEnum;
import com.abin.mallchat.common.common.constant.RedisKey;
import com.abin.mallchat.common.common.exception.BusinessException;
import com.abin.mallchat.common.common.handler.GlobalUncaughtExceptionHandler;
import com.abin.mallchat.common.common.utils.DateUtils;
import com.abin.mallchat.common.common.utils.RedisUtils;
import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq;
import com.abin.mallchat.custom.chat.domain.vo.request.msg.TextMsgReq;
import com.abin.mallchat.custom.chat.service.ChatService;
import com.abin.mallchat.custom.openai.enums.OpenAIModelEnums;
import com.abin.mallchat.custom.openai.service.IOpenAIService;
import com.abin.mallchat.custom.openai.utils.OpenAIUtils;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
import com.abin.mallchat.custom.user.service.UserService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Description;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import java.util.Collections;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
@Slf4j
@Service
public class OpenAIServiceImpl implements IOpenAIService, DisposableBean, InitializingBean {
private static ExecutorService EXECUTOR;
@Value("${openai.use-openai:false}")
private boolean USE_OPENAI;
@Value("${openai.ai-user-id}")
public Long AI_USER_ID;
@Value("${openai.model.name:text-davinci-003}")
private String modelName;
@Value("${openai.key}")
private String key;
@Value("${openai.proxy-url:}")
private String proxyUrl;
@Value("${openai.limit:5}")
private Integer limit;
@Autowired
private UserService userService;
@Lazy
@Autowired
private ChatService chatService;
public static String MALL_CHAT_AI_NAME;
/**
* 聊天
*
* @param chatMessageReq 提示词
* @param uid 用户id
*/
@Deprecated
@Override
public void chat(ChatMessageReq chatMessageReq, Long uid) {
TextMsgReq body = BeanUtil.toBean(chatMessageReq.getBody(), TextMsgReq.class);
String content = body.getContent().replace(MALL_CHAT_AI_NAME, "").trim();
EXECUTOR.execute(() -> {
Long chatNum;
if ((chatNum = userChatNumInrc(uid)) > limit) {
answerMsg("你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧", chatMessageReq.getRoomId(), uid);
} else {
chat(content, chatMessageReq.getRoomId(), uid);
}
});
}
@Override
public void chat(Message message) {
String content = message.getContent().replace(MALL_CHAT_AI_NAME, "").trim();
Long roomId = message.getRoomId();
Long uid = message.getFromUid();
EXECUTOR.execute(() -> {
Long chatNum;
if ((chatNum = userChatNumInrc(uid)) > limit) {
answerMsg("你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧", roomId, uid);
} else {
chat(content, roomId, uid);
}
});
}
private Long userChatNumInrc(Long uid) {
//todo:白名单
return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS);
}
private void chat(String content, Long roomId, Long uid) {
HttpResponse response = OpenAIUtils.create(key)
.proxyUrl(proxyUrl)
.model(modelName)
.prompt(content)
.send();
String text = OpenAIUtils.parseText(response);
answerMsg(text, roomId, uid);
}
private void answerMsg(String text, Long roomId, Long uid) {
ChatMessageReq answerReq = new ChatMessageReq();
answerReq.setRoomId(roomId);
answerReq.setMsgType(MessageTypeEnum.TEXT.getType());
UserInfoResp userInfo = userService.getUserInfo(uid);
TextMsgReq textMsgReq = new TextMsgReq();
textMsgReq.setContent("@" + userInfo.getName() + " " + text);
textMsgReq.setAtUidList(Collections.singletonList(uid));
answerReq.setBody(textMsgReq);
chatService.sendMsg(answerReq, AI_USER_ID);
}
@Override
public void afterPropertiesSet() {
if (!USE_OPENAI) {
return;
}
if (StringUtils.isNotBlank(proxyUrl) && !HttpUtil.isHttp(proxyUrl) && !HttpUtil.isHttps(proxyUrl)) {
throw new BusinessException("openai.proxy-url 配置错误");
}
OpenAIModelEnums modelEnum = OpenAIModelEnums.of(modelName);
if (modelEnum == null) {
throw new BusinessException("openai.model.name 配置错误");
}
Integer rpm = modelEnum.getRPM();
EXECUTOR = new ThreadPoolExecutor(10, 10,
0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<>(rpm),
new NamedThreadFactory("openAI-chat-gpt",
null,
false,
new GlobalUncaughtExceptionHandler()),
(r, executor) -> {
throw new BusinessException("别问的太快了,我的脑子不够用了");
});
UserInfoResp userInfo = userService.getUserInfo(AI_USER_ID);
if (userInfo == null) {
throw new BusinessException("openai.ai-user-id 配置错误");
}
MALL_CHAT_AI_NAME = userInfo.getName();
}
@Override
public void destroy() throws Exception {
EXECUTOR.shutdown();
if (!EXECUTOR.awaitTermination(30, TimeUnit.SECONDS)) { //最多等30秒处理不完就拉倒
if (log.isErrorEnabled()) {
log.error("Timed out while waiting for executor [{}] to terminate", EXECUTOR);
}
}
}
}

View File

@ -0,0 +1,158 @@
package com.abin.mallchat.custom.openai.utils;
import cn.hutool.http.HttpResponse;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import com.abin.mallchat.common.common.exception.BusinessException;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
import java.util.Map;
public class OpenAIUtils {
private static final String URL = "https://api.openai.com/v1/completions";
private String model = "text-davinci-003";
private final Map<String, String> headers;
/**
* 超时30秒
*/
private Integer timeout = 30 * 1000;
/**
* 参数用于指定生成文本的最大长度
* 它表示生成的文本中最多包含多少个 token一个 token 可以是一个单词一个标点符号或一个空格
*/
private int maxTokens = 2048;
/**
* 用于控制生成文本的多样性
* 较高的温度会导致更多的随机性和多样性但可能会降低生成文本的质量默认值为 1建议在 0.7 1.3 之间调整
*/
private Object temperature = 1;
/**
* 用于控制生成文本的多样性
* 它会根据概率选择最高的几个单词而不是选择概率最高的单词默认值为 1建议在 0.7 0.9 之间调整
*/
private Object topP = 0.9;
/**
* 用于控制生成文本中重复单词的数量
* 较高的惩罚值会导致更少的重复单词但可能会降低生成文本的流畅性默认值为 0建议在 0 2 之间调整
*/
private Object frequencyPenalty = 0.0;
/**
* 用于控制生成文本中出现特定单词的数量
* 较高的惩罚值会导致更少的特定单词但可能会降低生成文本的流畅性默认值为 0建议在 0 2 之间调整
*/
private Object presencePenalty = 0.6;
/**
* 提示词
*/
private String prompt;
private String proxyUrl;
public OpenAIUtils(String key) {
HashMap<String, String> _headers_ = new HashMap<>();
_headers_.put("Content-Type", "application/json");
if (StringUtils.isBlank(key)) {
throw new BusinessException("openAi key is blank");
}
_headers_.put("Authorization", "Bearer " + key);
this.headers = _headers_;
}
public static OpenAIUtils create(String key) {
return new OpenAIUtils(key);
}
public static String parseText(HttpResponse response) {
return parseText(response.body());
}
public static String parseText(String body) {
JSONObject jsonObj = new JSONObject(body);
JSONArray choicesArr = jsonObj.getJSONArray("choices");
JSONObject choiceObj = choicesArr.getJSONObject(0);
return choiceObj.getStr("text");
}
public OpenAIUtils model(String model) {
this.model = model;
return this;
}
public OpenAIUtils timeout(int timeout) {
this.timeout = timeout;
return this;
}
public OpenAIUtils maxTokens(int maxTokens) {
this.maxTokens = maxTokens;
return this;
}
public OpenAIUtils temperature(int temperature) {
this.temperature = temperature;
return this;
}
public OpenAIUtils topP(int topP) {
this.topP = topP;
return this;
}
public OpenAIUtils frequencyPenalty(int frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
return this;
}
public OpenAIUtils presencePenalty(int presencePenalty) {
this.presencePenalty = presencePenalty;
return this;
}
public OpenAIUtils prompt(String prompt) {
this.prompt = prompt;
return this;
}
public OpenAIUtils proxyUrl(String proxyUrl) {
this.proxyUrl = proxyUrl;
return this;
}
public HttpResponse send() {
JSONObject param = new JSONObject();
param.set("model", model);
param.set("prompt", prompt);
param.set("max_tokens", maxTokens);
param.set("temperature", temperature);
param.set("top_p", topP);
param.set("frequency_penalty", frequencyPenalty);
param.set("presence_penalty", presencePenalty);
return HttpUtil.createPost(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL)
.addHeaders(headers)
.body(param.toString())
.timeout(timeout)
.execute();
}
public static void main(String[] args) {
HttpResponse send = OpenAIUtils.create("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")
.timeout(30 * 1000)
.prompt("Spring的启动流程是什么")
.send();
System.out.println("send = " + send);
// JSON 数据
// JSON 数据
JSONObject jsonObj = new JSONObject(send.body());
JSONArray choicesArr = jsonObj.getJSONArray("choices");
JSONObject choiceObj = choicesArr.getJSONObject(0);
String text = choiceObj.getStr("text");
System.out.println("text = " + text);
}
}