diff --git a/app/src/renderer/src/components/Chat/index.tsx b/app/src/renderer/src/components/Chat/index.tsx index a05dd4c..cd65501 100644 --- a/app/src/renderer/src/components/Chat/index.tsx +++ b/app/src/renderer/src/components/Chat/index.tsx @@ -26,7 +26,6 @@ export default function Chat(props: {id: number, revalidator: () => void, search { - console.log('chat', chat) setMessages(chat) }} chatItemRenderConfig={{ @@ -62,29 +61,20 @@ export default function Chat(props: {id: number, revalidator: () => void, search } request={async (messages) => { - // const response = await getResponse(messages) - // if (response.isExistCode === 0) { - // setTimeout(() => { - // proChatRef.current?.sendMessage({ - // type: 'text', - // content: response.content, - // role: 'coder', - // originData: response - // }); - // }, 1000); // 延时1秒推送消息 - - // } - - // return new Response(response.content)// 支持流式和非流式 - setTimeout(() => { - proChatRef.current?.pushChat({ + const response = await getResponse(messages) + console.log(response) + if (response.isExistCode === 0) { + setTimeout(() => { + proChatRef.current?.pushChat({ type: 'text', - content: "hello", + content: response.content, role: 'coder', - originData: "dd" + originData: response }); + }, 1000); // 延时1秒推送消息 - return new Response("hello") + } + return new Response(response.content) }} /> ) diff --git a/app/src/renderer/src/hooks/useChat.ts b/app/src/renderer/src/hooks/useChat.ts index 2ed7d5c..5bf1b06 100644 --- a/app/src/renderer/src/hooks/useChat.ts +++ b/app/src/renderer/src/hooks/useChat.ts @@ -3,21 +3,23 @@ import { localServerBaseUrl } from "@renderer/config"; export default () => { const getResponse = async (chat_messages: Array) => { + + const messages = chat_messages + .filter((m) => ['assistant', 'user'].includes(m.role)) // 过滤出 role 为 assistant 和 user 的消息 + .map((m) => { + return { + role: m.role, + content: m.content + } + }) - const messages = chat_messages.map((m) => { - return { - role: m.role, - content: m.content - } - }) - - const response = await fetch(localServerBaseUrl + "/llm", { - method: "POST", - headers: { - "Content-Type": "application/json" - }, - body: JSON.stringify({messages, isStream: false }) - }) + const response = await fetch(localServerBaseUrl + "/llm", { + method: "POST", + headers: { + "Content-Type": "application/json" + }, + body: JSON.stringify({messages, isStream: false }) + }) return response.json() diff --git a/server/route/llm.py b/server/route/llm.py index 0e955d4..d924f46 100644 --- a/server/route/llm.py +++ b/server/route/llm.py @@ -4,43 +4,37 @@ from utils.sql_util import get_config from agent.prompt import code_prompt import json import re + + + home_bp = Blueprint('llm', __name__) @home_bp.route('/llm', methods=["POST"]) def llm(): data = request.get_json() messages = data["messages"] - isStream = data.get("isStream", False) if data.get("llm_config"): config = json.loads(data.get("llm_config")) else: config = json.loads(get_config())["llm"] messages = [{"role": "system", "content": code_prompt.substitute()}] + messages - # 暂时没有strem - if isStream: - def generate(): - response = completion(messages=messages, stream=True, **config) - for part in response: - yield part.choices[0].delta.content or "" - return Response(generate(), mimetype='text/event-stream') - else: - try: - res = completion(messages=messages, **config).choices[0].message.content - return {"content": res, "isExistCode": contains_code(res), "status": 0} - except Exception as e: - return {"content": str(e), "status": 1} + try: + res = completion(messages=messages, **config).choices[0].message.content + return {"content": res, "code": extract_code_blocks(res), "status": 0} + except Exception as e: + return {"content": str(e), "status": 1} -def contains_code(text): - markdown_patterns = [ - r'```.*?```', - r'```[\s\S]*?```' + +def extract_code_blocks(text): + pattern_match = [ + r'.*?```python([\s\S]*?)```.*', + r'.*?```([\s\S]*?)```.*' ] - for pattern in markdown_patterns: - if re.search(pattern, text, re.MULTILINE): - return 0 - return 1 - - -if __name__ == "__main__": - print(contains_code("为了帮助你打开并读取位于桌面上的 `a.txt` 文件,以下是相应的Python代码。请确保根据你的系统环境(如Windows或Mac OS),调整文件路径。\\n\\n```python\\n# 打开并读取桌面上的 a.txt 文件\\ntry:\\n with open('/Users/your_username/Desktop/a.txt', 'r') as file: # 请根据你的系统路径修改文件路径\\n content = file.read() # 读取文件\\n print(content) # 显示文件内容\\nexcept FileNotFoundError:\\n print(\\\"文件没有找到,请确保文件路径正确。\\\")\\nexcept Exception as e:\\n print(\\\"读取文件时发生错误:\\\", e)\\n```\\n\\n请将 `/Users/your_username/Desktop/a.txt` 中的 `your_username` 替换为你的用户名称。如果你是Windows用户,路径可能类似于 `C:\\\\\\\\Users\\\\\\\\your_username\\\\\\\\Desktop\\\\\\\\a.txt`。")) \ No newline at end of file + for pattern in pattern_match: + pattern = re.compile(pattern, re.MULTILINE).findall(text) + if pattern: + return pattern[0] + else: + continue + return "" \ No newline at end of file diff --git a/server/route/test.py b/server/route/test.py new file mode 100644 index 0000000..55fb0a7 --- /dev/null +++ b/server/route/test.py @@ -0,0 +1,15 @@ +import re +def extract_code_blocks(text): + pattern_match = [ + r'.*?```python([\s\S]*?)```.*', + r'.*?```([\s\S]*?)```.*' + ] + for pattern in pattern_match: + pattern = re.compile(pattern, re.MULTILINE).findall(text) + if pattern: + return pattern[0] + else: + continue +if __name__ == "__main__": + text = "为了帮助你打开并读取位于桌面上的 `a.txt` 文件,以下是相应的Python代码。请确保根据你的系统环境(如Windows或Mac OS),调整文件路径。\\n\\n```python\\n# 打开并读取桌面上的 a.txt 文件\\ntry:\\n with open('/Users/your_username/Desktop/a.txt', 'r') as file: # 请根据你的系统路径修改文件路径\\n content = file.read() # 读取文件\\n print(content) # 显示文件内容\\nexcept FileNotFoundError:\\n print(\\\"文件没有找到,请确保文件路径正确。\\\")\\nexcept Exception as e:\\n print(\\\"读取文件时发生错误:\\\", e)\\n```\\n\\n请将 `/Users/your_username/Desktop/a.txt` 中的 `your_username` 替换为你的用户名称。如果你是Windows用户,路径可能类似于 `C:\\\\\\\\Users\\\\\\\\your_username\\\\\\\\Desktop\\\\\\\\a.txt`。" + print(extract_code_blocks(text))