更新搜索引擎

This commit is contained in:
yuruo
2024-02-21 10:05:48 +08:00
parent 27e0b7e31c
commit 434224aa28
6 changed files with 136 additions and 73 deletions

47
main.py
View File

@@ -1,5 +1,11 @@
from langchain.agents import create_openai_functions_agent, AgentExecutor
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, MessagesPlaceholder, PromptTemplate, \
HumanMessagePromptTemplate
from agent.manager_agent import ManagerAgent
from agent.plan_agent import PlanAgent
from tools.search_engine_tool import SearchEngineTool
from utils.llm_util import LLMUtil
from work_principle.okr_principle import OKR_Object
import logging
@@ -11,21 +17,38 @@ class AutoMate:
def __init__(self):
pass
def rule_define(self):
# 与用户对齐任务
while True:
o_kr = OKR_Object(
"因为想要增加编程效率对比一下copilot和curson谁更好用比较提示词数量、安装易用性给出不少于100字的文章")
ManagerAgent().optimization_Object(o_kr)
r = input(f"最终对齐的任务是:{o_kr.raw_user_task}一切都OK对吧y/n\n")
if r == "y":
break
# def rule_define(self):
# # 与用户对齐任务
# while True:
# o_kr = OKR_Object(
# "因为想要增加编程效率对比一下copilot和curson谁更好用比较提示词数量、安装易用性给出不少于100字的文章")
# ManagerAgent().optimization_Object(o_kr)
# r = input(f"最终对齐的任务是:{o_kr.raw_user_task}一切都OK对吧y/n\n")
# if r == "y":
# break
#
# # 让计划拆解者拆解任务
# PlanAgent().aligning(o_kr)
# 让计划拆解者拆解任务
PlanAgent().aligning(o_kr)
def run(self):
rompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate(
prompt=PromptTemplate(input_variables=[], template='你是一个工作助手')),
MessagesPlaceholder(variable_name='chat_history', optional=True),
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], template='{input}')),
MessagesPlaceholder(variable_name='agent_scratchpad')
]
)
model = LLMUtil().llm()
tools = [SearchEngineTool()]
agent = create_openai_functions_agent(model, tools, rompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
r = input("请输入你的问题:\n")
agent_executor.invoke({"input": r})
if __name__ == "__main__":
automator = AutoMate()
automator.rule_define()
automator.run()
# print(automator.call_chatgpt_api("Hello"))

16
tests/test_baidu_api.py Normal file
View File

@@ -0,0 +1,16 @@
from selenium.webdriver.common.by import By
from utils.selenium_util import SeleniumUtil
class TestBaiduApi:
def test_api(self):
selenium = SeleniumUtil()
selenium.get_url("https://www.baidu.com/s?wd=搜索引擎 api 汇总")
result_elements = selenium.get_xpath_elements("//*[@class='result c-container xpath-log new-pmd']")
for result_element in result_elements:
title = result_element.find_element(By.XPATH, ".//h3").text
url = result_element.find_element(By.XPATH, ".//h3/a").get_attribute("href")
short_description = result_element.find_element(By.XPATH, ".//*/span[@class='content-right_8Zs40']").text
print(result_element)

View File

@@ -1,5 +1,6 @@
import logging
import unittest
from idlelib.searchengine import SearchEngine
from langchain.agents import create_openai_functions_agent, AgentExecutor
from langchain_core.messages import HumanMessage
@@ -8,17 +9,18 @@ from langchain_core.prompts import SystemMessagePromptTemplate, PromptTemplate,
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_openai import ChatOpenAI
from tools.web_browser_tool import WebBrowserTool
from tools.search_engine_tool import SearchEngineTool
from utils.llm_util import LLMUtil
class TestWebBrowser:
def test_web_browser(self):
model = LLMUtil().llm()
tools = [WebBrowserTool()]
tools = [SearchEngine()]
model_with_functions = model.bind_functions(tools)
s = model_with_functions.invoke([HumanMessage(
content="帮我查询一下这个网页的内容 https://mbd.baidu.com/newspage/data/landingsuper?context=%7B%22nid%22%3A%22news_9510051560337988929%22%7D&n_type=-1&p_from=-1")])
content="帮我查询一下这个网页的内容 https://mbd.baidu.com/newspage/data/landingsuper?context=%7B%22nid%22%3A"
"%22news_9510051560337988929%22%7D&n_type=-1&p_from=-1")])
print(s.additional_kwargs["function_call"])
def test_agent(self):
@@ -32,8 +34,7 @@ class TestWebBrowser:
]
)
model = LLMUtil().llm()
tools = [WebBrowserTool()]
tools = [SearchEngineTool()]
agent = create_openai_functions_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
# agent_executor.invoke({"input": "你好!你是谁"})
agent_executor.invoke({"input": "帮我查询一下这个网页的内容 https://mbd.baidu.com/newspage/data/landingsuper?context=%7B%22nid%22%3A%22news_9510051560337988929%22%7D&n_type=-1&p_from=-1"})
agent_executor.invoke({"input": "李一舟为什么能成功"})

View File

