修复cuda兼容问题

This commit is contained in:
yuruo
2025-03-11 21:31:39 +08:00
parent 30b97e53b1
commit b2d559f15a
4 changed files with 266 additions and 90 deletions

View File

@@ -6,9 +6,7 @@ from xbrain.core.chat import run
import platform
import re
class TaskRunAgent(BaseAgent):
def __init__(self):
print("TaskRunAgent initialized without a task")
def __call__(self,task_plan: str, screen_info):
def __call__(self, task_plan: str, screen_info):
self.OUTPUT_DIR = "./tmp/outputs"
device = self.get_device()
self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan,

View File

@@ -9,6 +9,7 @@ import numpy as np
import time
from pydantic import BaseModel
import base64
from PIL import Image
class UIElement(BaseModel):
element_id: int
@@ -19,36 +20,125 @@ class UIElement(BaseModel):
class VisionAgent:
def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
"""
Computer vision agent for UI analysis.
初始化视觉代理
Args:
yolo_model_path: Path to YOLO model weights
caption_model_path: Name/path to captioning model (defaults to Florence-2)
参数:
yolo_model_path: YOLO模型路径
caption_model_path: 图像描述模型路径,默认为Florence-2
"""
self.device = self._get_available_device()
self.dtype = self._get_dtype()
self.elements: List[UIElement] = []
# 确定可用的设备和最佳数据类型
self.device, self.dtype = self._get_optimal_device_and_dtype()
print(f"使用设备: {self.device}, 数据类型: {self.dtype}")
# 加载YOLO模型
self.yolo_model = YOLO(yolo_model_path)
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path, trust_remote_code=True
).to(self.device)
# 加载图像描述模型和处理器
self.caption_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
"microsoft/Florence-2-base",
trust_remote_code=True
)
# 根据设备类型加载模型
try:
print(f"正在加载图像描述模型: {caption_model_path}")
if self.device.type == 'cuda':
# CUDA设备使用float16
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float16,
trust_remote_code=True
).to(self.device)
elif self.device.type == 'mps':
# MPS设备使用float32MPS对float16支持有限
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
).to(self.device)
else:
# CPU使用float32
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
).to(self.device)
print("图像描述模型加载成功")
except Exception as e:
print(f"加载图像描述模型失败: {e}")
print("尝试使用备用方法加载...")
# 备用加载方法
try:
# 先加载到CPU再转移到目标设备
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
)
# 如果是CUDA设备尝试转换为float16
if self.device.type == 'cuda':
try:
self.caption_model = self.caption_model.to(dtype=torch.float16)
except:
print("转换为float16失败使用float32")
# 移动到目标设备
self.caption_model = self.caption_model.to(self.device)
print("使用备用方法加载成功")
except Exception as e2:
print(f"备用加载方法也失败: {e2}")
print("回退到CPU模式")
self.device = torch.device("cpu")
self.dtype = torch.float32
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
).to(self.device)
# 设置提示词
self.prompt = "<CAPTION>"
# 设置批处理大小
if self.device.type == 'cuda':
self.batch_size = 128 # CUDA设备使用较大批处理大小
elif self.device.type == 'mps':
self.batch_size = 32 # MPS设备使用中等批处理大小
else:
self.batch_size = 16 # CPU使用较小批处理大小
self.elements: List[UIElement] = []
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
def _get_available_device(self) -> str:
def _get_optimal_device_and_dtype(self):
"""确定最佳设备和数据类型"""
if torch.cuda.is_available():
return 'cuda'
if torch.backends.mps.is_available():
return 'mps'
return 'cpu'
def _get_dtype(self)-> torch.dtype:
if torch.cuda.is_available():
return torch.float16
return torch.float32
device = torch.device("cuda")
# 检查GPU是否适合使用float16
capability = torch.cuda.get_device_capability()
gpu_name = torch.cuda.get_device_name()
print(f"检测到CUDA设备: {gpu_name}, 计算能力: {capability}")
# 只在较新的GPU上使用float16
if capability[0] >= 7: # Volta及以上架构
dtype = torch.float16
print("使用float16精度")
else:
dtype = torch.float32
print("GPU计算能力较低使用float32精度")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device("mps")
dtype = torch.float32 # MPS上使用float32更安全
print("检测到MPS设备(Apple Silicon)使用float32精度")
else:
device = torch.device("cpu")
dtype = torch.float32
print("未检测到GPU使用CPU和float32精度")
return device, dtype
def _reset_state(self):
"""Clear previous analysis results"""
@@ -73,12 +163,11 @@ class VisionAgent:
ocr_time = (end-start) * 10 ** 3
print(f"Speed: {ocr_time:.2f} ms OCR of {len(element_texts)} icons.")
start = time.time()
element_captions = self._get_caption(element_crops)
element_captions = self._get_caption(element_crops, 5)
end = time.time()
caption_time = (end-start) * 10 ** 3
print(f"Speed: {caption_time:.2f} ms captioning of {len(element_captions)} icons.")
for idx in range(len(element_crops)):
print(idx, boxes[idx], element_texts[idx], element_captions[idx])
new_element = UIElement(element_id=idx,
coordinates=boxes[idx],
text=element_texts[idx][0] if len(element_texts[idx]) > 0 else '',
@@ -101,42 +190,128 @@ class VisionAgent:
# print(texts)
return texts
def _get_caption(self, images: np.ndarray, batch_size: int = 1) -> list[str]:
"""Run captioning in batched mode. TODO: adjust batch size"""
prompt = "<CAPTION>"
def _get_caption(self, element_crops, batch_size=None):
"""获取图像元素的描述"""
if not element_crops:
return []
# 如果未指定批处理大小,使用实例的默认值
if batch_size is None:
batch_size = self.batch_size
# 调整图像尺寸为64x64
resized_crops = []
for img in element_crops:
# 转换为numpy数组调整大小再转回PIL
img_np = np.array(img)
resized_np = cv2.resize(img_np, (64, 64))
resized_crops.append(Image.fromarray(resized_np))
generated_texts = []
resized_images = []
for image in images:
resized_image = cv2.resize(image, (64, 64))
resized_images.append(resized_image)
for i in range(0, len(resized_images), batch_size):
batch_images = resized_images[i:i+batch_size]
inputs = self.caption_processor(
images=batch_images,
text=[prompt] * len(batch_images),
return_tensors="pt",
do_resize=True,
).to(device=self.device, dtype=self.dtype)
generated_ids = self.caption_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=10,
num_beams=1,
do_sample=False,
early_stopping=False,
)
generated_text = self.caption_processor.batch_decode(
generated_ids, skip_special_tokens=True
)
generated_texts.extend([gen.strip() for gen in generated_text])
device = self.device
# 分批处理
for i in range(0, len(resized_crops), batch_size):
batch = resized_crops[i:i+batch_size]
try:
# 根据设备类型选择数据类型
if device.type == 'cuda':
inputs = self.caption_processor(
images=batch,
text=[self.prompt] * len(batch),
return_tensors="pt",
do_resize=False # 避免处理器改变图像尺寸
).to(device=device, dtype=torch.float16)
else:
# MPS和CPU使用float32
inputs = self.caption_processor(
images=batch,
text=[self.prompt] * len(batch),
return_tensors="pt"
).to(device=device)
# 针对Florence-2的特殊处理
with torch.no_grad():
if 'florence' in self.caption_model.config.model_type:
generated_ids = self.caption_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=20,
num_beams=1,
do_sample=False
)
else:
generated_ids = self.caption_model.generate(
**inputs,
max_length=50,
num_beams=3,
early_stopping=True
)
# 解码生成的ID
texts = self.caption_processor.batch_decode(
generated_ids,
skip_special_tokens=True
)
texts = [text.strip() for text in texts]
generated_texts.extend(texts)
# 清理缓存
if device.type == 'cuda' and torch.cuda.is_available():
torch.cuda.empty_cache()
except RuntimeError as e:
print(f"批次处理失败: {e}")
# 如果是CUDA错误尝试在CPU上处理
if "CUDA" in str(e) or "cuda" in str(e):
print("尝试在CPU上处理此批次...")
try:
# 临时将模型移至CPU
self.caption_model = self.caption_model.to("cpu")
# 在CPU上处理
inputs = self.caption_processor(
images=batch,
text=[self.prompt] * len(batch),
return_tensors="pt"
)
with torch.no_grad():
if 'florence' in self.caption_model.config.model_type:
generated_ids = self.caption_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=20,
num_beams=1,
do_sample=False
)
else:
generated_ids = self.caption_model.generate(
**inputs,
max_length=50,
num_beams=3,
early_stopping=True
)
texts = self.caption_processor.batch_decode(
generated_ids,
skip_special_tokens=True
)
texts = [text.strip() for text in texts]
generated_texts.extend(texts)
# 处理完成后将模型移回原设备
self.caption_model = self.caption_model.to(self.device)
except Exception as cpu_e:
print(f"CPU处理也失败: {cpu_e}")
generated_texts.extend(["[描述生成失败]"] * len(batch))
else:
# 非CUDA错误直接添加占位符
generated_texts.extend(["[描述生成失败]"] * len(batch))
return generated_texts
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
"""Run object detection pipeline"""
results = self.yolo_model(image)[0]
@@ -203,12 +378,14 @@ class VisionAgent:
raise ValueError(error_msg) from e
def __call__(self, image_source: str) -> List[UIElement]:
def __call__(self, image_path: str) -> List[UIElement]:
"""Process an image from file path."""
image = self.load_image(image_source)
# image = self.load_image(image_source)
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"Vision agent: 图片读取失败")
return self.analyze_image(image)

