mirror of
https://github.com/yuruotong1/autoMate.git
synced 2026-03-22 04:57:18 +08:00
优化调用
This commit is contained in:
@@ -32,8 +32,7 @@ class TaskRunAgent(BaseAgent):
|
||||
def __call__(self, task):
|
||||
res = run([{"role": "user", "content": task}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
|
||||
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=res, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
|
||||
vlm_response_json = self.extract_data(res, "json")
|
||||
return response_message, vlm_response_json
|
||||
return response_message
|
||||
|
||||
def extract_data(self, input_string, data_type):
|
||||
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks
|
||||
|
||||
@@ -14,6 +14,7 @@ from anthropic import APIResponse
|
||||
from anthropic.types import TextBlock
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
from gradio_ui.agent.vision_agent import VisionAgent
|
||||
from gradio_ui.loop import (
|
||||
sampling_loop_sync,
|
||||
)
|
||||
@@ -156,7 +157,7 @@ def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="b
|
||||
# print(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)")
|
||||
|
||||
|
||||
def process_input(user_input, state):
|
||||
def process_input(user_input, state, vision_agent):
|
||||
# Reset the stop flag
|
||||
if state["stop"]:
|
||||
state["stop"] = False
|
||||
@@ -185,7 +186,8 @@ def process_input(user_input, state):
|
||||
only_n_most_recent_images=state["only_n_most_recent_images"],
|
||||
max_tokens=8000,
|
||||
omniparser_url=args.omniparser_server_url,
|
||||
base_url = state["base_url"]
|
||||
base_url = state["base_url"],
|
||||
vision_agent = vision_agent
|
||||
):
|
||||
if loop_msg is None or state.get("stop"):
|
||||
yield state['chatbot_messages']
|
||||
@@ -314,7 +316,9 @@ def run():
|
||||
model.change(fn=update_model, inputs=[model, state], outputs=[api_key])
|
||||
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
|
||||
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
|
||||
submit_button.click(process_input, [chat_input, state], chatbot)
|
||||
vision_agent = VisionAgent(yolo_model_path="./weights/icon_detect/model.pt",
|
||||
caption_model_path="./weights/icon_caption")
|
||||
submit_button.click(process_input, [chat_input, state, vision_agent], chatbot)
|
||||
stop_button.click(stop_app, [state], None)
|
||||
base_url.change(fn=update_base_url, inputs=[base_url, state], outputs=None)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7888)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
"""
|
||||
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools.
|
||||
"""
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from enum import StrEnum
|
||||
|
||||
from gradio_ui.agent.vision_agent import VisionAgent
|
||||
from gradio_ui.tools.screen_capture import get_screenshot
|
||||
from anthropic import APIResponse
|
||||
from anthropic.types import (
|
||||
TextBlock,
|
||||
)
|
||||
from anthropic.types.beta import (
|
||||
BetaContentBlock,
|
||||
BetaMessage,
|
||||
@@ -16,10 +14,12 @@ from anthropic.types.beta import (
|
||||
from gradio_ui.agent.task_plan_agent import TaskPlanAgent
|
||||
from gradio_ui.agent.task_run_agent import TaskRunAgent
|
||||
from gradio_ui.tools import ToolResult
|
||||
|
||||
from gradio_ui.agent.llm_utils.utils import encode_image
|
||||
from gradio_ui.agent.llm_utils.omniparserclient import OmniParserClient
|
||||
from gradio_ui.agent.vlm_agent import VLMAgent
|
||||
from gradio_ui.executor.anthropic_executor import AnthropicExecutor
|
||||
from pathlib import Path
|
||||
OUTPUT_DIR = "./tmp/outputs"
|
||||
|
||||
def sampling_loop_sync(
|
||||
*,
|
||||
@@ -32,13 +32,14 @@ def sampling_loop_sync(
|
||||
only_n_most_recent_images: int | None = 0,
|
||||
max_tokens: int = 4096,
|
||||
omniparser_url: str,
|
||||
base_url: str
|
||||
base_url: str,
|
||||
vision_agent: VisionAgent
|
||||
):
|
||||
"""
|
||||
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
|
||||
"""
|
||||
print('in sampling_loop_sync, model:', model)
|
||||
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
|
||||
# omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
|
||||
task_plan_agent = TaskPlanAgent()
|
||||
# actor = VLMAgent(
|
||||
# model=model,
|
||||
@@ -57,14 +58,22 @@ def sampling_loop_sync(
|
||||
tool_result_content = None
|
||||
|
||||
print(f"Start the message loop. User messages: {messages}")
|
||||
plan = task_plan_agent(user_task = messages[-1]["content"][0])
|
||||
plan = task_plan_agent(user_task = messages[-1]["content"][0].text)
|
||||
task_run_agent = TaskRunAgent()
|
||||
|
||||
|
||||
while True:
|
||||
parsed_screen = omniparser_client()
|
||||
parsed_screen = parse_screen(vision_agent)
|
||||
# tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
|
||||
task_run_agent(plan, parsed_screen)
|
||||
tools_use_needed = task_run_agent(plan, parsed_screen)
|
||||
for message, tool_result_content in executor(tools_use_needed, messages):
|
||||
yield message
|
||||
if not tool_result_content:
|
||||
return messages
|
||||
return messages
|
||||
|
||||
def parse_screen(vision_agent: VisionAgent):
|
||||
_, screenshot_path = get_screenshot()
|
||||
screenshot_path = str(screenshot_path)
|
||||
image_base64 = encode_image(screenshot_path)
|
||||
parsed_screen = vision_agent(image_base64)
|
||||
return parsed_screen
|
||||
|
||||
64
main.py
64
main.py
@@ -1,11 +1,6 @@
|
||||
import subprocess
|
||||
from threading import Thread
|
||||
import time
|
||||
import requests
|
||||
from gradio_ui import app
|
||||
from util import download_weights
|
||||
import torch
|
||||
import socket
|
||||
|
||||
def run():
|
||||
try:
|
||||
@@ -16,63 +11,10 @@ def run():
|
||||
except Exception:
|
||||
print("显卡驱动不适配,请根据readme安装合适版本的 torch!")
|
||||
|
||||
|
||||
server_process = subprocess.Popen(
|
||||
["python", "./omniserver.py"],
|
||||
stdout=subprocess.PIPE, # 捕获标准输出
|
||||
stderr=subprocess.PIPE,
|
||||
text=True
|
||||
)
|
||||
|
||||
stdout_thread = Thread(
|
||||
target=stream_reader,
|
||||
args=(server_process.stdout, "SERVER-OUT")
|
||||
)
|
||||
|
||||
stderr_thread = Thread(
|
||||
target=stream_reader,
|
||||
args=(server_process.stderr, "SERVER-ERR")
|
||||
)
|
||||
stdout_thread.daemon = True
|
||||
stderr_thread.daemon = True
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
# 下载权重文件
|
||||
download_weights.download()
|
||||
app.run()
|
||||
|
||||
|
||||
try:
|
||||
# 下载权重文件
|
||||
download_weights.download()
|
||||
print("启动Omniserver服务中,因为加载模型真的超级慢,请耐心等待!")
|
||||
while True:
|
||||
try:
|
||||
res = requests.get("http://127.0.0.1:8000/probe/", timeout=5)
|
||||
if res.status_code == 200:
|
||||
print("Omniparser服务启动成功...")
|
||||
break
|
||||
except (requests.ConnectionError, requests.Timeout):
|
||||
pass
|
||||
if server_process.poll() is not None:
|
||||
raise RuntimeError(f"服务器进程报错退出:{server_process.returncode}")
|
||||
print("等待服务启动...")
|
||||
time.sleep(10)
|
||||
|
||||
app.run()
|
||||
finally:
|
||||
if server_process.poll() is None: # 如果进程还在运行
|
||||
server_process.terminate() # 发送终止信号
|
||||
server_process.wait(timeout=8) # 等待进程结束
|
||||
|
||||
def stream_reader(pipe, prefix):
|
||||
for line in pipe:
|
||||
print(f"[{prefix}]", line, end="", flush=True)
|
||||
|
||||
def is_port_occupied(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 检测8000端口是否被占用
|
||||
if is_port_occupied(8000):
|
||||
print("8000端口被占用,请先关闭占用该端口的进程")
|
||||
exit()
|
||||
run()
|
||||
@@ -12,7 +12,7 @@ import uvicorn
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(root_dir)
|
||||
# from util.omniparser import Omniparser
|
||||
from util.vision_agent import VisionAgent
|
||||
from gradio_ui.agent.vision_agent import VisionAgent
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Omniparser API')
|
||||
|
||||
Reference in New Issue
Block a user