@@ -0,0 +1,36 @@
from typing import Optional, Type, Any
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
from langchain.pydantic_v1 import BaseModel, Field
from selenium.webdriver.common.by import By
from tools.tool_base import ToolBase
from utils.selenium_util import SeleniumUtil
class SearchInput(BaseModel):
key: str = Field(description="要查询的关键词")
# 利用搜索引擎搜索关键词
class SearchEngineTool(ToolBase):
name = "web_browser"
description = "利用搜索引擎搜索关键词,得到结果列表"
args_schema: Type[BaseModel] = SearchInput
def _run(self, key: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> list[dict[str, Any]]:
"""Use the tool."""
selenium = SeleniumUtil()
selenium.get_url(f"https://www.baidu.com/s?wd={key}")
result_elements = selenium.get_xpath_elements("//*[@class='result c-container xpath-log new-pmd']")
search_result = []
for result_element in result_elements:
title = result_element.find_element(By.XPATH, ".//h3").text
url = result_element.find_element(By.XPATH, ".//h3/a").get_attribute("href")
short_description = result_element.find_element(By.XPATH, ".//*/span[@class='content-right_8Zs40']").text
search_result.append({"title": title, "url": url, "short_description": short_description})
return search_result
async def _arun(
self, key: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")

View File

@@ -8,35 +8,27 @@ from selenium.webdriver.chrome.service import Service as ChromeService
from tools.tool_base import ToolBase
class WebBrowserUrl(ToolBase):
class WebBrowserUrl():
def __init__(self):
super().__init__()
self.name = "web_browser"
self.description = "利用selenium对指定URL进行访问"
self.request_param = '字典,如{"usrl": ""}'
self.return_content = ('{"driver": "selenium的webdriverdriver用于继续在此见面上进行操作可作为web_element工具的入参", "content": '
'"网页xml结构"')
def run(self, param=None):
def run(self, url):
# Load browser configuration from YAML file
driver = None
# Check if webdriver is available
if not os.path.exists("webdriver.exe"):
# Download webdriver based on browser type
browser_type = self.config.BROWSER.get("browser_type")
if browser_type == "chrome":
options = webdriver.ChromeOptions()
options.add_argument("--headless") # Enable headless mode
webdriver_manager = ChromeDriverManager()
driver = webdriver.Chrome(service=ChromeService(webdriver_manager.install()), options=options)
elif browser_type == "edge":
options = webdriver.EdgeOptions()
options.add_argument("--headless") # Enable headless mode
webdriver_manager = EdgeChromiumDriverManager()
driver = webdriver.Edge(service=EdgeService(webdriver_manager.install()), options=options)
else:
return
# Download webdriver based on browser type
browser_type = self.config.BROWSER.get("browser_type")
if browser_type == "chrome":
options = webdriver.ChromeOptions()
options.add_argument("--headless") # Enable headless mode
webdriver_manager = ChromeDriverManager()
driver = webdriver.Chrome(service=ChromeService(webdriver_manager.install()), options=options)
elif browser_type == "edge":
options = webdriver.EdgeOptions()
options.add_argument("--headless") # Enable headless mode
webdriver_manager = EdgeChromiumDriverManager()
driver = webdriver.Edge(service=EdgeService(webdriver_manager.install()), options=options)
else:
return
driver.implicitly_wait(10)
driver.get(param["url"])
driver.quit()
return driver.page_source
return driver

View File

@@ -1,33 +1,17 @@
import logging
import os
from typing import Optional, Type, Any
from langchain_core.tools import BaseTool
from selenium.webdriver.chrome.service import Service as ChromeService
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
from langchain.pydantic_v1 import BaseModel, Field
from selenium import webdriver
from selenium.webdriver.common.by import By
from webdriver_manager.chrome import ChromeDriverManager
from webdriver_manager.microsoft import EdgeChromiumDriverManager
from selenium.webdriver.edge.service import Service as EdgeService
from selenium.webdriver.chrome.service import Service as ChromeService
from tools.tool_base import ToolBase
from utils.config import Config
class SearchInput(BaseModel):
url: str = Field(description="要查询的网址")
class WebBrowserTool(ToolBase):
name = "web_browser"
description = "利用浏览器访问url得到网页源码"
args_schema: Type[BaseModel] = SearchInput
def _run(self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
"""Use the tool."""
class SeleniumUtil:
def __init__(self):
config = Config()
# Download webdriver based on browser type
browser_type = config.BROWSER.get("browser_type")
if browser_type == "chrome":
options = webdriver.ChromeOptions()
@@ -40,16 +24,27 @@ class WebBrowserTool(ToolBase):
webdriver_manager = EdgeChromiumDriverManager()
driver = webdriver.Edge(service=EdgeService(webdriver_manager.install()), options=options)
else:
return ""
return
driver.implicitly_wait(10)
driver.get(url)
driver.quit()
print(f"res {driver.page_source}")
self.driver = driver
return driver.page_source
def get_url(self, url):
self.driver.get(url)
async def _arun(
self, url: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")
def click(self, xpath):
self.driver.find_element(By.XPATH, xpath).click()
def send(self, xpath, text):
self.driver.find_element(By.XPATH, xpath).send_keys(text)
def quit(self):
self.driver.quit()
def get_text(self, xpath):
return self.driver.find_element(xpath).text
def get_attribute(self, xpath, name):
return self.driver.find_element(xpath).get_attribute(name)
def get_xpath_elements(self, xpath):
return self.driver.find_elements(By.XPATH, xpath)