发现代码时,增加两个按键;

server端完成代码提取。
This commit is contained in:
yuruo
2024-07-12 11:31:08 +08:00
parent 0c66c44982
commit c7bbfe98c5
4 changed files with 61 additions and 60 deletions

View File

@@ -26,7 +26,6 @@ export default function Chat(props: {id: number, revalidator: () => void, search
<ProChat
chats={chatMessages}
onChatsChange={(chat)=>{
console.log('chat', chat)
setMessages(chat)
}}
chatItemRenderConfig={{
@@ -62,29 +61,20 @@ export default function Chat(props: {id: number, revalidator: () => void, search
}
request={async (messages) => {
// const response = await getResponse(messages)
// if (response.isExistCode === 0) {
// setTimeout(() => {
// proChatRef.current?.sendMessage({
// type: 'text',
// content: response.content,
// role: 'coder',
// originData: response
// });
// }, 1000); // 延时1秒推送消息
// }
// return new Response(response.content)// 支持流式和非流式
setTimeout(() => {
proChatRef.current?.pushChat({
const response = await getResponse(messages)
console.log(response)
if (response.isExistCode === 0) {
setTimeout(() => {
proChatRef.current?.pushChat({
type: 'text',
content: "hello",
content: response.content,
role: 'coder',
originData: "dd"
originData: response
});
}, 1000); // 延时1秒推送消息
return new Response("hello")
}
return new Response(response.content)
}}
/>
)

View File

@@ -3,21 +3,23 @@ import { localServerBaseUrl } from "@renderer/config";
export default () => {
const getResponse = async (chat_messages: Array<any>) => {
const messages = chat_messages
.filter((m) => ['assistant', 'user'].includes(m.role)) // 过滤出 role 为 assistant 和 user 的消息
.map((m) => {
return {
role: m.role,
content: m.content
}
})
const messages = chat_messages.map((m) => {
return {
role: m.role,
content: m.content
}
})
const response = await fetch(localServerBaseUrl + "/llm", {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify({messages, isStream: false })
})
const response = await fetch(localServerBaseUrl + "/llm", {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify({messages, isStream: false })
})
return response.json()

View File

@@ -4,43 +4,37 @@ from utils.sql_util import get_config
from agent.prompt import code_prompt
import json
import re
home_bp = Blueprint('llm', __name__)
@home_bp.route('/llm', methods=["POST"])
def llm():
data = request.get_json()
messages = data["messages"]
isStream = data.get("isStream", False)
if data.get("llm_config"):
config = json.loads(data.get("llm_config"))
else:
config = json.loads(get_config())["llm"]
messages = [{"role": "system", "content": code_prompt.substitute()}] + messages
# 暂时没有strem
if isStream:
def generate():
response = completion(messages=messages, stream=True, **config)
for part in response:
yield part.choices[0].delta.content or ""
return Response(generate(), mimetype='text/event-stream')
else:
try:
res = completion(messages=messages, **config).choices[0].message.content
return {"content": res, "isExistCode": contains_code(res), "status": 0}
except Exception as e:
return {"content": str(e), "status": 1}
try:
res = completion(messages=messages, **config).choices[0].message.content
return {"content": res, "code": extract_code_blocks(res), "status": 0}
except Exception as e:
return {"content": str(e), "status": 1}
def contains_code(text):
markdown_patterns = [
r'```.*?```',
r'```[\s\S]*?```'
def extract_code_blocks(text):
pattern_match = [
r'.*?```python([\s\S]*?)```.*',
r'.*?```([\s\S]*?)```.*'
]
for pattern in markdown_patterns:
if re.search(pattern, text, re.MULTILINE):
return 0
return 1
if __name__ == "__main__":
print(contains_code("为了帮助你打开并读取位于桌面上的 `a.txt` 文件以下是相应的Python代码。请确保根据你的系统环境如Windows或Mac OS调整文件路径。\\n\\n```python\\n# 打开并读取桌面上的 a.txt 文件\\ntry:\\n with open('/Users/your_username/Desktop/a.txt', 'r') as file: # 请根据你的系统路径修改文件路径\\n content = file.read() # 读取文件\\n print(content) # 显示文件内容\\nexcept FileNotFoundError:\\n print(\\\"文件没有找到,请确保文件路径正确。\\\")\\nexcept Exception as e:\\n print(\\\"读取文件时发生错误:\\\", e)\\n```\\n\\n请将 `/Users/your_username/Desktop/a.txt` 中的 `your_username` 替换为你的用户名称。如果你是Windows用户路径可能类似于 `C:\\\\\\\\Users\\\\\\\\your_username\\\\\\\\Desktop\\\\\\\\a.txt`。"))
for pattern in pattern_match:
pattern = re.compile(pattern, re.MULTILINE).findall(text)
if pattern:
return pattern[0]
else:
continue
return ""

15
server/route/test.py Normal file
View File

@@ -0,0 +1,15 @@
import re
def extract_code_blocks(text):
pattern_match = [
r'.*?```python([\s\S]*?)```.*',
r'.*?```([\s\S]*?)```.*'
]
for pattern in pattern_match:
pattern = re.compile(pattern, re.MULTILINE).findall(text)
if pattern:
return pattern[0]
else:
continue
if __name__ == "__main__":
text = "为了帮助你打开并读取位于桌面上的 `a.txt` 文件以下是相应的Python代码。请确保根据你的系统环境如Windows或Mac OS调整文件路径。\\n\\n```python\\n# 打开并读取桌面上的 a.txt 文件\\ntry:\\n with open('/Users/your_username/Desktop/a.txt', 'r') as file: # 请根据你的系统路径修改文件路径\\n content = file.read() # 读取文件\\n print(content) # 显示文件内容\\nexcept FileNotFoundError:\\n print(\\\"文件没有找到,请确保文件路径正确。\\\")\\nexcept Exception as e:\\n print(\\\"读取文件时发生错误:\\\", e)\\n```\\n\\n请将 `/Users/your_username/Desktop/a.txt` 中的 `your_username` 替换为你的用户名称。如果你是Windows用户路径可能类似于 `C:\\\\\\\\Users\\\\\\\\your_username\\\\\\\\Desktop\\\\\\\\a.txt`。"
print(extract_code_blocks(text))