mirror of
https://github.com/yuruotong1/autoMate.git
synced 2026-03-22 13:07:17 +08:00
更新readme
This commit is contained in:
@@ -20,64 +20,58 @@ class UIElement(BaseModel):
|
||||
class VisionAgent:
|
||||
def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
|
||||
"""
|
||||
初始化视觉代理
|
||||
Initialize the vision agent
|
||||
|
||||
参数:
|
||||
yolo_model_path: YOLO模型路径
|
||||
caption_model_path: 图像描述模型路径,默认为Florence-2
|
||||
Parameters:
|
||||
yolo_model_path: Path to YOLO model
|
||||
caption_model_path: Path to image caption model, default is Florence-2
|
||||
"""
|
||||
# 确定可用的设备和最佳数据类型
|
||||
self.device, self.dtype = self._get_optimal_device_and_dtype()
|
||||
print(f"使用设备: {self.device}, 数据类型: {self.dtype}")
|
||||
|
||||
# 加载YOLO模型
|
||||
# determine the available device and the best dtype
|
||||
self.device, self.dtype = self._get_optimal_device_and_dtype()
|
||||
# load the YOLO model
|
||||
self.yolo_model = YOLO(yolo_model_path)
|
||||
|
||||
# 加载图像描述模型和处理器
|
||||
# load the image caption model and processor
|
||||
self.caption_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 根据设备类型加载模型
|
||||
# load the model according to the device type
|
||||
try:
|
||||
print(f"正在加载图像描述模型: {caption_model_path}")
|
||||
if self.device.type == 'cuda':
|
||||
# CUDA设备使用float16
|
||||
# CUDA device uses float16
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
elif self.device.type == 'mps':
|
||||
# MPS设备使用float32(MPS对float16支持有限)
|
||||
# MPS device uses float32 (MPS has limited support for float16)
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
else:
|
||||
# CPU使用float32
|
||||
# CPU uses float32
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
|
||||
print("图像描述模型加载成功")
|
||||
except Exception as e:
|
||||
print(f"加载图像描述模型失败: {e}")
|
||||
|
||||
# 设置提示词
|
||||
raise e
|
||||
self.prompt = "<CAPTION>"
|
||||
|
||||
# 设置批处理大小
|
||||
# set the batch size
|
||||
if self.device.type == 'cuda':
|
||||
self.batch_size = 128 # CUDA设备使用较大批处理大小
|
||||
self.batch_size = 128
|
||||
elif self.device.type == 'mps':
|
||||
self.batch_size = 32 # MPS设备使用中等批处理大小
|
||||
self.batch_size = 32
|
||||
else:
|
||||
self.batch_size = 16 # CPU使用较小批处理大小
|
||||
self.batch_size = 16
|
||||
|
||||
self.elements: List[UIElement] = []
|
||||
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
|
||||
@@ -87,33 +81,26 @@ class VisionAgent:
|
||||
# image = self.load_image(image_source)
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise FileNotFoundError(f"Vision agent: 图片读取失败")
|
||||
raise FileNotFoundError(f"Vision agent: Failed to read image")
|
||||
return self.analyze_image(image)
|
||||
|
||||
def _get_optimal_device_and_dtype(self):
|
||||
"""确定最佳设备和数据类型"""
|
||||
"""determine the optimal device and dtype"""
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
# 检查GPU是否适合使用float16
|
||||
# check if the GPU is suitable for using float16
|
||||
capability = torch.cuda.get_device_capability()
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
print(f"检测到CUDA设备: {gpu_name}, 计算能力: {capability}")
|
||||
|
||||
# 只在较新的GPU上使用float16
|
||||
if capability[0] >= 7: # Volta及以上架构
|
||||
# only use float16 on newer GPUs
|
||||
if capability[0] >= 7:
|
||||
dtype = torch.float16
|
||||
print("使用float16精度")
|
||||
else:
|
||||
dtype = torch.float32
|
||||
print("GPU计算能力较低,使用float32精度")
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
dtype = torch.float32 # MPS上使用float32更安全
|
||||
print("检测到MPS设备(Apple Silicon),使用float32精度")
|
||||
dtype = torch.float32
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
print("未检测到GPU,使用CPU和float32精度")
|
||||
|
||||
return device, dtype
|
||||
|
||||
@@ -168,18 +155,18 @@ class VisionAgent:
|
||||
return texts
|
||||
|
||||
def _get_caption(self, element_crops, batch_size=None):
|
||||
"""获取图像元素的描述"""
|
||||
"""get the caption of the element crops"""
|
||||
if not element_crops:
|
||||
return []
|
||||
|
||||
# 如果未指定批处理大小,使用实例的默认值
|
||||
# if batch_size is not specified, use the instance's default value
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
# 调整图像尺寸为64x64
|
||||
# resize the image to 64x64
|
||||
resized_crops = []
|
||||
for img in element_crops:
|
||||
# 转换为numpy数组,调整大小,再转回PIL
|
||||
# convert to numpy array, resize, then convert back to PIL
|
||||
img_np = np.array(img)
|
||||
resized_np = cv2.resize(img_np, (64, 64))
|
||||
resized_crops.append(Image.fromarray(resized_np))
|
||||
@@ -187,27 +174,27 @@ class VisionAgent:
|
||||
generated_texts = []
|
||||
device = self.device
|
||||
|
||||
# 分批处理
|
||||
# process in batches
|
||||
for i in range(0, len(resized_crops), batch_size):
|
||||
batch = resized_crops[i:i+batch_size]
|
||||
try:
|
||||
# 根据设备类型选择数据类型
|
||||
# select the dtype according to the device type
|
||||
if device.type == 'cuda':
|
||||
inputs = self.caption_processor(
|
||||
images=batch,
|
||||
text=[self.prompt] * len(batch),
|
||||
return_tensors="pt",
|
||||
do_resize=False # 避免处理器改变图像尺寸
|
||||
do_resize=False
|
||||
).to(device=device, dtype=torch.float16)
|
||||
else:
|
||||
# MPS和CPU使用float32
|
||||
# MPS and CPU use float32
|
||||
inputs = self.caption_processor(
|
||||
images=batch,
|
||||
text=[self.prompt] * len(batch),
|
||||
return_tensors="pt"
|
||||
).to(device=device)
|
||||
|
||||
# 针对Florence-2的特殊处理
|
||||
# special treatment for Florence-2
|
||||
with torch.no_grad():
|
||||
if 'florence' in self.caption_model.config.model_type:
|
||||
generated_ids = self.caption_model.generate(
|
||||
@@ -225,7 +212,7 @@ class VisionAgent:
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
# 解码生成的ID
|
||||
# decode the generated IDs
|
||||
texts = self.caption_processor.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True
|
||||
@@ -233,12 +220,12 @@ class VisionAgent:
|
||||
texts = [text.strip() for text in texts]
|
||||
generated_texts.extend(texts)
|
||||
|
||||
# 清理缓存
|
||||
# clean the cache
|
||||
if device.type == 'cuda' and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except RuntimeError as e:
|
||||
print(f"批次处理失败: {e}")
|
||||
raise e
|
||||
return generated_texts
|
||||
|
||||
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
|
||||
@@ -283,27 +270,27 @@ class VisionAgent:
|
||||
|
||||
def load_image(self, image_source: str) -> np.ndarray:
|
||||
try:
|
||||
# 处理可能存在的Data URL前缀(如 "data:image/png;base64,")
|
||||
# Handle potential Data URL prefix (like "data:image/png;base64,")
|
||||
if ',' in image_source:
|
||||
_, payload = image_source.split(',', 1)
|
||||
else:
|
||||
payload = image_source
|
||||
|
||||
# Base64解码 -> bytes -> numpy数组
|
||||
# Base64 decode -> bytes -> numpy array
|
||||
image_bytes = base64.b64decode(payload)
|
||||
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
|
||||
# OpenCV解码图像
|
||||
# OpenCV decode image
|
||||
image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
raise ValueError("解码图片失败:无效的图片数据")
|
||||
raise ValueError("Failed to decode image: Invalid image data")
|
||||
|
||||
return self.analyze_image(image)
|
||||
|
||||
except (base64.binascii.Error, ValueError) as e:
|
||||
# 生成更清晰的错误信息
|
||||
error_msg = f"输入既不是有效的文件路径,也不是有效的Base64图片数据"
|
||||
# Generate clearer error message
|
||||
error_msg = f"Input is neither a valid file path nor valid Base64 image data"
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ from gradio_ui.loop import (
|
||||
from gradio_ui.tools import ToolResult
|
||||
import base64
|
||||
from xbrain.utils.config import Config
|
||||
|
||||
from util.download_weights import MODEL_DIR
|
||||
CONFIG_DIR = Path("~/.anthropic").expanduser()
|
||||
API_KEY_FILE = CONFIG_DIR / "api_key"
|
||||
|
||||
@@ -51,7 +53,7 @@ def setup_state(state):
|
||||
if config.OPENAI_BASE_URL:
|
||||
state["base_url"] = config.OPENAI_BASE_URL
|
||||
else:
|
||||
state["base_url"] = "https://api.openai-next.com/v1"
|
||||
state["base_url"] = "https://api.openai.com/v1"
|
||||
if config.OPENAI_MODEL:
|
||||
state["model"] = config.OPENAI_MODEL
|
||||
else:
|
||||
@@ -317,8 +319,8 @@ def run():
|
||||
model.change(fn=update_model, inputs=[model, state], outputs=None)
|
||||
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
|
||||
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
|
||||
vision_agent = VisionAgent(yolo_model_path="./weights/icon_detect/model.pt",
|
||||
caption_model_path="./weights/icon_caption")
|
||||
vision_agent = VisionAgent(yolo_model_path=os.path.join(MODEL_DIR, "icon_detect", "model.pt"),
|
||||
caption_model_path=os.path.join(MODEL_DIR, "icon_caption"))
|
||||
vision_agent_state = gr.State({"agent": vision_agent})
|
||||
submit_button.click(process_input, [chat_input, state, vision_agent_state], chatbot)
|
||||
stop_button.click(stop_app, [state], None)
|
||||
|
||||
10
main.py
10
main.py
@@ -4,14 +4,14 @@ import torch
|
||||
|
||||
def run():
|
||||
try:
|
||||
print("cuda is_available: ", torch.cuda.is_available()) # 应该返回True
|
||||
print("cuda is_available: ", torch.cuda.is_available())
|
||||
print("MPS is_available: ", torch.backends.mps.is_available())
|
||||
print("cuda device_count", torch.cuda.device_count()) # 应该至少返回1
|
||||
print("cuda device_name", torch.cuda.get_device_name(0)) # 应该显示您的GPU名称
|
||||
print("cuda device_count", torch.cuda.device_count())
|
||||
print("cuda device_name", torch.cuda.get_device_name(0))
|
||||
except Exception:
|
||||
print("显卡驱动不适配,请根据readme安装合适版本的 torch!")
|
||||
print("GPU driver is not compatible, please install the appropriate version of torch according to the readme!")
|
||||
|
||||
# 下载权重文件
|
||||
# download the weight files
|
||||
download_weights.download()
|
||||
app.run()
|
||||
|
||||
|
||||
@@ -10,4 +10,5 @@ pyautogui==0.9.54
|
||||
anthropic[bedrock,vertex]>=0.37.1
|
||||
pyxbrain==1.1.31
|
||||
timm
|
||||
einops==0.8.0
|
||||
einops==0.8.0
|
||||
modelscope
|
||||
@@ -1,12 +1,14 @@
|
||||
import subprocess
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from modelscope import snapshot_download
|
||||
__WEIGHTS_DIR = Path("weights")
|
||||
MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "OmniParser-v2___0")
|
||||
def download():
|
||||
# 创建权重目录
|
||||
weights_dir = Path("weights")
|
||||
weights_dir.mkdir(exist_ok=True)
|
||||
# Create weights directory
|
||||
|
||||
__WEIGHTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# 需要下载的文件列表
|
||||
# List of files to download
|
||||
files = [
|
||||
"icon_detect/train_args.yaml",
|
||||
"icon_detect/model.pt",
|
||||
@@ -16,42 +18,24 @@ def download():
|
||||
"icon_caption/model.safetensors"
|
||||
]
|
||||
|
||||
# 检查并下载缺失的文件
|
||||
# Check and download missing files
|
||||
missing_files = []
|
||||
for file in files:
|
||||
file_path = weights_dir / file
|
||||
if not file_path.exists():
|
||||
file_path = os.path.join(MODEL_DIR, file)
|
||||
if not os.path.exists(file_path):
|
||||
missing_files.append(file)
|
||||
break
|
||||
|
||||
if not missing_files:
|
||||
print("已经检测到模型文件!")
|
||||
print("Model files already detected!")
|
||||
return
|
||||
|
||||
print(f"未检测到模型文件,需要下载 {len(missing_files)} 个文件")
|
||||
# 下载缺失的文件
|
||||
max_retries = 3 # 最大重试次数
|
||||
for file in missing_files:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
print(f"正在下载: {file} (尝试 {attempt + 1}/{max_retries})")
|
||||
cmd = [
|
||||
"huggingface-cli",
|
||||
"download",
|
||||
"microsoft/OmniParser-v2.0",
|
||||
file,
|
||||
"--local-dir",
|
||||
"weights"
|
||||
]
|
||||
subprocess.run(cmd, check=True)
|
||||
break # 下载成功,跳出重试循环
|
||||
except subprocess.CalledProcessError as e:
|
||||
if attempt == max_retries - 1: # 最后一次尝试
|
||||
print(f"下载失败: {file},已达到最大重试次数")
|
||||
raise # 重新抛出异常
|
||||
print(f"下载失败: {file},正在重试...")
|
||||
continue
|
||||
|
||||
print("下载完成")
|
||||
snapshot_download(
|
||||
'AI-ModelScope/OmniParser-v2.0',
|
||||
cache_dir='weights'
|
||||
)
|
||||
|
||||
print("Download complete")
|
||||
|
||||
if __name__ == "__main__":
|
||||
download()
|
||||
Reference in New Issue
Block a user