优化调用

This commit is contained in:
yuruo
2025-03-11 19:09:40 +08:00
parent 49cf1dfb6f
commit 30b97e53b1
6 changed files with 33 additions and 79 deletions

View File

@@ -32,8 +32,7 @@ class TaskRunAgent(BaseAgent):
def __call__(self, task):
res = run([{"role": "user", "content": task}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=res, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
vlm_response_json = self.extract_data(res, "json")
return response_message, vlm_response_json
return response_message
def extract_data(self, input_string, data_type):
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks

View File

@@ -14,6 +14,7 @@ from anthropic import APIResponse
from anthropic.types import TextBlock
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
from anthropic.types.tool_use_block import ToolUseBlock
from gradio_ui.agent.vision_agent import VisionAgent
from gradio_ui.loop import (
sampling_loop_sync,
)
@@ -156,7 +157,7 @@ def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="b
# print(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)")
def process_input(user_input, state):
def process_input(user_input, state, vision_agent):
# Reset the stop flag
if state["stop"]:
state["stop"] = False
@@ -185,7 +186,8 @@ def process_input(user_input, state):
only_n_most_recent_images=state["only_n_most_recent_images"],
max_tokens=8000,
omniparser_url=args.omniparser_server_url,
base_url = state["base_url"]
base_url = state["base_url"],
vision_agent = vision_agent
):
if loop_msg is None or state.get("stop"):
yield state['chatbot_messages']
@@ -314,7 +316,9 @@ def run():
model.change(fn=update_model, inputs=[model, state], outputs=[api_key])
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
submit_button.click(process_input, [chat_input, state], chatbot)
vision_agent = VisionAgent(yolo_model_path="./weights/icon_detect/model.pt",
caption_model_path="./weights/icon_caption")
submit_button.click(process_input, [chat_input, state, vision_agent], chatbot)
stop_button.click(stop_app, [state], None)
base_url.change(fn=update_base_url, inputs=[base_url, state], outputs=None)
demo.launch(server_name="0.0.0.0", server_port=7888)

View File

@@ -1,13 +1,11 @@
"""
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools.
"""
import base64
from collections.abc import Callable
from enum import StrEnum
from gradio_ui.agent.vision_agent import VisionAgent
from gradio_ui.tools.screen_capture import get_screenshot
from anthropic import APIResponse
from anthropic.types import (
TextBlock,
)
from anthropic.types.beta import (
BetaContentBlock,
BetaMessage,
@@ -16,10 +14,12 @@ from anthropic.types.beta import (
from gradio_ui.agent.task_plan_agent import TaskPlanAgent
from gradio_ui.agent.task_run_agent import TaskRunAgent
from gradio_ui.tools import ToolResult
from gradio_ui.agent.llm_utils.utils import encode_image
from gradio_ui.agent.llm_utils.omniparserclient import OmniParserClient
from gradio_ui.agent.vlm_agent import VLMAgent
from gradio_ui.executor.anthropic_executor import AnthropicExecutor
from pathlib import Path
OUTPUT_DIR = "./tmp/outputs"
def sampling_loop_sync(
*,
@@ -32,13 +32,14 @@ def sampling_loop_sync(
only_n_most_recent_images: int | None = 0,
max_tokens: int = 4096,
omniparser_url: str,
base_url: str
base_url: str,
vision_agent: VisionAgent
):
"""
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
"""
print('in sampling_loop_sync, model:', model)
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
# omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
task_plan_agent = TaskPlanAgent()
# actor = VLMAgent(
# model=model,
@@ -57,14 +58,22 @@ def sampling_loop_sync(
tool_result_content = None
print(f"Start the message loop. User messages: {messages}")
plan = task_plan_agent(user_task = messages[-1]["content"][0])
plan = task_plan_agent(user_task = messages[-1]["content"][0].text)
task_run_agent = TaskRunAgent()
while True:
parsed_screen = omniparser_client()
parsed_screen = parse_screen(vision_agent)
# tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
task_run_agent(plan, parsed_screen)
tools_use_needed = task_run_agent(plan, parsed_screen)
for message, tool_result_content in executor(tools_use_needed, messages):
yield message
if not tool_result_content:
return messages
return messages
def parse_screen(vision_agent: VisionAgent):
_, screenshot_path = get_screenshot()
screenshot_path = str(screenshot_path)
image_base64 = encode_image(screenshot_path)
parsed_screen = vision_agent(image_base64)
return parsed_screen

64
main.py
View File

@@ -1,11 +1,6 @@
import subprocess
from threading import Thread
import time
import requests
from gradio_ui import app
from util import download_weights
import torch
import socket
def run():
try:
@@ -16,63 +11,10 @@ def run():
except Exception:
print("显卡驱动不适配请根据readme安装合适版本的 torch")
server_process = subprocess.Popen(
["python", "./omniserver.py"],
stdout=subprocess.PIPE, # 捕获标准输出
stderr=subprocess.PIPE,
text=True
)
stdout_thread = Thread(
target=stream_reader,
args=(server_process.stdout, "SERVER-OUT")
)
stderr_thread = Thread(
target=stream_reader,
args=(server_process.stderr, "SERVER-ERR")
)
stdout_thread.daemon = True
stderr_thread.daemon = True
stdout_thread.start()
stderr_thread.start()
# 下载权重文件
download_weights.download()
app.run()
try:
# 下载权重文件
download_weights.download()
print("启动Omniserver服务中因为加载模型真的超级慢请耐心等待")
while True:
try:
res = requests.get("http://127.0.0.1:8000/probe/", timeout=5)
if res.status_code == 200:
print("Omniparser服务启动成功...")
break
except (requests.ConnectionError, requests.Timeout):
pass
if server_process.poll() is not None:
raise RuntimeError(f"服务器进程报错退出:{server_process.returncode}")
print("等待服务启动...")
time.sleep(10)
app.run()
finally:
if server_process.poll() is None: # 如果进程还在运行
server_process.terminate() # 发送终止信号
server_process.wait(timeout=8) # 等待进程结束
def stream_reader(pipe, prefix):
for line in pipe:
print(f"[{prefix}]", line, end="", flush=True)
def is_port_occupied(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
if __name__ == '__main__':
# 检测8000端口是否被占用
if is_port_occupied(8000):
print("8000端口被占用请先关闭占用该端口的进程")
exit()
run()

View File

@@ -12,7 +12,7 @@ import uvicorn
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(root_dir)
# from util.omniparser import Omniparser
from util.vision_agent import VisionAgent
from gradio_ui.agent.vision_agent import VisionAgent
def parse_arguments():
parser = argparse.ArgumentParser(description='Omniparser API')