Merge remote-tracking branch 'origin/master' into dev

This commit is contained in:
yuruo 2025-03-06 18:21:14 +08:00
commit fc7bc3e85c
3 changed files with 22 additions and 13 deletions

28
main.py
View File

@ -2,26 +2,40 @@ import argparse
import subprocess
import signal
import sys
import platform
from gradio_ui import app
from util import download_weights
import time
import torch
def run():
try:
print("cuda is_available: ", torch.cuda.is_available()) # 应该返回True
print("MPS is_available: ", torch.backends.mps.is_available())
print("cuda device_count", torch.cuda.device_count()) # 应该至少返回1
print("cuda device_name", torch.cuda.get_device_name(0)) # 应该显示您的GPU名称
except Exception:
print("显卡驱动不适配请根据readme安装合适版本的 torch")
# 启动 server.py 子进程,并捕获其输出
server_process = subprocess.Popen(
["python", "./server.py"],
stdout=subprocess.PIPE, # 捕获标准输出
stderr=subprocess.PIPE,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
text=True
)
# Windows:
if platform.system() == 'Windows':
server_process = subprocess.Popen(
["python", "./server.py"],
stdout=subprocess.PIPE, # 捕获标准输出
stderr=subprocess.PIPE,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
text=True
)
else:
server_process = subprocess.Popen(
["python", "./server.py"],
stdout=subprocess.PIPE, # 捕获标准输出
stderr=subprocess.PIPE,
start_new_session=True,
text=True
)
try:
# 下载权重文件

View File

@ -17,7 +17,7 @@ def parse_arguments():
parser = argparse.ArgumentParser(description='autoMate API')
parser.add_argument('--som_model_path', type=str, default='./weights/icon_detect/model.pt', help='Path to the som model')
parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
parser.add_argument('--caption_model_path', type=str, default='./weights/icon_caption_florence', help='Path to the caption model')
parser.add_argument('--caption_model_path', type=str, default='./weights/icon_caption', help='Path to the caption model')
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')

View File

@ -51,11 +51,6 @@ def download():
print(f"下载失败: {file},正在重试...")
continue
# 重命名目录
old_path = weights_dir / "icon_caption"
new_path = weights_dir / "icon_caption_florence"
if old_path.exists() and not new_path.exists():
old_path.rename(new_path)
print("下载完成")
if __name__ == "__main__":