mirror of
https://github.com/yuruotong1/autoMate.git
synced 2026-03-22 13:07:17 +08:00
更新搜索引擎
This commit is contained in:
47
main.py
47
main.py
@@ -1,5 +1,11 @@
|
||||
from langchain.agents import create_openai_functions_agent, AgentExecutor
|
||||
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, MessagesPlaceholder, PromptTemplate, \
|
||||
HumanMessagePromptTemplate
|
||||
|
||||
from agent.manager_agent import ManagerAgent
|
||||
from agent.plan_agent import PlanAgent
|
||||
from tools.search_engine_tool import SearchEngineTool
|
||||
from utils.llm_util import LLMUtil
|
||||
from work_principle.okr_principle import OKR_Object
|
||||
import logging
|
||||
|
||||
@@ -11,21 +17,38 @@ class AutoMate:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def rule_define(self):
|
||||
# 与用户对齐任务
|
||||
while True:
|
||||
o_kr = OKR_Object(
|
||||
"因为想要增加编程效率,对比一下copilot和curson谁更好用,比较提示词数量、安装易用性,给出不少于100字的文章")
|
||||
ManagerAgent().optimization_Object(o_kr)
|
||||
r = input(f"最终对齐的任务是:{o_kr.raw_user_task},一切都OK对吧?y/n\n")
|
||||
if r == "y":
|
||||
break
|
||||
# def rule_define(self):
|
||||
# # 与用户对齐任务
|
||||
# while True:
|
||||
# o_kr = OKR_Object(
|
||||
# "因为想要增加编程效率,对比一下copilot和curson谁更好用,比较提示词数量、安装易用性,给出不少于100字的文章")
|
||||
# ManagerAgent().optimization_Object(o_kr)
|
||||
# r = input(f"最终对齐的任务是:{o_kr.raw_user_task},一切都OK对吧?y/n\n")
|
||||
# if r == "y":
|
||||
# break
|
||||
#
|
||||
# # 让计划拆解者拆解任务
|
||||
# PlanAgent().aligning(o_kr)
|
||||
|
||||
# 让计划拆解者拆解任务
|
||||
PlanAgent().aligning(o_kr)
|
||||
def run(self):
|
||||
rompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate(
|
||||
prompt=PromptTemplate(input_variables=[], template='你是一个工作助手')),
|
||||
MessagesPlaceholder(variable_name='chat_history', optional=True),
|
||||
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], template='{input}')),
|
||||
MessagesPlaceholder(variable_name='agent_scratchpad')
|
||||
]
|
||||
)
|
||||
model = LLMUtil().llm()
|
||||
tools = [SearchEngineTool()]
|
||||
agent = create_openai_functions_agent(model, tools, rompt)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
|
||||
r = input("请输入你的问题:\n")
|
||||
agent_executor.invoke({"input": r})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
automator = AutoMate()
|
||||
automator.rule_define()
|
||||
automator.run()
|
||||
# print(automator.call_chatgpt_api("Hello"))
|
||||
|
||||
16
tests/test_baidu_api.py
Normal file
16
tests/test_baidu_api.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from selenium.webdriver.common.by import By
|
||||
|
||||
from utils.selenium_util import SeleniumUtil
|
||||
|
||||
|
||||
class TestBaiduApi:
|
||||
def test_api(self):
|
||||
selenium = SeleniumUtil()
|
||||
selenium.get_url("https://www.baidu.com/s?wd=搜索引擎 api 汇总")
|
||||
result_elements = selenium.get_xpath_elements("//*[@class='result c-container xpath-log new-pmd']")
|
||||
for result_element in result_elements:
|
||||
title = result_element.find_element(By.XPATH, ".//h3").text
|
||||
url = result_element.find_element(By.XPATH, ".//h3/a").get_attribute("href")
|
||||
short_description = result_element.find_element(By.XPATH, ".//*/span[@class='content-right_8Zs40']").text
|
||||
print(result_element)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import unittest
|
||||
from idlelib.searchengine import SearchEngine
|
||||
|
||||
from langchain.agents import create_openai_functions_agent, AgentExecutor
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -8,17 +9,18 @@ from langchain_core.prompts import SystemMessagePromptTemplate, PromptTemplate,
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from tools.web_browser_tool import WebBrowserTool
|
||||
from tools.search_engine_tool import SearchEngineTool
|
||||
from utils.llm_util import LLMUtil
|
||||
|
||||
|
||||
class TestWebBrowser:
|
||||
def test_web_browser(self):
|
||||
model = LLMUtil().llm()
|
||||
tools = [WebBrowserTool()]
|
||||
tools = [SearchEngine()]
|
||||
model_with_functions = model.bind_functions(tools)
|
||||
s = model_with_functions.invoke([HumanMessage(
|
||||
content="帮我查询一下这个网页的内容 https://mbd.baidu.com/newspage/data/landingsuper?context=%7B%22nid%22%3A%22news_9510051560337988929%22%7D&n_type=-1&p_from=-1")])
|
||||
content="帮我查询一下这个网页的内容 https://mbd.baidu.com/newspage/data/landingsuper?context=%7B%22nid%22%3A"
|
||||
"%22news_9510051560337988929%22%7D&n_type=-1&p_from=-1")])
|
||||
print(s.additional_kwargs["function_call"])
|
||||
|
||||
def test_agent(self):
|
||||
@@ -32,8 +34,7 @@ class TestWebBrowser:
|
||||
]
|
||||
)
|
||||
model = LLMUtil().llm()
|
||||
tools = [WebBrowserTool()]
|
||||
tools = [SearchEngineTool()]
|
||||
agent = create_openai_functions_agent(model, tools, prompt)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
|
||||
# agent_executor.invoke({"input": "你好!你是谁"})
|
||||
agent_executor.invoke({"input": "帮我查询一下这个网页的内容 https://mbd.baidu.com/newspage/data/landingsuper?context=%7B%22nid%22%3A%22news_9510051560337988929%22%7D&n_type=-1&p_from=-1"})
|
||||
agent_executor.invoke({"input": "李一舟为什么能成功"})
|
||||
|
||||
36
tools/search_engine_tool.py
Normal file
36
tools/search_engine_tool.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Optional, Type, Any
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from selenium.webdriver.common.by import By
|
||||
from tools.tool_base import ToolBase
|
||||
from utils.selenium_util import SeleniumUtil
|
||||
|
||||
|
||||
class SearchInput(BaseModel):
|
||||
key: str = Field(description="要查询的关键词")
|
||||
|
||||
|
||||
# 利用搜索引擎搜索关键词
|
||||
class SearchEngineTool(ToolBase):
|
||||
name = "web_browser"
|
||||
description = "利用搜索引擎搜索关键词,得到结果列表"
|
||||
args_schema: Type[BaseModel] = SearchInput
|
||||
|
||||
def _run(self, key: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> list[dict[str, Any]]:
|
||||
"""Use the tool."""
|
||||
selenium = SeleniumUtil()
|
||||
selenium.get_url(f"https://www.baidu.com/s?wd={key}")
|
||||
result_elements = selenium.get_xpath_elements("//*[@class='result c-container xpath-log new-pmd']")
|
||||
search_result = []
|
||||
for result_element in result_elements:
|
||||
title = result_element.find_element(By.XPATH, ".//h3").text
|
||||
url = result_element.find_element(By.XPATH, ".//h3/a").get_attribute("href")
|
||||
short_description = result_element.find_element(By.XPATH, ".//*/span[@class='content-right_8Zs40']").text
|
||||
search_result.append({"title": title, "url": url, "short_description": short_description})
|
||||
return search_result
|
||||
|
||||
async def _arun(
|
||||
self, key: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("custom_search does not support async")
|
||||
@@ -8,35 +8,27 @@ from selenium.webdriver.chrome.service import Service as ChromeService
|
||||
from tools.tool_base import ToolBase
|
||||
|
||||
|
||||
class WebBrowserUrl(ToolBase):
|
||||
class WebBrowserUrl():
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "web_browser"
|
||||
self.description = "利用selenium对指定URL进行访问"
|
||||
self.request_param = '字典,如{"usrl": ""}'
|
||||
self.return_content = ('{"driver": "selenium的webdriver,driver用于继续在此见面上进行操作,可作为web_element工具的入参", "content": '
|
||||
'"网页xml结构"')
|
||||
|
||||
def run(self, param=None):
|
||||
def run(self, url):
|
||||
# Load browser configuration from YAML file
|
||||
driver = None
|
||||
# Check if webdriver is available
|
||||
if not os.path.exists("webdriver.exe"):
|
||||
# Download webdriver based on browser type
|
||||
browser_type = self.config.BROWSER.get("browser_type")
|
||||
if browser_type == "chrome":
|
||||
options = webdriver.ChromeOptions()
|
||||
options.add_argument("--headless") # Enable headless mode
|
||||
webdriver_manager = ChromeDriverManager()
|
||||
driver = webdriver.Chrome(service=ChromeService(webdriver_manager.install()), options=options)
|
||||
elif browser_type == "edge":
|
||||
options = webdriver.EdgeOptions()
|
||||
options.add_argument("--headless") # Enable headless mode
|
||||
webdriver_manager = EdgeChromiumDriverManager()
|
||||
driver = webdriver.Edge(service=EdgeService(webdriver_manager.install()), options=options)
|
||||
else:
|
||||
return
|
||||
# Download webdriver based on browser type
|
||||
browser_type = self.config.BROWSER.get("browser_type")
|
||||
if browser_type == "chrome":
|
||||
options = webdriver.ChromeOptions()
|
||||
options.add_argument("--headless") # Enable headless mode
|
||||
webdriver_manager = ChromeDriverManager()
|
||||
driver = webdriver.Chrome(service=ChromeService(webdriver_manager.install()), options=options)
|
||||
elif browser_type == "edge":
|
||||
options = webdriver.EdgeOptions()
|
||||
options.add_argument("--headless") # Enable headless mode
|
||||
webdriver_manager = EdgeChromiumDriverManager()
|
||||
driver = webdriver.Edge(service=EdgeService(webdriver_manager.install()), options=options)
|
||||
else:
|
||||
return
|
||||
driver.implicitly_wait(10)
|
||||
driver.get(param["url"])
|
||||
driver.quit()
|
||||
return driver.page_source
|
||||
return driver
|
||||
|
||||
@@ -1,33 +1,17 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Type, Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from selenium.webdriver.chrome.service import Service as ChromeService
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
from webdriver_manager.microsoft import EdgeChromiumDriverManager
|
||||
from selenium.webdriver.edge.service import Service as EdgeService
|
||||
from selenium.webdriver.chrome.service import Service as ChromeService
|
||||
|
||||
from tools.tool_base import ToolBase
|
||||
from utils.config import Config
|
||||
|
||||
|
||||
class SearchInput(BaseModel):
|
||||
url: str = Field(description="要查询的网址")
|
||||
|
||||
|
||||
class WebBrowserTool(ToolBase):
|
||||
name = "web_browser"
|
||||
description = "利用浏览器访问url,得到网页源码"
|
||||
args_schema: Type[BaseModel] = SearchInput
|
||||
|
||||
def _run(self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
|
||||
"""Use the tool."""
|
||||
class SeleniumUtil:
|
||||
def __init__(self):
|
||||
config = Config()
|
||||
# Download webdriver based on browser type
|
||||
browser_type = config.BROWSER.get("browser_type")
|
||||
if browser_type == "chrome":
|
||||
options = webdriver.ChromeOptions()
|
||||
@@ -40,16 +24,27 @@ class WebBrowserTool(ToolBase):
|
||||
webdriver_manager = EdgeChromiumDriverManager()
|
||||
driver = webdriver.Edge(service=EdgeService(webdriver_manager.install()), options=options)
|
||||
else:
|
||||
return ""
|
||||
return
|
||||
driver.implicitly_wait(10)
|
||||
driver.get(url)
|
||||
driver.quit()
|
||||
print(f"res {driver.page_source}")
|
||||
self.driver = driver
|
||||
|
||||
return driver.page_source
|
||||
def get_url(self, url):
|
||||
self.driver.get(url)
|
||||
|
||||
async def _arun(
|
||||
self, url: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("custom_search does not support async")
|
||||
def click(self, xpath):
|
||||
self.driver.find_element(By.XPATH, xpath).click()
|
||||
|
||||
def send(self, xpath, text):
|
||||
self.driver.find_element(By.XPATH, xpath).send_keys(text)
|
||||
|
||||
def quit(self):
|
||||
self.driver.quit()
|
||||
|
||||
def get_text(self, xpath):
|
||||
return self.driver.find_element(xpath).text
|
||||
|
||||
def get_attribute(self, xpath, name):
|
||||
return self.driver.find_element(xpath).get_attribute(name)
|
||||
|
||||
def get_xpath_elements(self, xpath):
|
||||
return self.driver.find_elements(By.XPATH, xpath)
|
||||
Reference in New Issue
Block a user