From 82802ffebc6bb04a9549c5484967fd1519820b13 Mon Sep 17 00:00:00 2001 From: yuruo Date: Thu, 13 Mar 2025 17:17:04 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gradio_ui/agent/vision_agent.py | 95 ++++++++++++++------------------- gradio_ui/app.py | 8 +-- main.py | 10 ++-- requirements.txt | 3 +- util/download_weights.py | 54 +++++++------------ 5 files changed, 72 insertions(+), 98 deletions(-) diff --git a/gradio_ui/agent/vision_agent.py b/gradio_ui/agent/vision_agent.py index 8ec3e23..eaded60 100644 --- a/gradio_ui/agent/vision_agent.py +++ b/gradio_ui/agent/vision_agent.py @@ -20,64 +20,58 @@ class UIElement(BaseModel): class VisionAgent: def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'): """ - 初始化视觉代理 + Initialize the vision agent - 参数: - yolo_model_path: YOLO模型路径 - caption_model_path: 图像描述模型路径,默认为Florence-2 + Parameters: + yolo_model_path: Path to YOLO model + caption_model_path: Path to image caption model, default is Florence-2 """ - # 确定可用的设备和最佳数据类型 - self.device, self.dtype = self._get_optimal_device_and_dtype() - print(f"使用设备: {self.device}, 数据类型: {self.dtype}") - - # 加载YOLO模型 + # determine the available device and the best dtype + self.device, self.dtype = self._get_optimal_device_and_dtype() + # load the YOLO model self.yolo_model = YOLO(yolo_model_path) - # 加载图像描述模型和处理器 + # load the image caption model and processor self.caption_processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) - # 根据设备类型加载模型 + # load the model according to the device type try: - print(f"正在加载图像描述模型: {caption_model_path}") if self.device.type == 'cuda': - # CUDA设备使用float16 + # CUDA device uses float16 self.caption_model = AutoModelForCausalLM.from_pretrained( caption_model_path, torch_dtype=torch.float16, trust_remote_code=True ).to(self.device) elif self.device.type == 'mps': - # MPS设备使用float32(MPS对float16支持有限) + # MPS device uses float32 (MPS has limited support for float16) self.caption_model = AutoModelForCausalLM.from_pretrained( caption_model_path, torch_dtype=torch.float32, trust_remote_code=True ).to(self.device) else: - # CPU使用float32 + # CPU uses float32 self.caption_model = AutoModelForCausalLM.from_pretrained( caption_model_path, torch_dtype=torch.float32, trust_remote_code=True ).to(self.device) - print("图像描述模型加载成功") except Exception as e: - print(f"加载图像描述模型失败: {e}") - - # 设置提示词 + raise e self.prompt = "" - # 设置批处理大小 + # set the batch size if self.device.type == 'cuda': - self.batch_size = 128 # CUDA设备使用较大批处理大小 + self.batch_size = 128 elif self.device.type == 'mps': - self.batch_size = 32 # MPS设备使用中等批处理大小 + self.batch_size = 32 else: - self.batch_size = 16 # CPU使用较小批处理大小 + self.batch_size = 16 self.elements: List[UIElement] = [] self.ocr_reader = easyocr.Reader(['en', 'ch_sim']) @@ -87,33 +81,26 @@ class VisionAgent: # image = self.load_image(image_source) image = cv2.imread(image_path) if image is None: - raise FileNotFoundError(f"Vision agent: 图片读取失败") + raise FileNotFoundError(f"Vision agent: Failed to read image") return self.analyze_image(image) def _get_optimal_device_and_dtype(self): - """确定最佳设备和数据类型""" + """determine the optimal device and dtype""" if torch.cuda.is_available(): device = torch.device("cuda") - # 检查GPU是否适合使用float16 + # check if the GPU is suitable for using float16 capability = torch.cuda.get_device_capability() - gpu_name = torch.cuda.get_device_name() - print(f"检测到CUDA设备: {gpu_name}, 计算能力: {capability}") - - # 只在较新的GPU上使用float16 - if capability[0] >= 7: # Volta及以上架构 + # only use float16 on newer GPUs + if capability[0] >= 7: dtype = torch.float16 - print("使用float16精度") else: dtype = torch.float32 - print("GPU计算能力较低,使用float32精度") elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device("mps") - dtype = torch.float32 # MPS上使用float32更安全 - print("检测到MPS设备(Apple Silicon),使用float32精度") + dtype = torch.float32 else: device = torch.device("cpu") dtype = torch.float32 - print("未检测到GPU,使用CPU和float32精度") return device, dtype @@ -168,18 +155,18 @@ class VisionAgent: return texts def _get_caption(self, element_crops, batch_size=None): - """获取图像元素的描述""" + """get the caption of the element crops""" if not element_crops: return [] - # 如果未指定批处理大小,使用实例的默认值 + # if batch_size is not specified, use the instance's default value if batch_size is None: batch_size = self.batch_size - # 调整图像尺寸为64x64 + # resize the image to 64x64 resized_crops = [] for img in element_crops: - # 转换为numpy数组,调整大小,再转回PIL + # convert to numpy array, resize, then convert back to PIL img_np = np.array(img) resized_np = cv2.resize(img_np, (64, 64)) resized_crops.append(Image.fromarray(resized_np)) @@ -187,27 +174,27 @@ class VisionAgent: generated_texts = [] device = self.device - # 分批处理 + # process in batches for i in range(0, len(resized_crops), batch_size): batch = resized_crops[i:i+batch_size] try: - # 根据设备类型选择数据类型 + # select the dtype according to the device type if device.type == 'cuda': inputs = self.caption_processor( images=batch, text=[self.prompt] * len(batch), return_tensors="pt", - do_resize=False # 避免处理器改变图像尺寸 + do_resize=False ).to(device=device, dtype=torch.float16) else: - # MPS和CPU使用float32 + # MPS and CPU use float32 inputs = self.caption_processor( images=batch, text=[self.prompt] * len(batch), return_tensors="pt" ).to(device=device) - # 针对Florence-2的特殊处理 + # special treatment for Florence-2 with torch.no_grad(): if 'florence' in self.caption_model.config.model_type: generated_ids = self.caption_model.generate( @@ -225,7 +212,7 @@ class VisionAgent: early_stopping=True ) - # 解码生成的ID + # decode the generated IDs texts = self.caption_processor.batch_decode( generated_ids, skip_special_tokens=True @@ -233,12 +220,12 @@ class VisionAgent: texts = [text.strip() for text in texts] generated_texts.extend(texts) - # 清理缓存 + # clean the cache if device.type == 'cuda' and torch.cuda.is_available(): torch.cuda.empty_cache() except RuntimeError as e: - print(f"批次处理失败: {e}") + raise e return generated_texts def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]: @@ -283,27 +270,27 @@ class VisionAgent: def load_image(self, image_source: str) -> np.ndarray: try: - # 处理可能存在的Data URL前缀(如 "data:image/png;base64,") + # Handle potential Data URL prefix (like "data:image/png;base64,") if ',' in image_source: _, payload = image_source.split(',', 1) else: payload = image_source - # Base64解码 -> bytes -> numpy数组 + # Base64 decode -> bytes -> numpy array image_bytes = base64.b64decode(payload) np_array = np.frombuffer(image_bytes, dtype=np.uint8) - # OpenCV解码图像 + # OpenCV decode image image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) if image is None: - raise ValueError("解码图片失败:无效的图片数据") + raise ValueError("Failed to decode image: Invalid image data") return self.analyze_image(image) except (base64.binascii.Error, ValueError) as e: - # 生成更清晰的错误信息 - error_msg = f"输入既不是有效的文件路径,也不是有效的Base64图片数据" + # Generate clearer error message + error_msg = f"Input is neither a valid file path nor valid Base64 image data" raise ValueError(error_msg) from e diff --git a/gradio_ui/app.py b/gradio_ui/app.py index dc6d0e6..cf00a50 100644 --- a/gradio_ui/app.py +++ b/gradio_ui/app.py @@ -21,6 +21,8 @@ from gradio_ui.loop import ( from gradio_ui.tools import ToolResult import base64 from xbrain.utils.config import Config + +from util.download_weights import MODEL_DIR CONFIG_DIR = Path("~/.anthropic").expanduser() API_KEY_FILE = CONFIG_DIR / "api_key" @@ -51,7 +53,7 @@ def setup_state(state): if config.OPENAI_BASE_URL: state["base_url"] = config.OPENAI_BASE_URL else: - state["base_url"] = "https://api.openai-next.com/v1" + state["base_url"] = "https://api.openai.com/v1" if config.OPENAI_MODEL: state["model"] = config.OPENAI_MODEL else: @@ -317,8 +319,8 @@ def run(): model.change(fn=update_model, inputs=[model, state], outputs=None) api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None) chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot]) - vision_agent = VisionAgent(yolo_model_path="./weights/icon_detect/model.pt", - caption_model_path="./weights/icon_caption") + vision_agent = VisionAgent(yolo_model_path=os.path.join(MODEL_DIR, "icon_detect", "model.pt"), + caption_model_path=os.path.join(MODEL_DIR, "icon_caption")) vision_agent_state = gr.State({"agent": vision_agent}) submit_button.click(process_input, [chat_input, state, vision_agent_state], chatbot) stop_button.click(stop_app, [state], None) diff --git a/main.py b/main.py index 73bea7f..d550c2c 100644 --- a/main.py +++ b/main.py @@ -4,14 +4,14 @@ import torch def run(): try: - print("cuda is_available: ", torch.cuda.is_available()) # 应该返回True + print("cuda is_available: ", torch.cuda.is_available()) print("MPS is_available: ", torch.backends.mps.is_available()) - print("cuda device_count", torch.cuda.device_count()) # 应该至少返回1 - print("cuda device_name", torch.cuda.get_device_name(0)) # 应该显示您的GPU名称 + print("cuda device_count", torch.cuda.device_count()) + print("cuda device_name", torch.cuda.get_device_name(0)) except Exception: - print("显卡驱动不适配,请根据readme安装合适版本的 torch!") + print("GPU driver is not compatible, please install the appropriate version of torch according to the readme!") - # 下载权重文件 + # download the weight files download_weights.download() app.run() diff --git a/requirements.txt b/requirements.txt index 1518408..af48c6a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ pyautogui==0.9.54 anthropic[bedrock,vertex]>=0.37.1 pyxbrain==1.1.31 timm -einops==0.8.0 \ No newline at end of file +einops==0.8.0 +modelscope \ No newline at end of file diff --git a/util/download_weights.py b/util/download_weights.py index 6a9eeaf..fda22c6 100644 --- a/util/download_weights.py +++ b/util/download_weights.py @@ -1,12 +1,14 @@ -import subprocess +import os from pathlib import Path - +from modelscope import snapshot_download +__WEIGHTS_DIR = Path("weights") +MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "OmniParser-v2___0") def download(): - # 创建权重目录 - weights_dir = Path("weights") - weights_dir.mkdir(exist_ok=True) + # Create weights directory + + __WEIGHTS_DIR.mkdir(exist_ok=True) - # 需要下载的文件列表 + # List of files to download files = [ "icon_detect/train_args.yaml", "icon_detect/model.pt", @@ -16,42 +18,24 @@ def download(): "icon_caption/model.safetensors" ] - # 检查并下载缺失的文件 + # Check and download missing files missing_files = [] for file in files: - file_path = weights_dir / file - if not file_path.exists(): + file_path = os.path.join(MODEL_DIR, file) + if not os.path.exists(file_path): missing_files.append(file) + break if not missing_files: - print("已经检测到模型文件!") + print("Model files already detected!") return - print(f"未检测到模型文件,需要下载 {len(missing_files)} 个文件") - # 下载缺失的文件 - max_retries = 3 # 最大重试次数 - for file in missing_files: - for attempt in range(max_retries): - try: - print(f"正在下载: {file} (尝试 {attempt + 1}/{max_retries})") - cmd = [ - "huggingface-cli", - "download", - "microsoft/OmniParser-v2.0", - file, - "--local-dir", - "weights" - ] - subprocess.run(cmd, check=True) - break # 下载成功,跳出重试循环 - except subprocess.CalledProcessError as e: - if attempt == max_retries - 1: # 最后一次尝试 - print(f"下载失败: {file},已达到最大重试次数") - raise # 重新抛出异常 - print(f"下载失败: {file},正在重试...") - continue - - print("下载完成") + snapshot_download( + 'AI-ModelScope/OmniParser-v2.0', + cache_dir='weights' + ) + + print("Download complete") if __name__ == "__main__": download() \ No newline at end of file