View File

@@ -42,12 +42,23 @@ class Sender(StrEnum):
BOT = "assistant"
TOOL = "tool"
def setup_state(state):
# 如果存在config则从config中加载数据
config = Config()
if config.OPENAI_API_KEY:
state["api_key"] = config.OPENAI_API_KEY
else:
state["api_key"] = ""
if config.OPENAI_BASE_URL:
state["base_url"] = config.OPENAI_BASE_URL
else:
state["base_url"] = "https://api.openai-next.com/v1"
if config.OPENAI_MODEL:
state["model"] = config.OPENAI_MODEL
else:
state["model"] = "gpt-4o"
if "messages" not in state:
state["messages"] = []
if "model" not in state:
state["model"] = "gpt-4o"
if "api_key" not in state:
state["api_key"] = ""
if "auth_validated" not in state:
state["auth_validated"] = False
if "responses" not in state:
@@ -60,8 +71,6 @@ def setup_state(state):
state['chatbot_messages'] = []
if 'stop' not in state:
state['stop'] = False
if 'base_url' not in state:
state['base_url'] = "https://api.openai-next.com/v1"
async def main(state):
"""Render loop for Gradio"""
@@ -157,7 +166,7 @@ def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="b
# print(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)")
def process_input(user_input, state, vision_agent):
def process_input(user_input, state, vision_agent_state):
# Reset the stop flag
if state["stop"]:
state["stop"] = False
@@ -176,6 +185,7 @@ def process_input(user_input, state, vision_agent):
state['chatbot_messages'].append((user_input, None)) # 确保格式正确
yield state['chatbot_messages'] # Yield to update the chatbot UI with the user's message
# Run sampling_loop_sync with the chatbot_output_callback
agent = vision_agent_state["agent"]
for loop_msg in sampling_loop_sync(
model=state["model"],
messages=state["messages"],
@@ -187,7 +197,7 @@ def process_input(user_input, state, vision_agent):
max_tokens=8000,
omniparser_url=args.omniparser_server_url,
base_url = state["base_url"],
vision_agent = vision_agent
vision_agent = agent
):
if loop_msg is None or state.get("stop"):
yield state['chatbot_messages']
@@ -248,14 +258,14 @@ def run():
with gr.Column():
model = gr.Textbox(
label="Model",
value="gpt-4o",
value=state.value["model"],
placeholder="输入模型名称",
interactive=True,
)
with gr.Column():
base_url = gr.Textbox(
label="Base URL",
value="https://api.openai-next.com/v1",
value=state.value["base_url"],
placeholder="输入基础 URL",
interactive=True
)
@@ -272,7 +282,7 @@ def run():
api_key = gr.Textbox(
label="API Key",
type="password",
value=state.value.get("api_key", ""),
value=state.value["api_key"],
placeholder="Paste your API key here",
interactive=True,
)
@@ -291,13 +301,8 @@ def run():
autoscroll=True,
height=580 )
def update_model(model_selection, state):
state["model"] = model_selection
api_key_update = gr.update(
placeholder="API Key",
value=state["api_key"]
)
return api_key_update
def update_model(model, state):
state["model"] = model
def update_api_key(api_key_value, state):
state["api_key"] = api_key_value
@@ -313,12 +318,13 @@ def run():
state['chatbot_messages'] = []
return state['chatbot_messages']
model.change(fn=update_model, inputs=[model, state], outputs=[api_key])
model.change(fn=update_model, inputs=[model, state], outputs=None)
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
vision_agent = VisionAgent(yolo_model_path="./weights/icon_detect/model.pt",
caption_model_path="./weights/icon_caption")
submit_button.click(process_input, [chat_input, state, vision_agent], chatbot)
caption_model_path="./weights/icon_caption")
vision_agent_state = gr.State({"agent": vision_agent})
submit_button.click(process_input, [chat_input, state, vision_agent_state], chatbot)
stop_button.click(stop_app, [state], None)
base_url.change(fn=update_base_url, inputs=[base_url, state], outputs=None)
demo.launch(server_name="0.0.0.0", server_port=7888)

View File

@@ -54,10 +54,7 @@ def sampling_loop_sync(
output_callback=output_callback,
tool_output_callback=tool_output_callback,
)
tool_result_content = None
print(f"Start the message loop. User messages: {messages}")
plan = task_plan_agent(user_task = messages[-1]["content"][0].text)
task_run_agent = TaskRunAgent()
@@ -65,7 +62,7 @@ def sampling_loop_sync(
while True:
parsed_screen = parse_screen(vision_agent)
# tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
tools_use_needed = task_run_agent(plan, parsed_screen)
tools_use_needed = task_run_agent(task_plan=plan, screen_info=parsed_screen)
for message, tool_result_content in executor(tools_use_needed, messages):
yield message
if not tool_result_content:
@@ -73,7 +70,5 @@ def sampling_loop_sync(
def parse_screen(vision_agent: VisionAgent):
_, screenshot_path = get_screenshot()
screenshot_path = str(screenshot_path)
image_base64 = encode_image(screenshot_path)
parsed_screen = vision_agent(image_base64)
parsed_screen = vision_agent(str(screenshot_path))
return parsed_screen