mirror of
https://github.com/yuruotong1/autoMate.git
synced 2026-03-22 13:07:17 +08:00
兼容多个大模型
This commit is contained in:
@@ -38,7 +38,7 @@ initData()
|
||||
function initData() {
|
||||
const initData = findOne('select * from config')
|
||||
if (initData) return
|
||||
const llm = {model: "gpt-4-turbo", apiKey: "", baseURL: "https://api.openai.com/v1"}
|
||||
const llm = {model: "gpt-4-turbo", api_key: "", base_url: "https://api.openai.com/v1"}
|
||||
db().exec(`insert into config (id, content) values(1,'{"shortCut":"Alt+d","llm": ${JSON.stringify(llm)}}')`)
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import { ProChat } from '@ant-design/pro-chat';
|
||||
import { ProChat, ProChatInstance } from '@ant-design/pro-chat';
|
||||
import useChat from '@renderer/hooks/useChat';
|
||||
import { useStore } from '@renderer/store/useStore';
|
||||
import { useTheme } from 'antd-style';
|
||||
import { useRef } from 'react';
|
||||
export default function Chat(props: {id: number, revalidator: () => void}) {
|
||||
const {id, revalidator} = props;
|
||||
const {getResponse} = useChat()
|
||||
const theme = useTheme();
|
||||
const chatMessages = useStore(state=>state.chatMessages)
|
||||
const setMessages = useStore(state=>state.setChatMessage)
|
||||
const proChatRef = useRef<ProChatInstance>();
|
||||
|
||||
return (
|
||||
<ProChat
|
||||
chats={chatMessages}
|
||||
@@ -15,6 +18,7 @@ export default function Chat(props: {id: number, revalidator: () => void}) {
|
||||
console.log('chat', chat)
|
||||
setMessages(chat)
|
||||
}}
|
||||
chatRef={proChatRef}
|
||||
style={{ background: theme.colorBgLayout }}
|
||||
// assistantMeta={{ avatar: '', title: '智子', backgroundColor: '#67dedd' }}
|
||||
helloMessage={
|
||||
@@ -22,8 +26,8 @@ export default function Chat(props: {id: number, revalidator: () => void}) {
|
||||
}
|
||||
|
||||
request={async (messages) => {
|
||||
const response = await getResponse(messages, id, revalidator)
|
||||
return response// 支持流式和非流式
|
||||
const response = await getResponse(messages, id, revalidator, proChatRef.current)
|
||||
return new Response(response.content)// 支持流式和非流式
|
||||
}}
|
||||
/>
|
||||
)
|
||||
|
||||
@@ -1,49 +1,46 @@
|
||||
import { useStore } from "@renderer/store/useStore";
|
||||
import { requireAlignmentPrompt, programmerPrompt} from "./prompt";
|
||||
import { requireAlignmentPrompt, programmerPrompt } from "./prompt";
|
||||
import useOpenai from "./useOpenai";
|
||||
import { ProChatInstance } from "@ant-design/pro-chat";
|
||||
|
||||
export default ()=>{
|
||||
export default () => {
|
||||
const setIsCodeLoading = useStore(state => state.setIsCodeLoading)
|
||||
|
||||
const getResponse=(chat_messages: Array<any>, id:number, revalidator: () => void)=>{
|
||||
const messages = chat_messages.map((m) => {
|
||||
return {
|
||||
role: m.role,
|
||||
content: m.content
|
||||
}
|
||||
})
|
||||
const getResponse = async (chat_messages: Array<any>, id: number, revalidator: () => void, proChatRef: ProChatInstance|undefined) => {
|
||||
|
||||
// 添加 system 消息
|
||||
messages.unshift({
|
||||
role: 'system',
|
||||
content: requireAlignmentPrompt()
|
||||
});
|
||||
const response = useOpenai(requireAlignmentPrompt(), messages, (allContent)=>{
|
||||
const programmerCallBack = (allContent: string) => {
|
||||
allContent = allContent.replace(/^```python/, "").replace(/^```/, "").replace(/```python$/, "").replace(/```$/, "").trim()
|
||||
window.api.sql('update contents set content = @content where id = @id',
|
||||
'update',
|
||||
{content: allContent, id})
|
||||
// 更新数据
|
||||
revalidator()
|
||||
// 关闭loading
|
||||
setIsCodeLoading(false)
|
||||
}
|
||||
if (allContent.includes("【自动化方案】")) {
|
||||
setIsCodeLoading(true)
|
||||
useOpenai(programmerPrompt(), [{
|
||||
role: "user",
|
||||
content: allContent
|
||||
}], (allContent)=>{
|
||||
programmerCallBack(allContent)
|
||||
const response = useOpenai(requireAlignmentPrompt(), chat_messages)
|
||||
response.then(async (res) => {
|
||||
if (res.content.includes("【自动化方案】")) {
|
||||
const chat_id = Date.now().toString()
|
||||
proChatRef!.pushChat({
|
||||
id: chat_id,
|
||||
createAt: new Date(),
|
||||
updateAt: new Date(),
|
||||
role: "assistant",
|
||||
content: "根据自动化方案生成代码中,请稍等..."
|
||||
})
|
||||
setIsCodeLoading(true)
|
||||
const programmerResponse = await useOpenai(programmerPrompt(), [{
|
||||
role: "user",
|
||||
content: res.content
|
||||
}])
|
||||
const code = programmerResponse.content.replace(/^```python/, "").replace(/^```/, "").replace(/```python$/, "").replace(/```$/, "").trim()
|
||||
proChatRef!.setMessageContent(
|
||||
chat_id,
|
||||
"根据测试方案生成的python代码如下\n```python\n" + code + "\n```"
|
||||
)
|
||||
window.api.sql('update contents set content = @content where id = @id',
|
||||
'update',
|
||||
{ content: code, id})
|
||||
// 更新数据
|
||||
revalidator()
|
||||
// 关闭loading
|
||||
setIsCodeLoading(false)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
console.log("Response does not contain '【自动化方案】'");
|
||||
}
|
||||
})
|
||||
return response
|
||||
}
|
||||
return {getResponse}
|
||||
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
}
|
||||
return { getResponse }
|
||||
}
|
||||
@@ -1,9 +1,6 @@
|
||||
import { createOpenAI } from "@ai-sdk/openai"
|
||||
import { streamText } from "ai"
|
||||
|
||||
export default async (systemPrompt: string, chatMessages: Array<Record<string, any>>, callback?: (allContent: string) => void) => {
|
||||
const configType = (await window.api.getConfig()) as ConfigType
|
||||
const config = JSON.parse(configType.content) as ConfigDataType
|
||||
import { localServerBaseUrl } from "@renderer/config"
|
||||
export default async (systemPrompt: string, chatMessages: Array<Record<string, any>>) => {
|
||||
const messages = chatMessages.map((m) => {
|
||||
return {
|
||||
role: m.role,
|
||||
@@ -15,46 +12,13 @@ export default async (systemPrompt: string, chatMessages: Array<Record<string, a
|
||||
role: 'system',
|
||||
content: systemPrompt
|
||||
});
|
||||
const openai = createOpenAI({
|
||||
apiKey: config.llm.apiKey,
|
||||
baseURL: config.llm.baseURL,
|
||||
compatibility: 'compatible'
|
||||
});
|
||||
|
||||
|
||||
const stream = await streamText({
|
||||
model: openai(config.llm.model),
|
||||
messages: [...messages],
|
||||
});
|
||||
// 获取 reader
|
||||
const reader = stream.textStream.getReader();
|
||||
const encoder = new TextEncoder();
|
||||
|
||||
const readableStream = new ReadableStream({
|
||||
async start(controller) {
|
||||
let res = ""
|
||||
function push() {
|
||||
reader
|
||||
.read()
|
||||
.then(({ done, value }) => {
|
||||
if (done) {
|
||||
controller.close();
|
||||
if (callback) {
|
||||
callback(res);
|
||||
}
|
||||
return;
|
||||
}
|
||||
res += value;
|
||||
controller.enqueue(encoder.encode(value));
|
||||
push();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error('读取流中的数据时发生错误', err);
|
||||
controller.error(err);
|
||||
});
|
||||
}
|
||||
push();
|
||||
const response = await fetch(localServerBaseUrl + "/llm", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
});
|
||||
return new Response(readableStream)
|
||||
body: JSON.stringify({messages, isStream: false })
|
||||
})
|
||||
return response.json()
|
||||
}
|
||||
@@ -22,7 +22,7 @@ function Home(): JSX.Element {
|
||||
}, [])
|
||||
window.api.getConfig().then((res)=>{
|
||||
const config = JSON.parse(res.content) as ConfigDataType
|
||||
if (config.llm.apiKey=="") {
|
||||
if (config.llm.api_key=="") {
|
||||
setError("没有检测到大模型配置信息,请“点击配置”进行配置。如有疑问可查看文档:https://s0soyusc93k.feishu.cn/wiki/JhhIwAUXJiBHG9kmt3YcXisWnec")
|
||||
}
|
||||
})
|
||||
|
||||
11
app/types.d.ts
vendored
11
app/types.d.ts
vendored
@@ -27,14 +27,7 @@ type ConfigDataType = {
|
||||
shortCut: string
|
||||
llm: {
|
||||
model: string
|
||||
apiKey: string
|
||||
baseURL: string
|
||||
api_key: string
|
||||
base_url: string
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
type LLM = {
|
||||
model: string
|
||||
apiKey: string
|
||||
baseURL: string
|
||||
}
|
||||
@@ -8,6 +8,8 @@ def create_app():
|
||||
app.register_blueprint(home_bp)
|
||||
from route.shutdown import home_bp
|
||||
app.register_blueprint(home_bp)
|
||||
from route.llm import home_bp
|
||||
app.register_blueprint(home_bp)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
Binary file not shown.
31
server/route/llm.py
Normal file
31
server/route/llm.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from flask import Blueprint, Response
|
||||
from flask import request
|
||||
from litellm import completion
|
||||
from utils.sql_util import get_config
|
||||
import json
|
||||
home_bp = Blueprint('llm', __name__)
|
||||
|
||||
@home_bp.route('/llm', methods=["POST"])
|
||||
def llm():
|
||||
config = get_config()
|
||||
messages = request.get_json()["messages"]
|
||||
isStream = request.get_json().get("isStream", False)
|
||||
if isStream:
|
||||
def generate():
|
||||
response = completion(
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**json.loads(config)["llm"]
|
||||
)
|
||||
for part in response:
|
||||
yield part.choices[0].delta.content or ""
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
else:
|
||||
res = completion(
|
||||
messages=messages,
|
||||
**json.loads(config)["llm"]
|
||||
)
|
||||
return {
|
||||
"content": res.choices[0].message.content
|
||||
}
|
||||
|
||||
20
server/utils/sql_util.py
Normal file
20
server/utils/sql_util.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
|
||||
def find_all(sql):
|
||||
home_directory = os.path.expanduser('~')
|
||||
conn = sqlite3.connect(f'{home_directory}/autoMate.db')
|
||||
sql_res = conn.execute(sql)
|
||||
res = sql_res.fetchall()
|
||||
conn.close()
|
||||
return res
|
||||
|
||||
|
||||
def get_config():
|
||||
res = find_all(f"SELECT * FROM config WHERE id = 1")
|
||||
return res[0][1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_config())
|
||||
Reference in New Issue
Block a user