mirror of
https://github.com/yuruotong1/autoMate.git
synced 2026-03-22 13:07:17 +08:00
修复cuda兼容问题
This commit is contained in:
@@ -6,9 +6,7 @@ from xbrain.core.chat import run
|
||||
import platform
|
||||
import re
|
||||
class TaskRunAgent(BaseAgent):
|
||||
def __init__(self):
|
||||
print("TaskRunAgent initialized without a task")
|
||||
def __call__(self,task_plan: str, screen_info):
|
||||
def __call__(self, task_plan: str, screen_info):
|
||||
self.OUTPUT_DIR = "./tmp/outputs"
|
||||
device = self.get_device()
|
||||
self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan,
|
||||
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import time
|
||||
from pydantic import BaseModel
|
||||
import base64
|
||||
from PIL import Image
|
||||
|
||||
class UIElement(BaseModel):
|
||||
element_id: int
|
||||
@@ -19,36 +20,125 @@ class UIElement(BaseModel):
|
||||
class VisionAgent:
|
||||
def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
|
||||
"""
|
||||
Computer vision agent for UI analysis.
|
||||
初始化视觉代理
|
||||
|
||||
Args:
|
||||
yolo_model_path: Path to YOLO model weights
|
||||
caption_model_path: Name/path to captioning model (defaults to Florence-2)
|
||||
参数:
|
||||
yolo_model_path: YOLO模型路径
|
||||
caption_model_path: 图像描述模型路径,默认为Florence-2
|
||||
"""
|
||||
self.device = self._get_available_device()
|
||||
self.dtype = self._get_dtype()
|
||||
self.elements: List[UIElement] = []
|
||||
|
||||
# 确定可用的设备和最佳数据类型
|
||||
self.device, self.dtype = self._get_optimal_device_and_dtype()
|
||||
print(f"使用设备: {self.device}, 数据类型: {self.dtype}")
|
||||
|
||||
# 加载YOLO模型
|
||||
self.yolo_model = YOLO(yolo_model_path)
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path, trust_remote_code=True
|
||||
).to(self.device)
|
||||
|
||||
# 加载图像描述模型和处理器
|
||||
self.caption_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base", trust_remote_code=True
|
||||
"microsoft/Florence-2-base",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 根据设备类型加载模型
|
||||
try:
|
||||
print(f"正在加载图像描述模型: {caption_model_path}")
|
||||
if self.device.type == 'cuda':
|
||||
# CUDA设备使用float16
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
elif self.device.type == 'mps':
|
||||
# MPS设备使用float32(MPS对float16支持有限)
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
else:
|
||||
# CPU使用float32
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
|
||||
print("图像描述模型加载成功")
|
||||
except Exception as e:
|
||||
print(f"加载图像描述模型失败: {e}")
|
||||
print("尝试使用备用方法加载...")
|
||||
|
||||
# 备用加载方法
|
||||
try:
|
||||
# 先加载到CPU,再转移到目标设备
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 如果是CUDA设备,尝试转换为float16
|
||||
if self.device.type == 'cuda':
|
||||
try:
|
||||
self.caption_model = self.caption_model.to(dtype=torch.float16)
|
||||
except:
|
||||
print("转换为float16失败,使用float32")
|
||||
|
||||
# 移动到目标设备
|
||||
self.caption_model = self.caption_model.to(self.device)
|
||||
print("使用备用方法加载成功")
|
||||
except Exception as e2:
|
||||
print(f"备用加载方法也失败: {e2}")
|
||||
print("回退到CPU模式")
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
|
||||
# 设置提示词
|
||||
self.prompt = "<CAPTION>"
|
||||
|
||||
# 设置批处理大小
|
||||
if self.device.type == 'cuda':
|
||||
self.batch_size = 128 # CUDA设备使用较大批处理大小
|
||||
elif self.device.type == 'mps':
|
||||
self.batch_size = 32 # MPS设备使用中等批处理大小
|
||||
else:
|
||||
self.batch_size = 16 # CPU使用较小批处理大小
|
||||
|
||||
self.elements: List[UIElement] = []
|
||||
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
|
||||
|
||||
def _get_available_device(self) -> str:
|
||||
def _get_optimal_device_and_dtype(self):
|
||||
"""确定最佳设备和数据类型"""
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
if torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
|
||||
def _get_dtype(self)-> torch.dtype:
|
||||
if torch.cuda.is_available():
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
device = torch.device("cuda")
|
||||
# 检查GPU是否适合使用float16
|
||||
capability = torch.cuda.get_device_capability()
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
print(f"检测到CUDA设备: {gpu_name}, 计算能力: {capability}")
|
||||
|
||||
# 只在较新的GPU上使用float16
|
||||
if capability[0] >= 7: # Volta及以上架构
|
||||
dtype = torch.float16
|
||||
print("使用float16精度")
|
||||
else:
|
||||
dtype = torch.float32
|
||||
print("GPU计算能力较低,使用float32精度")
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
dtype = torch.float32 # MPS上使用float32更安全
|
||||
print("检测到MPS设备(Apple Silicon),使用float32精度")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
print("未检测到GPU,使用CPU和float32精度")
|
||||
|
||||
return device, dtype
|
||||
|
||||
def _reset_state(self):
|
||||
"""Clear previous analysis results"""
|
||||
@@ -73,12 +163,11 @@ class VisionAgent:
|
||||
ocr_time = (end-start) * 10 ** 3
|
||||
print(f"Speed: {ocr_time:.2f} ms OCR of {len(element_texts)} icons.")
|
||||
start = time.time()
|
||||
element_captions = self._get_caption(element_crops)
|
||||
element_captions = self._get_caption(element_crops, 5)
|
||||
end = time.time()
|
||||
caption_time = (end-start) * 10 ** 3
|
||||
print(f"Speed: {caption_time:.2f} ms captioning of {len(element_captions)} icons.")
|
||||
for idx in range(len(element_crops)):
|
||||
print(idx, boxes[idx], element_texts[idx], element_captions[idx])
|
||||
new_element = UIElement(element_id=idx,
|
||||
coordinates=boxes[idx],
|
||||
text=element_texts[idx][0] if len(element_texts[idx]) > 0 else '',
|
||||
@@ -101,42 +190,128 @@ class VisionAgent:
|
||||
# print(texts)
|
||||
return texts
|
||||
|
||||
|
||||
def _get_caption(self, images: np.ndarray, batch_size: int = 1) -> list[str]:
|
||||
"""Run captioning in batched mode. TODO: adjust batch size"""
|
||||
prompt = "<CAPTION>"
|
||||
def _get_caption(self, element_crops, batch_size=None):
|
||||
"""获取图像元素的描述"""
|
||||
if not element_crops:
|
||||
return []
|
||||
|
||||
# 如果未指定批处理大小,使用实例的默认值
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
# 调整图像尺寸为64x64
|
||||
resized_crops = []
|
||||
for img in element_crops:
|
||||
# 转换为numpy数组,调整大小,再转回PIL
|
||||
img_np = np.array(img)
|
||||
resized_np = cv2.resize(img_np, (64, 64))
|
||||
resized_crops.append(Image.fromarray(resized_np))
|
||||
|
||||
generated_texts = []
|
||||
resized_images = []
|
||||
for image in images:
|
||||
resized_image = cv2.resize(image, (64, 64))
|
||||
resized_images.append(resized_image)
|
||||
|
||||
for i in range(0, len(resized_images), batch_size):
|
||||
batch_images = resized_images[i:i+batch_size]
|
||||
inputs = self.caption_processor(
|
||||
images=batch_images,
|
||||
text=[prompt] * len(batch_images),
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
generated_ids = self.caption_model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=10,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
early_stopping=False,
|
||||
)
|
||||
|
||||
generated_text = self.caption_processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=True
|
||||
)
|
||||
generated_texts.extend([gen.strip() for gen in generated_text])
|
||||
|
||||
device = self.device
|
||||
|
||||
# 分批处理
|
||||
for i in range(0, len(resized_crops), batch_size):
|
||||
batch = resized_crops[i:i+batch_size]
|
||||
try:
|
||||
# 根据设备类型选择数据类型
|
||||
if device.type == 'cuda':
|
||||
inputs = self.caption_processor(
|
||||
images=batch,
|
||||
text=[self.prompt] * len(batch),
|
||||
return_tensors="pt",
|
||||
do_resize=False # 避免处理器改变图像尺寸
|
||||
).to(device=device, dtype=torch.float16)
|
||||
else:
|
||||
# MPS和CPU使用float32
|
||||
inputs = self.caption_processor(
|
||||
images=batch,
|
||||
text=[self.prompt] * len(batch),
|
||||
return_tensors="pt"
|
||||
).to(device=device)
|
||||
|
||||
# 针对Florence-2的特殊处理
|
||||
with torch.no_grad():
|
||||
if 'florence' in self.caption_model.config.model_type:
|
||||
generated_ids = self.caption_model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=20,
|
||||
num_beams=1,
|
||||
do_sample=False
|
||||
)
|
||||
else:
|
||||
generated_ids = self.caption_model.generate(
|
||||
**inputs,
|
||||
max_length=50,
|
||||
num_beams=3,
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
# 解码生成的ID
|
||||
texts = self.caption_processor.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
texts = [text.strip() for text in texts]
|
||||
generated_texts.extend(texts)
|
||||
|
||||
# 清理缓存
|
||||
if device.type == 'cuda' and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except RuntimeError as e:
|
||||
print(f"批次处理失败: {e}")
|
||||
|
||||
# 如果是CUDA错误,尝试在CPU上处理
|
||||
if "CUDA" in str(e) or "cuda" in str(e):
|
||||
print("尝试在CPU上处理此批次...")
|
||||
try:
|
||||
# 临时将模型移至CPU
|
||||
self.caption_model = self.caption_model.to("cpu")
|
||||
|
||||
# 在CPU上处理
|
||||
inputs = self.caption_processor(
|
||||
images=batch,
|
||||
text=[self.prompt] * len(batch),
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
if 'florence' in self.caption_model.config.model_type:
|
||||
generated_ids = self.caption_model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=20,
|
||||
num_beams=1,
|
||||
do_sample=False
|
||||
)
|
||||
else:
|
||||
generated_ids = self.caption_model.generate(
|
||||
**inputs,
|
||||
max_length=50,
|
||||
num_beams=3,
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
texts = self.caption_processor.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
texts = [text.strip() for text in texts]
|
||||
generated_texts.extend(texts)
|
||||
|
||||
# 处理完成后将模型移回原设备
|
||||
self.caption_model = self.caption_model.to(self.device)
|
||||
except Exception as cpu_e:
|
||||
print(f"CPU处理也失败: {cpu_e}")
|
||||
generated_texts.extend(["[描述生成失败]"] * len(batch))
|
||||
else:
|
||||
# 非CUDA错误,直接添加占位符
|
||||
generated_texts.extend(["[描述生成失败]"] * len(batch))
|
||||
|
||||
return generated_texts
|
||||
|
||||
|
||||
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
|
||||
"""Run object detection pipeline"""
|
||||
results = self.yolo_model(image)[0]
|
||||
@@ -203,12 +378,14 @@ class VisionAgent:
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def __call__(self, image_source: str) -> List[UIElement]:
|
||||
def __call__(self, image_path: str) -> List[UIElement]:
|
||||
"""Process an image from file path."""
|
||||
image = self.load_image(image_source)
|
||||
# image = self.load_image(image_source)
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise FileNotFoundError(f"Vision agent: 图片读取失败")
|
||||
|
||||
return self.analyze_image(image)
|
||||
|
||||
|
||||
|
||||
@@ -42,12 +42,23 @@ class Sender(StrEnum):
|
||||
BOT = "assistant"
|
||||
TOOL = "tool"
|
||||
def setup_state(state):
|
||||
# 如果存在config,则从config中加载数据
|
||||
config = Config()
|
||||
if config.OPENAI_API_KEY:
|
||||
state["api_key"] = config.OPENAI_API_KEY
|
||||
else:
|
||||
state["api_key"] = ""
|
||||
if config.OPENAI_BASE_URL:
|
||||
state["base_url"] = config.OPENAI_BASE_URL
|
||||
else:
|
||||
state["base_url"] = "https://api.openai-next.com/v1"
|
||||
if config.OPENAI_MODEL:
|
||||
state["model"] = config.OPENAI_MODEL
|
||||
else:
|
||||
state["model"] = "gpt-4o"
|
||||
|
||||
if "messages" not in state:
|
||||
state["messages"] = []
|
||||
if "model" not in state:
|
||||
state["model"] = "gpt-4o"
|
||||
if "api_key" not in state:
|
||||
state["api_key"] = ""
|
||||
if "auth_validated" not in state:
|
||||
state["auth_validated"] = False
|
||||
if "responses" not in state:
|
||||
@@ -60,8 +71,6 @@ def setup_state(state):
|
||||
state['chatbot_messages'] = []
|
||||
if 'stop' not in state:
|
||||
state['stop'] = False
|
||||
if 'base_url' not in state:
|
||||
state['base_url'] = "https://api.openai-next.com/v1"
|
||||
|
||||
async def main(state):
|
||||
"""Render loop for Gradio"""
|
||||
@@ -157,7 +166,7 @@ def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="b
|
||||
# print(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)")
|
||||
|
||||
|
||||
def process_input(user_input, state, vision_agent):
|
||||
def process_input(user_input, state, vision_agent_state):
|
||||
# Reset the stop flag
|
||||
if state["stop"]:
|
||||
state["stop"] = False
|
||||
@@ -176,6 +185,7 @@ def process_input(user_input, state, vision_agent):
|
||||
state['chatbot_messages'].append((user_input, None)) # 确保格式正确
|
||||
yield state['chatbot_messages'] # Yield to update the chatbot UI with the user's message
|
||||
# Run sampling_loop_sync with the chatbot_output_callback
|
||||
agent = vision_agent_state["agent"]
|
||||
for loop_msg in sampling_loop_sync(
|
||||
model=state["model"],
|
||||
messages=state["messages"],
|
||||
@@ -187,7 +197,7 @@ def process_input(user_input, state, vision_agent):
|
||||
max_tokens=8000,
|
||||
omniparser_url=args.omniparser_server_url,
|
||||
base_url = state["base_url"],
|
||||
vision_agent = vision_agent
|
||||
vision_agent = agent
|
||||
):
|
||||
if loop_msg is None or state.get("stop"):
|
||||
yield state['chatbot_messages']
|
||||
@@ -248,14 +258,14 @@ def run():
|
||||
with gr.Column():
|
||||
model = gr.Textbox(
|
||||
label="Model",
|
||||
value="gpt-4o",
|
||||
value=state.value["model"],
|
||||
placeholder="输入模型名称",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
base_url = gr.Textbox(
|
||||
label="Base URL",
|
||||
value="https://api.openai-next.com/v1",
|
||||
value=state.value["base_url"],
|
||||
placeholder="输入基础 URL",
|
||||
interactive=True
|
||||
)
|
||||
@@ -272,7 +282,7 @@ def run():
|
||||
api_key = gr.Textbox(
|
||||
label="API Key",
|
||||
type="password",
|
||||
value=state.value.get("api_key", ""),
|
||||
value=state.value["api_key"],
|
||||
placeholder="Paste your API key here",
|
||||
interactive=True,
|
||||
)
|
||||
@@ -291,13 +301,8 @@ def run():
|
||||
autoscroll=True,
|
||||
height=580 )
|
||||
|
||||
def update_model(model_selection, state):
|
||||
state["model"] = model_selection
|
||||
api_key_update = gr.update(
|
||||
placeholder="API Key",
|
||||
value=state["api_key"]
|
||||
)
|
||||
return api_key_update
|
||||
def update_model(model, state):
|
||||
state["model"] = model
|
||||
|
||||
def update_api_key(api_key_value, state):
|
||||
state["api_key"] = api_key_value
|
||||
@@ -313,12 +318,13 @@ def run():
|
||||
state['chatbot_messages'] = []
|
||||
return state['chatbot_messages']
|
||||
|
||||
model.change(fn=update_model, inputs=[model, state], outputs=[api_key])
|
||||
model.change(fn=update_model, inputs=[model, state], outputs=None)
|
||||
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
|
||||
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
|
||||
vision_agent = VisionAgent(yolo_model_path="./weights/icon_detect/model.pt",
|
||||
caption_model_path="./weights/icon_caption")
|
||||
submit_button.click(process_input, [chat_input, state, vision_agent], chatbot)
|
||||
caption_model_path="./weights/icon_caption")
|
||||
vision_agent_state = gr.State({"agent": vision_agent})
|
||||
submit_button.click(process_input, [chat_input, state, vision_agent_state], chatbot)
|
||||
stop_button.click(stop_app, [state], None)
|
||||
base_url.change(fn=update_base_url, inputs=[base_url, state], outputs=None)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7888)
|
||||
|
||||
@@ -54,10 +54,7 @@ def sampling_loop_sync(
|
||||
output_callback=output_callback,
|
||||
tool_output_callback=tool_output_callback,
|
||||
)
|
||||
|
||||
tool_result_content = None
|
||||
|
||||
print(f"Start the message loop. User messages: {messages}")
|
||||
plan = task_plan_agent(user_task = messages[-1]["content"][0].text)
|
||||
task_run_agent = TaskRunAgent()
|
||||
|
||||
@@ -65,7 +62,7 @@ def sampling_loop_sync(
|
||||
while True:
|
||||
parsed_screen = parse_screen(vision_agent)
|
||||
# tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
|
||||
tools_use_needed = task_run_agent(plan, parsed_screen)
|
||||
tools_use_needed = task_run_agent(task_plan=plan, screen_info=parsed_screen)
|
||||
for message, tool_result_content in executor(tools_use_needed, messages):
|
||||
yield message
|
||||
if not tool_result_content:
|
||||
@@ -73,7 +70,5 @@ def sampling_loop_sync(
|
||||
|
||||
def parse_screen(vision_agent: VisionAgent):
|
||||
_, screenshot_path = get_screenshot()
|
||||
screenshot_path = str(screenshot_path)
|
||||
image_base64 = encode_image(screenshot_path)
|
||||
parsed_screen = vision_agent(image_base64)
|
||||
parsed_screen = vision_agent(str(screenshot_path))
|
||||
return parsed_screen
|
||||
|
||||
Reference in New Issue
Block a user