From a6f34bb103b713e84b84b33568e9937a146e0d5d Mon Sep 17 00:00:00 2001 From: Dan Li Date: Thu, 6 Mar 2025 13:01:35 +0300 Subject: [PATCH] add mac support --- main.py | 28 +++++++++++++++++++++------- server.py | 2 +- util/download_weights.py | 5 ----- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 2d54c62..080bef7 100644 --- a/main.py +++ b/main.py @@ -2,26 +2,40 @@ import argparse import subprocess import signal import sys +import platform from gradio_ui import app from util import download_weights import time import torch + def run(): try: print("cuda is_available: ", torch.cuda.is_available()) # 应该返回True + print("MPS is_available: ", torch.backends.mps.is_available()) print("cuda device_count", torch.cuda.device_count()) # 应该至少返回1 print("cuda device_name", torch.cuda.get_device_name(0)) # 应该显示您的GPU名称 except Exception: print("显卡驱动不适配,请根据readme安装合适版本的 torch!") # 启动 server.py 子进程,并捕获其输出 - server_process = subprocess.Popen( - ["python", "./server.py"], - stdout=subprocess.PIPE, # 捕获标准输出 - stderr=subprocess.PIPE, - creationflags=subprocess.CREATE_NEW_PROCESS_GROUP, - text=True - ) + # Windows: + if platform.system() == 'Windows': + server_process = subprocess.Popen( + ["python", "./server.py"], + stdout=subprocess.PIPE, # 捕获标准输出 + stderr=subprocess.PIPE, + creationflags=subprocess.CREATE_NEW_PROCESS_GROUP, + text=True + ) + else: + server_process = subprocess.Popen( + ["python", "./server.py"], + stdout=subprocess.PIPE, # 捕获标准输出 + stderr=subprocess.PIPE, + start_new_session=True, + text=True + ) + try: # 下载权重文件 diff --git a/server.py b/server.py index 9173ed9..23de09c 100644 --- a/server.py +++ b/server.py @@ -17,7 +17,7 @@ def parse_arguments(): parser = argparse.ArgumentParser(description='autoMate API') parser.add_argument('--som_model_path', type=str, default='./weights/icon_detect/model.pt', help='Path to the som model') parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model') - parser.add_argument('--caption_model_path', type=str, default='./weights/icon_caption_florence', help='Path to the caption model') + parser.add_argument('--caption_model_path', type=str, default='./weights/icon_caption', help='Path to the caption model') parser.add_argument('--device', type=str, default='cpu', help='Device to run the model') parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API') diff --git a/util/download_weights.py b/util/download_weights.py index ad3797e..6a9eeaf 100644 --- a/util/download_weights.py +++ b/util/download_weights.py @@ -51,11 +51,6 @@ def download(): print(f"下载失败: {file},正在重试...") continue - # 重命名目录 - old_path = weights_dir / "icon_caption" - new_path = weights_dir / "icon_caption_florence" - if old_path.exists() and not new_path.exists(): - old_path.rename(new_path) print("下载完成") if __name__ == "__main__":