mirror of
https://github.com/yuruotong1/autoMate.git
synced 2025-12-26 05:16:21 +08:00
Merge remote-tracking branch 'origin/master' into dev
This commit is contained in:
commit
fc7bc3e85c
28
main.py
28
main.py
@ -2,26 +2,40 @@ import argparse
|
||||
import subprocess
|
||||
import signal
|
||||
import sys
|
||||
import platform
|
||||
from gradio_ui import app
|
||||
from util import download_weights
|
||||
import time
|
||||
import torch
|
||||
|
||||
def run():
|
||||
try:
|
||||
print("cuda is_available: ", torch.cuda.is_available()) # 应该返回True
|
||||
print("MPS is_available: ", torch.backends.mps.is_available())
|
||||
print("cuda device_count", torch.cuda.device_count()) # 应该至少返回1
|
||||
print("cuda device_name", torch.cuda.get_device_name(0)) # 应该显示您的GPU名称
|
||||
except Exception:
|
||||
print("显卡驱动不适配,请根据readme安装合适版本的 torch!")
|
||||
|
||||
# 启动 server.py 子进程,并捕获其输出
|
||||
server_process = subprocess.Popen(
|
||||
["python", "./server.py"],
|
||||
stdout=subprocess.PIPE, # 捕获标准输出
|
||||
stderr=subprocess.PIPE,
|
||||
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
|
||||
text=True
|
||||
)
|
||||
# Windows:
|
||||
if platform.system() == 'Windows':
|
||||
server_process = subprocess.Popen(
|
||||
["python", "./server.py"],
|
||||
stdout=subprocess.PIPE, # 捕获标准输出
|
||||
stderr=subprocess.PIPE,
|
||||
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
|
||||
text=True
|
||||
)
|
||||
else:
|
||||
server_process = subprocess.Popen(
|
||||
["python", "./server.py"],
|
||||
stdout=subprocess.PIPE, # 捕获标准输出
|
||||
stderr=subprocess.PIPE,
|
||||
start_new_session=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 下载权重文件
|
||||
|
||||
@ -17,7 +17,7 @@ def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='autoMate API')
|
||||
parser.add_argument('--som_model_path', type=str, default='./weights/icon_detect/model.pt', help='Path to the som model')
|
||||
parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
|
||||
parser.add_argument('--caption_model_path', type=str, default='./weights/icon_caption_florence', help='Path to the caption model')
|
||||
parser.add_argument('--caption_model_path', type=str, default='./weights/icon_caption', help='Path to the caption model')
|
||||
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
|
||||
parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')
|
||||
|
||||
@ -51,11 +51,6 @@ def download():
|
||||
print(f"下载失败: {file},正在重试...")
|
||||
continue
|
||||
|
||||
# 重命名目录
|
||||
old_path = weights_dir / "icon_caption"
|
||||
new_path = weights_dir / "icon_caption_florence"
|
||||
if old_path.exists() and not new_path.exists():
|
||||
old_path.rename(new_path)
|
||||
print("下载完成")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user