diff --git a/app/src/main/db/tables.ts b/app/src/main/db/tables.ts index 7eb0fb6..feaae22 100644 --- a/app/src/main/db/tables.ts +++ b/app/src/main/db/tables.ts @@ -38,7 +38,7 @@ initData() function initData() { const initData = findOne('select * from config') if (initData) return - const llm = {model: "gpt-4-turbo", apiKey: "", baseURL: "https://api.openai.com/v1"} + const llm = {model: "gpt-4-turbo", api_key: "", base_url: "https://api.openai.com/v1"} db().exec(`insert into config (id, content) values(1,'{"shortCut":"Alt+d","llm": ${JSON.stringify(llm)}}')`) } diff --git a/app/src/renderer/src/components/Chat/index.tsx b/app/src/renderer/src/components/Chat/index.tsx index 96bc325..7e6f3d0 100644 --- a/app/src/renderer/src/components/Chat/index.tsx +++ b/app/src/renderer/src/components/Chat/index.tsx @@ -1,13 +1,16 @@ -import { ProChat } from '@ant-design/pro-chat'; +import { ProChat, ProChatInstance } from '@ant-design/pro-chat'; import useChat from '@renderer/hooks/useChat'; import { useStore } from '@renderer/store/useStore'; import { useTheme } from 'antd-style'; +import { useRef } from 'react'; export default function Chat(props: {id: number, revalidator: () => void}) { const {id, revalidator} = props; const {getResponse} = useChat() const theme = useTheme(); const chatMessages = useStore(state=>state.chatMessages) const setMessages = useStore(state=>state.setChatMessage) + const proChatRef = useRef(); + return ( void}) { console.log('chat', chat) setMessages(chat) }} + chatRef={proChatRef} style={{ background: theme.colorBgLayout }} // assistantMeta={{ avatar: '', title: '智子', backgroundColor: '#67dedd' }} helloMessage={ @@ -22,8 +26,8 @@ export default function Chat(props: {id: number, revalidator: () => void}) { } request={async (messages) => { - const response = await getResponse(messages, id, revalidator) - return response// 支持流式和非流式 + const response = await getResponse(messages, id, revalidator, proChatRef.current) + return new Response(response.content)// 支持流式和非流式 }} /> ) diff --git a/app/src/renderer/src/hooks/useChat.ts b/app/src/renderer/src/hooks/useChat.ts index d909154..91e00af 100644 --- a/app/src/renderer/src/hooks/useChat.ts +++ b/app/src/renderer/src/hooks/useChat.ts @@ -1,49 +1,46 @@ import { useStore } from "@renderer/store/useStore"; -import { requireAlignmentPrompt, programmerPrompt} from "./prompt"; +import { requireAlignmentPrompt, programmerPrompt } from "./prompt"; import useOpenai from "./useOpenai"; +import { ProChatInstance } from "@ant-design/pro-chat"; -export default ()=>{ +export default () => { const setIsCodeLoading = useStore(state => state.setIsCodeLoading) - const getResponse=(chat_messages: Array, id:number, revalidator: () => void)=>{ - const messages = chat_messages.map((m) => { - return { - role: m.role, - content: m.content - } - }) + const getResponse = async (chat_messages: Array, id: number, revalidator: () => void, proChatRef: ProChatInstance|undefined) => { - // 添加 system 消息 - messages.unshift({ - role: 'system', - content: requireAlignmentPrompt() - }); - const response = useOpenai(requireAlignmentPrompt(), messages, (allContent)=>{ - const programmerCallBack = (allContent: string) => { - allContent = allContent.replace(/^```python/, "").replace(/^```/, "").replace(/```python$/, "").replace(/```$/, "").trim() - window.api.sql('update contents set content = @content where id = @id', - 'update', - {content: allContent, id}) - // 更新数据 - revalidator() - // 关闭loading - setIsCodeLoading(false) - } - if (allContent.includes("【自动化方案】")) { - setIsCodeLoading(true) - useOpenai(programmerPrompt(), [{ - role: "user", - content: allContent - }], (allContent)=>{ - programmerCallBack(allContent) + const response = useOpenai(requireAlignmentPrompt(), chat_messages) + response.then(async (res) => { + if (res.content.includes("【自动化方案】")) { + const chat_id = Date.now().toString() + proChatRef!.pushChat({ + id: chat_id, + createAt: new Date(), + updateAt: new Date(), + role: "assistant", + content: "根据自动化方案生成代码中,请稍等..." + }) + setIsCodeLoading(true) + const programmerResponse = await useOpenai(programmerPrompt(), [{ + role: "user", + content: res.content + }]) + const code = programmerResponse.content.replace(/^```python/, "").replace(/^```/, "").replace(/```python$/, "").replace(/```$/, "").trim() + proChatRef!.setMessageContent( + chat_id, + "根据测试方案生成的python代码如下\n```python\n" + code + "\n```" + ) + window.api.sql('update contents set content = @content where id = @id', + 'update', + { content: code, id}) + // 更新数据 + revalidator() + // 关闭loading + setIsCodeLoading(false) + } }) - } else { - console.log("Response does not contain '【自动化方案】'"); - } - }) - return response - } - return {getResponse} - -} + return response + + } + return { getResponse } +} \ No newline at end of file diff --git a/app/src/renderer/src/hooks/useOpenai.ts b/app/src/renderer/src/hooks/useOpenai.ts index 3d92f0c..45d2ea7 100644 --- a/app/src/renderer/src/hooks/useOpenai.ts +++ b/app/src/renderer/src/hooks/useOpenai.ts @@ -1,9 +1,6 @@ -import { createOpenAI } from "@ai-sdk/openai" -import { streamText } from "ai" -export default async (systemPrompt: string, chatMessages: Array>, callback?: (allContent: string) => void) => { - const configType = (await window.api.getConfig()) as ConfigType - const config = JSON.parse(configType.content) as ConfigDataType +import { localServerBaseUrl } from "@renderer/config" +export default async (systemPrompt: string, chatMessages: Array>) => { const messages = chatMessages.map((m) => { return { role: m.role, @@ -15,46 +12,13 @@ export default async (systemPrompt: string, chatMessages: Array { - if (done) { - controller.close(); - if (callback) { - callback(res); - } - return; - } - res += value; - controller.enqueue(encoder.encode(value)); - push(); - }) - .catch((err) => { - console.error('读取流中的数据时发生错误', err); - controller.error(err); - }); - } - push(); + const response = await fetch(localServerBaseUrl + "/llm", { + method: "POST", + headers: { + "Content-Type": "application/json" }, - }); - return new Response(readableStream) + body: JSON.stringify({messages, isStream: false }) + }) + return response.json() } \ No newline at end of file diff --git a/app/src/renderer/src/layouts/Home/index.tsx b/app/src/renderer/src/layouts/Home/index.tsx index 5143fd2..3facaa7 100644 --- a/app/src/renderer/src/layouts/Home/index.tsx +++ b/app/src/renderer/src/layouts/Home/index.tsx @@ -22,7 +22,7 @@ function Home(): JSX.Element { }, []) window.api.getConfig().then((res)=>{ const config = JSON.parse(res.content) as ConfigDataType - if (config.llm.apiKey=="") { + if (config.llm.api_key=="") { setError("没有检测到大模型配置信息,请“点击配置”进行配置。如有疑问可查看文档:https://s0soyusc93k.feishu.cn/wiki/JhhIwAUXJiBHG9kmt3YcXisWnec") } }) diff --git a/app/types.d.ts b/app/types.d.ts index 64be7a7..672a7ff 100644 --- a/app/types.d.ts +++ b/app/types.d.ts @@ -27,14 +27,7 @@ type ConfigDataType = { shortCut: string llm: { model: string - apiKey: string - baseURL: string + api_key: string + base_url: string } } - - -type LLM = { - model: string - apiKey: string - baseURL: string -} \ No newline at end of file diff --git a/server/main.py b/server/main.py index 8a02cb2..72025a0 100644 --- a/server/main.py +++ b/server/main.py @@ -8,6 +8,8 @@ def create_app(): app.register_blueprint(home_bp) from route.shutdown import home_bp app.register_blueprint(home_bp) + from route.llm import home_bp + app.register_blueprint(home_bp) return app diff --git a/server/requirements.txt b/server/requirements.txt index 66c99d0..224ff93 100644 Binary files a/server/requirements.txt and b/server/requirements.txt differ diff --git a/server/route/llm.py b/server/route/llm.py new file mode 100644 index 0000000..cf3ceb0 --- /dev/null +++ b/server/route/llm.py @@ -0,0 +1,31 @@ +from flask import Blueprint, Response +from flask import request +from litellm import completion +from utils.sql_util import get_config +import json +home_bp = Blueprint('llm', __name__) + +@home_bp.route('/llm', methods=["POST"]) +def llm(): + config = get_config() + messages = request.get_json()["messages"] + isStream = request.get_json().get("isStream", False) + if isStream: + def generate(): + response = completion( + messages=messages, + stream=True, + **json.loads(config)["llm"] + ) + for part in response: + yield part.choices[0].delta.content or "" + return Response(generate(), mimetype='text/event-stream') + else: + res = completion( + messages=messages, + **json.loads(config)["llm"] + ) + return { + "content": res.choices[0].message.content + } + \ No newline at end of file diff --git a/server/utils/sql_util.py b/server/utils/sql_util.py new file mode 100644 index 0000000..a0219f1 --- /dev/null +++ b/server/utils/sql_util.py @@ -0,0 +1,20 @@ +import os +import sqlite3 + + +def find_all(sql): + home_directory = os.path.expanduser('~') + conn = sqlite3.connect(f'{home_directory}/autoMate.db') + sql_res = conn.execute(sql) + res = sql_res.fetchall() + conn.close() + return res + + +def get_config(): + res = find_all(f"SELECT * FROM config WHERE id = 1") + return res[0][1] + + +if __name__ == "__main__": + print(get_config()) \ No newline at end of file