diff --git a/gradio_ui/agent/task_run_agent.py b/gradio_ui/agent/task_run_agent.py index f9d30ba..2dc3d57 100644 --- a/gradio_ui/agent/task_run_agent.py +++ b/gradio_ui/agent/task_run_agent.py @@ -32,8 +32,7 @@ class TaskRunAgent(BaseAgent): def __call__(self, task): res = run([{"role": "user", "content": task}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse) response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=res, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0)) - vlm_response_json = self.extract_data(res, "json") - return response_message, vlm_response_json + return response_message def extract_data(self, input_string, data_type): # Regular expression to extract content starting from '```python' until the end if there are no closing backticks diff --git a/util/vision_agent.py b/gradio_ui/agent/vision_agent.py similarity index 100% rename from util/vision_agent.py rename to gradio_ui/agent/vision_agent.py diff --git a/gradio_ui/app.py b/gradio_ui/app.py index bceba5e..a232271 100644 --- a/gradio_ui/app.py +++ b/gradio_ui/app.py @@ -14,6 +14,7 @@ from anthropic import APIResponse from anthropic.types import TextBlock from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock from anthropic.types.tool_use_block import ToolUseBlock +from gradio_ui.agent.vision_agent import VisionAgent from gradio_ui.loop import ( sampling_loop_sync, ) @@ -156,7 +157,7 @@ def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="b # print(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)") -def process_input(user_input, state): +def process_input(user_input, state, vision_agent): # Reset the stop flag if state["stop"]: state["stop"] = False @@ -185,7 +186,8 @@ def process_input(user_input, state): only_n_most_recent_images=state["only_n_most_recent_images"], max_tokens=8000, omniparser_url=args.omniparser_server_url, - base_url = state["base_url"] + base_url = state["base_url"], + vision_agent = vision_agent ): if loop_msg is None or state.get("stop"): yield state['chatbot_messages'] @@ -314,7 +316,9 @@ def run(): model.change(fn=update_model, inputs=[model, state], outputs=[api_key]) api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None) chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot]) - submit_button.click(process_input, [chat_input, state], chatbot) + vision_agent = VisionAgent(yolo_model_path="./weights/icon_detect/model.pt", + caption_model_path="./weights/icon_caption") + submit_button.click(process_input, [chat_input, state, vision_agent], chatbot) stop_button.click(stop_app, [state], None) base_url.change(fn=update_base_url, inputs=[base_url, state], outputs=None) demo.launch(server_name="0.0.0.0", server_port=7888) diff --git a/gradio_ui/loop.py b/gradio_ui/loop.py index 57d51f2..e06ff2a 100644 --- a/gradio_ui/loop.py +++ b/gradio_ui/loop.py @@ -1,13 +1,11 @@ """ Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools. """ +import base64 from collections.abc import Callable -from enum import StrEnum - +from gradio_ui.agent.vision_agent import VisionAgent +from gradio_ui.tools.screen_capture import get_screenshot from anthropic import APIResponse -from anthropic.types import ( - TextBlock, -) from anthropic.types.beta import ( BetaContentBlock, BetaMessage, @@ -16,10 +14,12 @@ from anthropic.types.beta import ( from gradio_ui.agent.task_plan_agent import TaskPlanAgent from gradio_ui.agent.task_run_agent import TaskRunAgent from gradio_ui.tools import ToolResult - +from gradio_ui.agent.llm_utils.utils import encode_image from gradio_ui.agent.llm_utils.omniparserclient import OmniParserClient from gradio_ui.agent.vlm_agent import VLMAgent from gradio_ui.executor.anthropic_executor import AnthropicExecutor +from pathlib import Path +OUTPUT_DIR = "./tmp/outputs" def sampling_loop_sync( *, @@ -32,13 +32,14 @@ def sampling_loop_sync( only_n_most_recent_images: int | None = 0, max_tokens: int = 4096, omniparser_url: str, - base_url: str + base_url: str, + vision_agent: VisionAgent ): """ Synchronous agentic sampling loop for the assistant/tool interaction of computer use. """ print('in sampling_loop_sync, model:', model) - omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/") + # omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/") task_plan_agent = TaskPlanAgent() # actor = VLMAgent( # model=model, @@ -57,14 +58,22 @@ def sampling_loop_sync( tool_result_content = None print(f"Start the message loop. User messages: {messages}") - plan = task_plan_agent(user_task = messages[-1]["content"][0]) + plan = task_plan_agent(user_task = messages[-1]["content"][0].text) task_run_agent = TaskRunAgent() + while True: - parsed_screen = omniparser_client() + parsed_screen = parse_screen(vision_agent) # tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen) - task_run_agent(plan, parsed_screen) + tools_use_needed = task_run_agent(plan, parsed_screen) for message, tool_result_content in executor(tools_use_needed, messages): yield message if not tool_result_content: - return messages \ No newline at end of file + return messages + +def parse_screen(vision_agent: VisionAgent): + _, screenshot_path = get_screenshot() + screenshot_path = str(screenshot_path) + image_base64 = encode_image(screenshot_path) + parsed_screen = vision_agent(image_base64) + return parsed_screen diff --git a/main.py b/main.py index 845a0ca..73bea7f 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,6 @@ -import subprocess -from threading import Thread -import time -import requests from gradio_ui import app from util import download_weights import torch -import socket def run(): try: @@ -16,63 +11,10 @@ def run(): except Exception: print("显卡驱动不适配,请根据readme安装合适版本的 torch!") - - server_process = subprocess.Popen( - ["python", "./omniserver.py"], - stdout=subprocess.PIPE, # 捕获标准输出 - stderr=subprocess.PIPE, - text=True - ) - - stdout_thread = Thread( - target=stream_reader, - args=(server_process.stdout, "SERVER-OUT") - ) - - stderr_thread = Thread( - target=stream_reader, - args=(server_process.stderr, "SERVER-ERR") - ) - stdout_thread.daemon = True - stderr_thread.daemon = True - stdout_thread.start() - stderr_thread.start() + # 下载权重文件 + download_weights.download() + app.run() - try: - # 下载权重文件 - download_weights.download() - print("启动Omniserver服务中,因为加载模型真的超级慢,请耐心等待!") - while True: - try: - res = requests.get("http://127.0.0.1:8000/probe/", timeout=5) - if res.status_code == 200: - print("Omniparser服务启动成功...") - break - except (requests.ConnectionError, requests.Timeout): - pass - if server_process.poll() is not None: - raise RuntimeError(f"服务器进程报错退出:{server_process.returncode}") - print("等待服务启动...") - time.sleep(10) - - app.run() - finally: - if server_process.poll() is None: # 如果进程还在运行 - server_process.terminate() # 发送终止信号 - server_process.wait(timeout=8) # 等待进程结束 - -def stream_reader(pipe, prefix): - for line in pipe: - print(f"[{prefix}]", line, end="", flush=True) - -def is_port_occupied(port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 - if __name__ == '__main__': - # 检测8000端口是否被占用 - if is_port_occupied(8000): - print("8000端口被占用,请先关闭占用该端口的进程") - exit() run() \ No newline at end of file diff --git a/omniserver.py b/omniserver.py index 8a16e54..f4cb916 100644 --- a/omniserver.py +++ b/omniserver.py @@ -12,7 +12,7 @@ import uvicorn root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(root_dir) # from util.omniparser import Omniparser -from util.vision_agent import VisionAgent +from gradio_ui.agent.vision_agent import VisionAgent def parse_arguments(): parser = argparse.ArgumentParser(description='Omniparser API')