兼容多个大模型

This commit is contained in:
yuruo
2024-07-03 17:43:09 +08:00
parent 8c9f8caf30
commit a95fa36ffa
10 changed files with 111 additions and 100 deletions

View File

@@ -38,7 +38,7 @@ initData()
function initData() {
const initData = findOne('select * from config')
if (initData) return
const llm = {model: "gpt-4-turbo", apiKey: "", baseURL: "https://api.openai.com/v1"}
const llm = {model: "gpt-4-turbo", api_key: "", base_url: "https://api.openai.com/v1"}
db().exec(`insert into config (id, content) values(1,'{"shortCut":"Alt+d","llm": ${JSON.stringify(llm)}}')`)
}

View File

@@ -1,13 +1,16 @@
import { ProChat } from '@ant-design/pro-chat';
import { ProChat, ProChatInstance } from '@ant-design/pro-chat';
import useChat from '@renderer/hooks/useChat';
import { useStore } from '@renderer/store/useStore';
import { useTheme } from 'antd-style';
import { useRef } from 'react';
export default function Chat(props: {id: number, revalidator: () => void}) {
const {id, revalidator} = props;
const {getResponse} = useChat()
const theme = useTheme();
const chatMessages = useStore(state=>state.chatMessages)
const setMessages = useStore(state=>state.setChatMessage)
const proChatRef = useRef<ProChatInstance>();
return (
<ProChat
chats={chatMessages}
@@ -15,6 +18,7 @@ export default function Chat(props: {id: number, revalidator: () => void}) {
console.log('chat', chat)
setMessages(chat)
}}
chatRef={proChatRef}
style={{ background: theme.colorBgLayout }}
// assistantMeta={{ avatar: '', title: '智子', backgroundColor: '#67dedd' }}
helloMessage={
@@ -22,8 +26,8 @@ export default function Chat(props: {id: number, revalidator: () => void}) {
}
request={async (messages) => {
const response = await getResponse(messages, id, revalidator)
return response// 支持流式和非流式
const response = await getResponse(messages, id, revalidator, proChatRef.current)
return new Response(response.content)// 支持流式和非流式
}}
/>
)

View File

@@ -1,49 +1,46 @@
import { useStore } from "@renderer/store/useStore";
import { requireAlignmentPrompt, programmerPrompt} from "./prompt";
import { requireAlignmentPrompt, programmerPrompt } from "./prompt";
import useOpenai from "./useOpenai";
import { ProChatInstance } from "@ant-design/pro-chat";
export default ()=>{
export default () => {
const setIsCodeLoading = useStore(state => state.setIsCodeLoading)
const getResponse=(chat_messages: Array<any>, id:number, revalidator: () => void)=>{
const messages = chat_messages.map((m) => {
return {
role: m.role,
content: m.content
}
})
const getResponse = async (chat_messages: Array<any>, id: number, revalidator: () => void, proChatRef: ProChatInstance|undefined) => {
// 添加 system 消息
messages.unshift({
role: 'system',
content: requireAlignmentPrompt()
});
const response = useOpenai(requireAlignmentPrompt(), messages, (allContent)=>{
const programmerCallBack = (allContent: string) => {
allContent = allContent.replace(/^```python/, "").replace(/^```/, "").replace(/```python$/, "").replace(/```$/, "").trim()
window.api.sql('update contents set content = @content where id = @id',
'update',
{content: allContent, id})
// 更新数据
revalidator()
// 关闭loading
setIsCodeLoading(false)
}
if (allContent.includes("【自动化方案】")) {
setIsCodeLoading(true)
useOpenai(programmerPrompt(), [{
role: "user",
content: allContent
}], (allContent)=>{
programmerCallBack(allContent)
const response = useOpenai(requireAlignmentPrompt(), chat_messages)
response.then(async (res) => {
if (res.content.includes("【自动化方案】")) {
const chat_id = Date.now().toString()
proChatRef!.pushChat({
id: chat_id,
createAt: new Date(),
updateAt: new Date(),
role: "assistant",
content: "根据自动化方案生成代码中,请稍等..."
})
setIsCodeLoading(true)
const programmerResponse = await useOpenai(programmerPrompt(), [{
role: "user",
content: res.content
}])
const code = programmerResponse.content.replace(/^```python/, "").replace(/^```/, "").replace(/```python$/, "").replace(/```$/, "").trim()
proChatRef!.setMessageContent(
chat_id,
"根据测试方案生成的python代码如下\n```python\n" + code + "\n```"
)
window.api.sql('update contents set content = @content where id = @id',
'update',
{ content: code, id})
// 更新数据
revalidator()
// 关闭loading
setIsCodeLoading(false)
}
})
} else {
console.log("Response does not contain '【自动化方案】'");
}
})
return response
}
return {getResponse}
}
return response
}
return { getResponse }
}

View File

@@ -1,9 +1,6 @@
import { createOpenAI } from "@ai-sdk/openai"
import { streamText } from "ai"
export default async (systemPrompt: string, chatMessages: Array<Record<string, any>>, callback?: (allContent: string) => void) => {
const configType = (await window.api.getConfig()) as ConfigType
const config = JSON.parse(configType.content) as ConfigDataType
import { localServerBaseUrl } from "@renderer/config"
export default async (systemPrompt: string, chatMessages: Array<Record<string, any>>) => {
const messages = chatMessages.map((m) => {
return {
role: m.role,
@@ -15,46 +12,13 @@ export default async (systemPrompt: string, chatMessages: Array<Record<string, a
role: 'system',
content: systemPrompt
});
const openai = createOpenAI({
apiKey: config.llm.apiKey,
baseURL: config.llm.baseURL,
compatibility: 'compatible'
});
const stream = await streamText({
model: openai(config.llm.model),
messages: [...messages],
});
// 获取 reader
const reader = stream.textStream.getReader();
const encoder = new TextEncoder();
const readableStream = new ReadableStream({
async start(controller) {
let res = ""
function push() {
reader
.read()
.then(({ done, value }) => {
if (done) {
controller.close();
if (callback) {
callback(res);
}
return;
}
res += value;
controller.enqueue(encoder.encode(value));
push();
})
.catch((err) => {
console.error('读取流中的数据时发生错误', err);
controller.error(err);
});
}
push();
const response = await fetch(localServerBaseUrl + "/llm", {
method: "POST",
headers: {
"Content-Type": "application/json"
},
});
return new Response(readableStream)
body: JSON.stringify({messages, isStream: false })
})
return response.json()
}

View File

@@ -22,7 +22,7 @@ function Home(): JSX.Element {
}, [])
window.api.getConfig().then((res)=>{
const config = JSON.parse(res.content) as ConfigDataType
if (config.llm.apiKey=="") {
if (config.llm.api_key=="") {
setError("没有检测到大模型配置信息,请“点击配置”进行配置。如有疑问可查看文档:https://s0soyusc93k.feishu.cn/wiki/JhhIwAUXJiBHG9kmt3YcXisWnec")
}
})

11
app/types.d.ts vendored
View File

@@ -27,14 +27,7 @@ type ConfigDataType = {
shortCut: string
llm: {
model: string
apiKey: string
baseURL: string
api_key: string
base_url: string
}
}
type LLM = {
model: string
apiKey: string
baseURL: string
}

View File

@@ -8,6 +8,8 @@ def create_app():
app.register_blueprint(home_bp)
from route.shutdown import home_bp
app.register_blueprint(home_bp)
from route.llm import home_bp
app.register_blueprint(home_bp)
return app

Binary file not shown.

31
server/route/llm.py Normal file
View File

@@ -0,0 +1,31 @@
from flask import Blueprint, Response
from flask import request
from litellm import completion
from utils.sql_util import get_config
import json
home_bp = Blueprint('llm', __name__)
@home_bp.route('/llm', methods=["POST"])
def llm():
config = get_config()
messages = request.get_json()["messages"]
isStream = request.get_json().get("isStream", False)
if isStream:
def generate():
response = completion(
messages=messages,
stream=True,
**json.loads(config)["llm"]
)
for part in response:
yield part.choices[0].delta.content or ""
return Response(generate(), mimetype='text/event-stream')
else:
res = completion(
messages=messages,
**json.loads(config)["llm"]
)
return {
"content": res.choices[0].message.content
}

20
server/utils/sql_util.py Normal file
View File

@@ -0,0 +1,20 @@
import os
import sqlite3
def find_all(sql):
home_directory = os.path.expanduser('~')
conn = sqlite3.connect(f'{home_directory}/autoMate.db')
sql_res = conn.execute(sql)
res = sql_res.fetchall()
conn.close()
return res
def get_config():
res = find_all(f"SELECT * FROM config WHERE id = 1")
return res[0][1]
if __name__ == "__main__":
print(get_config())